Variational Auto-Encoders (VAEs)
Table of Contents:
VAE
Pytorch VAE implementation and JAX VAE implementation
Encoder and Decoder
- The latent prior is given by \(p(\mathbf{z})=\mathcal{N}(0,I)\).
- From a coding theory perspective, the unobserved variables \(\mathbf{z}\) have an interpretation as a latent representation or code.
- Encoder: We refer to the recognition model \(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x})\) as a probabilistic encoder, since given a datapoint \(\mathbf{x}\) it produces a distribution (e.g. a Gaussian) over the possible values of the code \(\mathbf{z}\) from which the datapoint \(\mathbf{x}\) could have been generated.
- Decoder: We will refer to \(p_{\boldsymbol{\theta}}(\mathbf{x} \mid \mathbf{z})\) as a probabilistic decoder, since given a code \(\mathbf{z}\) it produces a distribution over the possible corresponding values of \(\mathbf{x}\).
In JAX, this resembles:
class Encoder(nn.Module):
latents: int
@nn.compact
def __call__(self, x):
x = nn.Dense(500, name='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
return mean_x, logvar_x
class Decoder(nn.Module):
@nn.compact
def __call__(self, z):
z = nn.Dense(500, name='fc1')(z)
z = nn.relu(z)
x = nn.Dense(784, name='fc2')(z)
return x
We convert the output of the encoder into latent samples \(\mathbf{z}\) using the reparameterization trick:
def reparameterize(rng, mean, logvar):
std = jnp.exp(0.5 * logvar)
eps = random.normal(rng, logvar.shape)
return mean + eps * std
In other words, a VAE can be defined as follows in JAX:
class VAE(nn.Module):
latents: int = 20
def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __call__(self, x, z_rng):
mean, logvar = self.encoder(x)
z = reparameterize(z_rng, mean, logvar)
recon_x = self.decoder(z)
return recon_x, mean, logvar
def generate(self, z):
return nn.sigmoid(self.decoder(z))
Training VAEs: KL Divergence
KL Divergence: Revisited First, we’ll revisit the definition of KL divergence:
\[D_{KL}(P \lVert Q) = KL(P \lVert Q) = \sum\limits_x P(x) \log \frac{P(x)}{Q(x)}\]see Appendix B from VAE paper (Kingma and Welling, 2014)
KL Divergence Between Two Gaussians Solution of \(D_{KL}(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x} ) \lVert p_{\mathbf{\theta}}(\mathbf{z}))\), Gaussian case:
The variational lower bound (the objective to be maximized) contains a KL term that can often be integrated analytically. Here we give the solution when both the prior \(p_{\mathbf{\theta}}(\mathbf{z}) = \mathcal{N}(0,\mathbf{I})\) and the posterior approximation \(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x}^{(i)})\) are Gaussian. Let \(J\) be the dimensionality of \(\mathbf{z}\). Let \(\mathbf{\mu}\) and \(\mathbf{\sigma}\) denote the variational mean and s.d. evaluated at datapoint \(i\), and let \(\mu_j\) and \(\sigma_j\) simply denote the \(j\)-th element of these vectors. Then:
\[\begin{align} \int q_{\mathbf{\theta}}(\mathbf{z} \mid \mathbf{x} ) \log p(\mathbf{z}) \,d\mathbf{z} &= \int \mathcal{N}(\mathbf{z};\mathbf{\mu},\mathbf{\sigma}^2 \mathbf{I} ) \log \mathcal{N}(\mathbf{z};\mathbf{0},\mathbf{I}) \,d\mathbf{z} \\ &= - \frac{J}{2} \log (2 \pi) - \frac{1}{2} \sum_{j=1}^J (\mu_j^2 + \sigma_j^2) \end{align}\]And:
\[\begin{align} \int q_{\mathbf{\theta}}(\mathbf{z} \mid \mathbf{x} ) \log q_{\mathbf{\theta}}(\mathbf{z} ) \,d\mathbf{z} &= \int \mathcal{N}(\mathbf{z};\mathbf{\mu},\mathbf{\sigma}^2 \mathbf{I} ) \log \mathcal{N}(\mathbf{z};\mathbf{\mu},\mathbf{\sigma}^2) \,d\mathbf{z} \\ &= - \frac{J}{2} \log (2 \pi) - \frac{1}{2} \sum_{j=1}^J ( 1 + \log \sigma^2_j ) \end{align}\]Therefore:
\[\begin{align} - D_{KL}((q_{\mathbf{\phi}}(\mathbf{z}) \lVert p_{\mathbf{\theta}}(\mathbf{z})) &= \int q_{\mathbf{\theta}}(\mathbf{z}) \left(\log p_{\mathbf{\theta}}(\mathbf{z}) - \log q_{\mathbf{\theta}}(\mathbf{z})\right) \,d\mathbf{z} \\ &= \frac{1}{2} \sum_{j=1}^J \left(1 + \log ((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2 \right) \end{align}\]When using a recognition model \(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x})\) then \(\mathbf{\mu}\) and s.d. \(\mathbf{\sigma}\) are simply functions of \(\mathbf{x}\) and the variational parameters \(\mathbf{\phi}\), as exemplified in the text.
A full derivation can be found here.
In JAX:
@jax.vmap
def kl_divergence(mean, logvar):
return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
In Pytorch:
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
References
[1] Diederik P. Kingma, Max Welling. Auto-Encoding Variational Bayes. 2013. PDF
[2] Carl Doersch. Tutorial on Variational Autoencoders. PDF