Improving Variational Inference

Recap: Variational Inference

Key idea in variational inference is to approximate the true posterior with a variational distribution. This turns the inference problem to an optimization problem that minimizes the distance between these two distributions using a metric like KL divergence.

In the above diagram $x$ is the data (e.g. image or a sound clip) that has been generated by an unknown latent factor $z$. Part of this problem is to learn $z$ and in doing that we need to approximate the true posterior with an approximate distribution.

Let $q_{x}(z)$ be the approximation of the true posterior $p(z | x)$. And we pick KL divergence to minimize the distance between these 2 distributions.

$\begin{aligned} D_{\mathrm{KL}}\left[q_{x}(z) \| p(z | x)\right] &=\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log p(z | x)\right] \\ &=\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log \frac{p(z, x)}{p(x)}\right] \\ &=\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log p(z)-\log p(x | z)+\log p(x)\right] \\ &=\underbrace{\mathbb{E}_{z \sim q_{x}(z)}\left[\log q_{x}(z)-\log p(z)-\log p(x | z)\right]}_{\text {Only this part depends on } z} \end{aligned}$

Using simple algebra and Bayes' rule, we arrive at the above expression that can be used to minimize the distance between the distributions. The expectation in the expression can be approximated by stochastic samples and each term within the expectation can be computed in O(1) time.

Variational Inference as Importance Sampling

To train a latent variable model, we want to compute the marginal likelihood for any given $x$: $p(x)=\sum_{z} p(z, x)$. And the way we train it is by using Maximum Likelihood. This means that for any given $x$ we need to compute the marginal probability $p(x)$. To get the marginal, assuming that we have the discrete latent code $z$, we need to sum over all $z$ for this joint probability $p(z, x)$. This becomes difficult if $z$ has exponential number of choices.

But here's an intuition from empirical observations. For any $x$, typically $p(x, z)$ will have probability mass concentrated in very few places. e.g. if for an image the latent code $z$ is used to describe a semantic thing like if there is a car in the image, then only for that timy fraction of the image where the car is present will have that $z$ turned on.

Hence in most of the VAE or any other meaningful latent variable models, we will have a very picky distribution in the joint space for any given $x$. This implies that in the high dimensional space we don't need to enumerate through all the possibilities of $z$. We can place emphasis on the possibilities of $z$ that are more likely under that $x$. This is one way of looking at variational inference as Importance Sampling.

Intuition: The variational distribution $q(z | x)$ samples the high density region of $p(z, x)$

Have a look at the following derivation:

$\begin{aligned} \log p(x) &=\log \sum_{z} p(z, x) \\ &=\log \sum_{z} q(z | x) \frac{p(z, x)}{q(z | x)} \\ &=\log \mathbb{E}_{z \sim q(z | x)}\left[\frac{p(z, x)}{q(z | x)}\right] \\ & \geq \mathbb{E}_{z \sim q(z | x)}\left[\log \frac{p(z, x)}{q(z | x)}\right] \end{aligned}$

A few things to note here:

  • The expression in the third line is very close to the Variational Lower Bound (VLB). In fact if we get the $\log$ inside by applying Jensen's inequality as in the fourth line then we get the VLB itself.
  • The third line of the derivation implies that if we draw a lot of samples from $q(z | x)$, average them and then take the $\log$ we approach $\log p(x)$.
  • The fourth line of the derivation gives the VLB. This says that if I push the $\log$ inside I can draw one sample $z$ and get the lower bound itself.

So what the above implies is that Variational Inference is one way to do Importance Sampling. And that's exactly what we do in a VAE with one sample (the 4th line in the derivation above). The question is if we use multiple samples for importance sampling in a VAE, will that improve the lower bound ?

The answer is Yes! and Burda et al shows the way to do it in [5].

VAE Improvements

One of the ways that we can think of improving VAE is by reducing the gap between marginal log likelihood and the Variational Lower Bound, which implies reducing the mismatch between the approximate posterior and the true posterior. i.e. reducing $D_{\mathrm{KL}}\left[q_{x}(z | x) \| p(z | x)\right]$. We can do the following to reduce this gap:

  • Using Importance Sampling as per [5]
  • More expressive approximate posterior $q(z | x)$ which will also result in the true posterior capable of expressing more forms
  • More expressive prior $p(z)$

Using Importance Sampling

As Burda et al suggests the idea is to draw $k$ samples (instead of 1) to form an approximation that is closer to the true marginal log likelihood. And just train on this objective:

$\mathcal{L}_{k}=\mathbb{E}\left[\log \frac{1}{k} \sum_{i=1}^{k} w_{i}\right] \leq \log \mathbb{E}\left[\frac{1}{k} \sum_{i=1}^{k} w_{i}\right]=\log p(\mathbf{x})$

where

$w_{i}=\frac{p\left(z_{i}, x\right)}{q\left(z_{i} | x\right)}$

Here we choose $k$ to be the number of samples that we can afford to draw and then train on this modified objective. When $k$ is large it's as if we are training on the true marginal log likelihood.

Here's a comparison of the improvements that we get using IWAE as compared to using plain VAE: (source Burda et al's paper)

And here's an illustration that shows how using IWAE can result in more expressive posteriors than using regular VAEs: (source Burda et al's paper)

In the above diagram, the heatmaps show the true posteriors $p(z | x)$ of 4 data points for VAE and IWAE models. The leftmost heatmap shows for a VAE and the other 2 shows the same for IWAE with $k = 5$ and $k = 50$ respectively.

For the VAE case the figure shows that the posterior is (sort of) Gaussian, some circularish and some elongated and non axes aligned. Why is this so ?

The setting that we have here is:

  • Prior $p(z)$ is $Normal(0, I)$
  • Approximate posterior (variational distribution) $q(z | x)$ is $Normal(\mu, diag(\sigma)$, where $\mu$ and $\sigma$ come from the output of a neural network. So approximate posterior also takes the shape of a diagonal Gaussian

But none of the above explains why the true posterior should also be Gaussian (non axes elongated). Then how can we explain this ?

In a VAE when we maximize the variational lower bound (VLB), we are trying to maximize the marginal log likelihood, which also tries to minimize the KL divergence between the approximate and true posterior. This has 2 effects:

  • Pushes the approximate posterior to be closer to the true posterior
  • Pulls the true posterior to a form that can be expressed by the approximate posterior

This means that in deep generative modeling with latent codes, the choice of the variational distribution also shapes the generative model. In a sense this forces the generative model to conform to the family of the variational distribution.

As the IWAE paper suggests, in IWAE if we a larger $k$ we get to use a more expressive variational distribution $q$ which will also result in the true posterior being more expressive. So larger $k$ implies fewer restrictions on the true posterior. Have a look at the MNIST example above for the heatmaps corresponding to IWAE models. In the first row of the IWAE model with $k = 5$, we get an elongated Gaussian in a non axis aligned way. In the second row we get 2 modes for IWAE which is not possible just with a plain VAE. And with larger $k$ we get more exotic shapes.

True posteriors can be complicated and if we assume the approximate posterior to be a Gaussian with diagonal covariance, then we can only express a limited form in the true posterior as with plain VAE.

More Expressive Posterior

In the last section we saw one way of making posteriors more expressive by using importance sampling technique to implement an Importance Weighted Autoencoders. In this section we will discuss another approach based on some first principles way of thinking about posteriors.

Say we have a fixed prior $p(z)$ for the latent space (e.g. an isotropic Gaussian). We can then think that the approximate posterior $q(z | x)$ has a bin packing problem. For every data point $x$, the approximate posterior $q(z | x)$ maps $x$ to a distinct region in $p(z)$. So that $p(x | z)$ can reconstruct the data point with as little loss of information as possible. And more the packing of $x$ ito $p(z)$ the better. This is just an intuition - in reality however, we would like our model to generalize to data points that it has not seen so far.

Let's consider the following example where we fit a VAE with spherical Gaussian prior $p(z)$ and factorized Gaussian approximate posterior with diagonal covariance $q(z | x)$ to a toy data set of 4 points $\{A, B, C, D\}$:

The left box shows the isotropic Gaussian prior and in the right box each colored cluster corresponds to the approximate posterior distribution of one datapoint e.g. $q(z | x = A)$. Note that all the $4$ posteriors model some form of isotropic Gaussian only which is not very expressive. The fact that we see lots of empty spaces in between the distributions also imply that the bin packing was poor that would lead to inefficient inference downstream.

Now compare this with the following diagram where the rightmost box has a much more expressive posterior that allows for a much better fit between the prior and the posterior (aka better bin packing):

Now it turns out that implementing such expressive posteriors is possible and we will discuss some of them in the following sections. The above diagram is from [2] and improves VAE with more flexible posteriors.

A Lookback at the VAE Computational Model

So why do we need more expressive posterior in a VAE ? Just for recap here's how computation flows through a Variational Autoencoder through the inference and generative networks:

Computational Flow in a Variational Autoencoder (Source: [3])

The inference model needs to generate an approximate posterior $q(\mathbf{z} | \mathbf{x})$ that has to be computationally tractable as well as more flexible than a simple diagonal Gaussian. Kingma et. al. explains this dichotomy in their paper [2]:

Requirements for the inference model, in order to be able to efficiently optimize the bound, are that it is (1) computationally efficient to compute and differentiate its probability density $q(\mathbf{z} | \mathbf{x})$, and (2) computationally efficient to sample from, since both these operations need to be performed for each datapoint in a minibatch at every iteration of optimization. If $\mathbf{z}$ is high-dimensional and we want to make efficient use of parallel computational resources like GPUs, then parallelizability of these operations across dimensions of $\mathbf{z}$ is a large factor towards efficiency. This requirement restrict the class of approximate posteriors $q(\mathbf{z} | \mathbf{x})$ that are practical to use. In practice this often leads to the use of diagonal posteriors, e.g. $q(\mathbf{z} | \mathbf{x}) \sim \mathcal{N}\left(\boldsymbol{\mu}(\mathbf{x}), \boldsymbol{\sigma}^{2}(\mathbf{x})\right)$, where $\mu(x)$ and $\sigma(x)$ are often nonlinear functions parameterized by neural networks. However, as explained above, we also need the density $q(\mathbf{z} | \mathbf{x})$ to be sufficiently flexible to match the true posterior $p(\mathbf{z} | \mathbf{x})$.

Normalizing Flows

Now that we have established the requirements for a sufficiently flexible posterior distribution in a VAE that also needs to balabce with computational tractability, we need a framework for building such distributions. Normalizing flow offers one such alternative - here we start off with an initial random variable with a simple distribution and apply successive invertible transformations to convert it into something more flexible. Rezende and Mohamed first proposed an integration of normalizing flows into the framework of variational inference in their paper Variational Inference with Normalizing Flows.

We start with $\mathbf{z}_{0} \sim q\left(\mathbf{z}_{0} | \mathbf{x}\right)$ and then apply the transformation successively $\mathbf{z}_{t}=\mathbf{f}_{t}\left(\mathbf{z}_{t-1}, \mathbf{x}\right) \quad \forall t=1 \ldots T$. Then it can be shown that as long as we can compute the Jacobian determinant of each of the transformations $\mathbf{f}_{t}$, we can compute the probability density function of the last iterate:

$$\log q\left(\mathbf{z}_{T} | \mathbf{x}\right)=\log q\left(\mathbf{z}_{0} | \mathbf{x}\right)-\sum_{t=1}^{T} \log \operatorname{det}\left|\frac{d \mathbf{z}_{t}}{d \mathbf{z}_{t-1}}\right|$$

Mapping this back to our VAE, the idea is to apply normalizing flows to our fully factorized posterior distribution and make it something more flexible that can match the true posterior. The real trick is to find one such transformation that offers the balance of being computationally tractable and flexible enough to have the ability of matching the true posterior.

Inverse Autoregressive Transformations

An autoregressive transformation is one where given a sequence of variables $\mathbf{y}=\left\{y_{i}\right\}_{i=0}^{D}$, each variable is dependent only on the previously indexed variables i.e. $y_{i}=f_{i}\left(y_{0: i-1}\right)$.

Let's start with a noise vector $\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})$ and apply the following autoregressive transformation:

$$\begin{array}{l} {y_{0}=\mu_{0}+\sigma_{0} \odot \epsilon_{0}} \\ {y_{i}=\mu_{i}\left(\mathbf{y}_{0: i-1}\right)+\sigma_{i}\left(\mathbf{y}_{0: i-1}\right) \odot \epsilon_{i} \quad \text { for } i>0} \end{array}$$

Here $[\mu(\mathbf{y}), \sigma(\mathbf{y})]$ comes from a neural network (inference model in our case of VAE) and we choose an initial ordering on $\mathbf{y}$ that enables the autoregressive property of the model. Also here addition is element-wise and $\odot$ is element-wise multiplication.

Note however, that sampling from such a model follows the whole chain sequential transformation and the computation involved is clearly $\mathcal{O}(D)$. It is strictly sequential as generation of $y_{i}$ needs a previously generated $y_{i-1}$. Not good for our VAE as far as computational tractability is concerned. Here's an intuitive diagram of this autoregressive normalizing flow (Note we use an $\mathbf{x}$ and transform to a $\mathbf{y}$ since we transform any vector, not only a noise):

Autoregressive Transform (Source: [4])

Kingma et. al. in their paper (referred above) suggests using the inverse transformation instead and shows how it proves to be a perfect fit for the VAE use case:

$$x_{i}=\frac{y_{i}-\mu_{i}\left(\mathbf{y}_{0: i-1}\right)}{\sigma_{i}\left(\mathbf{y}_{0: i-1}\right)}$$

where subtraction and division are elementwise. Note that in such a transformation we have all the $y_{i}$s upfront and we generate the $x_{i}$s. Hence this transformation can be parallelized in the implementation.

Inverse Autoregressive Transform (Source: [4])

The other great part of this transform is that due to the autoregressive structure of the model, $\partial\left[\mu_{i}, \sigma_{i}\right] / \partial y_{j}=[0,0] \text { for } j \geq i$. Hence the Jacobian is lower triangular and its determinant equals the product of the diagonal terms. This makes the log determinant of the Jacobian quite simple and easy to compute:

$$\log \left|\operatorname{det} \frac{d \mathbf{x}}{d \mathbf{y}}\right|=-\sum_{i=0}^{D} \log \sigma_{i}\left(\mathbf{y}_{0: i-1}\right)$$

Inverse Autoregressive Flows

Now that we have the transform (inverse autoregressive transform, IAF), all that we need to do is to plug in a bunch of IAF transforms right after the latent variable in our VAE. [2] and [4] has all the rigorous details and the math that supports such implementation. And this diagram from [2] describes it all:

(Left) A chain of Inverse Autoregressive Flows transforming the posterior of a VAE (Right) A single IAF step (Source: [2])

Finally the density under the final iterate using the inverse autoregressive flow (details in [2] and [4]) is:

$$\log q\left(\mathbf{z}_{T} | \mathbf{x}\right)=-\sum_{i=1}^{D}\left(\frac{1}{2} \epsilon_{i}^{2}+\frac{1}{2} \log (2 \pi)+\sum_{t=0}^{T} \log \sigma_{t, i}\right)$$

And as [4] shows, the modified variational objective becomes the following:

$$\begin{aligned} \log p(\mathbf{x}) & \geq-E_{q}\left[\log \frac{q\left(\mathbf{z}_{T} | \mathbf{x}\right)}{p\left(\mathbf{z}_{T}, \mathbf{x}\right)}\right] \\ &=E_{q}\left[\log p\left(\mathbf{z}_{T}, \mathbf{x}\right)-\log q\left(\mathbf{z}_{T} | \mathbf{x}\right)\right] \\ &=E_{q}\left[\log p\left(\mathbf{x} | \mathbf{z}_{T}\right)+\log p\left(\mathbf{z}_{T}\right)-\log q\left(\mathbf{z}_{T} | \mathbf{x}\right)\right] \\ \log q\left(\mathbf{z}_{\mathrm{T}} | \mathbf{x}\right) &=-\sum_{i=0}^{D}\left[\frac{1}{2} \epsilon_{i}^{2}+\frac{1}{2} \log (2 \pi)+\sum_{t=0}^{D} \log \mathbf{s}_{t, i}\right] \\ \log p\left(\mathbf{z}_{\mathrm{T}}\right) &=-\sum_{i=0}^{D}\left[\frac{1}{2} \mathbf{z}_{T}^{2}+\frac{1}{2} \log (2 \pi)\right] \end{aligned}$$

As [2] mentions:

The flexibility of the distribution of the final iterate $\mathbf{z}_{T}$, and its ability to closely fit to the true posterior, increases with the expressivity of the autoregressive models and the depth of the chain.

References


In [ ]: