Another Favourite Machine Learning Paper: Adversarial Networks vs Kernel Scoring Rules
A while ago I wrote a post on my then-current favourite machine learning paper which connected denoising autoencoders to pseudolikelihood learning. I am generally intrigued by papers that establish connections between seemingly unrelated techniques, thereby bringing communities and frameworks together. Today I have found a similarly interesting paper from my old lab in Cambridge. In fact, there are two papers with the authors inventing pretty much the same thing in parallel:
- GK Dziugaite, DM Roy & Z Ghahramani: Training generative neural networks via Maximum Mean Discrepancy optimization
- Y Li, K Swesky & R Zemel: Generative Moment Matching Networks (ICML 2015).
Unsupervised Learning with Deep Models
Deep learning has achieved several breakthroughs in supervised learning, setting new benchmarks in speech, computer vision and natural language understanding. Many agree that one of the most interesting frontiers of research in this area is unsupervised learning: how do we learn rich, deep, distributed representations of data, when labels are not provided?
In unsupervised learning we are given indepenent samples $x_i$ from some underlying data distribution $P$, and our goal is to come up with an approximate distribution $Q$ that is as close to $P$ as possible, only using the samples $x_i$. Often, $Q$ is chosen from a parametric family of distributions $\{Q(\cdot ;\theta),\theta\in\Theta\}$, and our goal is to find the optimal parameters $\theta^{\ast}$ so that the distribution $P$ is best approximated.
The central issue of unsupervised learning is choosing an appropriate objective function $\ell(\theta,P)$, that appropriately measures the quality of our approximation, and which is tractable to compute and optimise when we are working with complicated, deep models. The most straightforward choice, marginal likelihood, is intractable in most cases, and researchers over the past decades have come up with a zoo of methods, such as contrastive divergence, pseudolikelihood and variational bounds, to solve this problem. Bt largely, people agree that unsupervised learning in deep generative models is far from solved.
Here, I'm writing about a promising new objective function used in Adversarial Generative Networks, and its connection to Maximum Mean Discrepancy which is known from the kernel machines literature.
Generative Adversarial Networks
Generative Adversarial Networks (GANs) are intuitively described as a game or fight between two neural networks: a generative model that specifies how to generate samples from the approximate distribution $Q$, and a discriminative network that tries to distinguish samples generated by the generative network from real data. The goal is that over time both the generative and discriminative models get better at their tasks, so we are left with a generative model whose samples the best discriminative network had a hard time distinguishing from real data.
Mathematically, the procedure minimises the following objective function:
$$\ell(\theta,P) = \max_{\psi} E_{x\sim Q(\cdot;\theta)} \log~f(x; \psi) + E_{x\sim P} \log(1 - f(x; \psi))$$
Explanation of terms:
- $P$ is the true data distribution, in this case $P$ will be approximated as the empirical distribution of data points.
- $Q(\cdot; \theta)$ is a probability distribution over $x$, parametrised by $\theta$. I will call this the approximate distribution. In original paper $Q$ was described as a generative procedure whereby a low dimensional latent vector $z$ was drawn from some simple parametric distribution, and then $z$ was transformed nonlinearly to obtain $x$ by a neural network whose parameters are $\theta$.
- $f(x; \psi)$ is a parametric discriminative function taking values between 0 and 1 and parametrised by $\psi$. It takes an input $x$ and predicts whether $x$ was a sample from the true data distribution $P$ or the approximate distribution $Q$. In this work, $f$ is modelled as a deep neural network trained in a supervised way.
- the first term $E_{x\sim Q(\cdot;\theta)} \log~f(x; \psi)$ encourages the discriminative function $f$ to give a positive ($f(x)=1$) output for any sample from $Q$ while the second term $E_{x\sim P} \log(1 - f(x; \psi))$ encourages it to give a negative output ($f(x)=0$) for samples in P. If the function $f$ is a perfect discriminator with respect to $Q$ and $P$, the sum of these two terms is 0, otherwise it's negative.
The core of the idea is to train a probabilistic model $Q(\cdot; \theta)$ in such a way that the best possible discriminator function $f$ cannot differentiate well between $Q$ and the true data distribution $P$. This is not the first time I see this idea! - I thought. Maximum mean discrepancy also tries to do exactly this.
Maximum Mean Discrepancy
Maximum mean discrepancy(MMD) was originally proposed by the kernel machines community as a nonparametric way to measure dissimilarity between two probability distributions. It's worth pointing out it has also been independently discovered as kernel scoring rule in statistics.
Just like any metric of dissimilarity between distributions, MMD can be used as an objective function for generative modelling, and this is exactly what the paper I found today proposed to do. Have a look at the MDD objective, and observe the similarities to the adversarial objective above:
$$\ell(\theta,P) = MMD(Q(\cdot; \theta),P) = \sup_{f\sim \mathcal{F}} \left\vert E_{x\sim P}f(x) - E_{x\sim Q}f(x)\right\vert $$
The MMD criterion also uses the concept of an 'adversarial' function $f$ that discriminates between samples from $Q$ and $P$. However, instead of it being a binary classifier constrained to predict 0 or 1, here $f$ can be any function chosen from some function class $\mathcal{F}$. The discrimination is measured not in terms of binary classification accuracy as above, but as the difference between the expected value of $f$ under $P$ and $Q$. The idea is: if $P$ and $Q$ are exactly the same, there should be no function whose expectations differ under $Q$ and $P$.
In adversarial networks, the maximisation over $f$ was carried out via stochastic gradient descent, here it can be done analytically, under suitable assumptions. It turns out that if the function class $\mathcal{F}$ is (the unit ball of) a reproducing kernel Hilbert space characterised by kernel $k$, the maximisation can be carried out analytically, and we obtain the following expression after applying the kernel trick:
$$MMD_k(Q,P) = E_{x,x' \sim P,P}k(x,x') + E_{x,x' \sim Q,Q}k(x,x') - 2 E_{x,x' \sim P,Q}k(x,x')$$
In the expression above
- $E_{x,x' \sim P,P}k(x,x')$ is constant with respect to Q, so we can just drop it from the objective.
- the second term $E_{x,x' \sim Q,Q}k(x,x')$ can be interpreted as an entropy term: minimising this will force Q to be spread out, rather than concentrate on a single set of points
- the third term $E_{x,x' \sim P,Q}k(x,x')$ ensures that samples from $Q$ are on average close to samples from $P$
Pros and cons of each approach
-
kernel MDD is strictly proper: When designing loss functions for unsupervised learning, an important to check whether the loss function is a strictly proper scoring rule or not. Roughly speaking, being strictly proper means that the objective function can't be easily tricked. We know that both maximum likelihood learning and pseudo-likelihood learning possess this property, and this ensures that they result in statistically consistent learning algorithms. Kernel MMD has been shown to be strictly proper depending on the choice of kernel function. However, it is unlikely that the adversarial network objective is strictly proper without further restrictions.
-
adversarial network might overfit: The discriminative adversarial network $f$ can overfit the training data, if not properly regularised. For example, it can learn to memorise all data points from $P$, now making it very easy to discriminate between any $Q$ and $P$. In kernel MMD, regularisation is part of the definition of the loss function, so one does not have to worry that much about overfitting.
-
deep neural networks are superior discriminative models: MMD uses functions from a kernel Hilbert space as discriminatory functions. We know from recent success that deep neural networks are probably superior discriminative models in high-dimensional data such as natural images.
-
adversarial networks are easier to generalise: because application of the kernel trick relies on the particular form of the MMD objective function, it is very hard to generalise the MMD objective to make it more performant. On the other hand one can imagine alternative adversarial network objectives based on ROC curves or f-scores, which may perform better than the original criterion.
-
the two can be combined: One could design a kernel which has a deep neural network in it, and use the MMD objective with that kernel for a hybrid method. Not sure if this would make sense, but why not :)