Table of Contents:

VAE

Pytorch VAE implementation and JAX VAE implementation

Encoder and Decoder

  • The latent prior is given by \(p(\mathbf{z})=\mathcal{N}(0,I)\).
  • From a coding theory perspective, the unobserved variables \(\mathbf{z}\) have an interpretation as a latent representation or code.
  • Encoder: We refer to the recognition model \(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x})\) as a probabilistic encoder, since given a datapoint \(\mathbf{x}\) it produces a distribution (e.g. a Gaussian) over the possible values of the code \(\mathbf{z}\) from which the datapoint \(\mathbf{x}\) could have been generated.
  • Decoder: We will refer to \(p_{\boldsymbol{\theta}}(\mathbf{x} \mid \mathbf{z})\) as a probabilistic decoder, since given a code \(\mathbf{z}\) it produces a distribution over the possible corresponding values of \(\mathbf{x}\).

In JAX, this resembles:

class Encoder(nn.Module):
  latents: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(500, name='fc1')(x)
    x = nn.relu(x)
    mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
    logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
    return mean_x, logvar_x

class Decoder(nn.Module):

  @nn.compact
  def __call__(self, z):
    z = nn.Dense(500, name='fc1')(z)
    z = nn.relu(z)
    x = nn.Dense(784, name='fc2')(z)
    return x

We convert the output of the encoder into latent samples \(\mathbf{z}\) using the reparameterization trick:

def reparameterize(rng, mean, logvar):
  std = jnp.exp(0.5 * logvar)
  eps = random.normal(rng, logvar.shape)
  return mean + eps * std

In other words, a VAE can be defined as follows in JAX:

class VAE(nn.Module):
  latents: int = 20

  def setup(self):
    self.encoder = Encoder(self.latents)
    self.decoder = Decoder()

  def __call__(self, x, z_rng):
    mean, logvar = self.encoder(x)
    z = reparameterize(z_rng, mean, logvar)
    recon_x = self.decoder(z)
    return recon_x, mean, logvar

  def generate(self, z):
    return nn.sigmoid(self.decoder(z))

Training VAEs: KL Divergence

KL Divergence: Revisited First, we’ll revisit the definition of KL divergence:

\[D_{KL}(P \lVert Q) = KL(P \lVert Q) = \sum\limits_x P(x) \log \frac{P(x)}{Q(x)}\]

see Appendix B from VAE paper (Kingma and Welling, 2014)

KL Divergence Between Two Gaussians Solution of \(D_{KL}(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x} ) \lVert p_{\mathbf{\theta}}(\mathbf{z}))\), Gaussian case:

The variational lower bound (the objective to be maximized) contains a KL term that can often be integrated analytically. Here we give the solution when both the prior \(p_{\mathbf{\theta}}(\mathbf{z}) = \mathcal{N}(0,\mathbf{I})\) and the posterior approximation \(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x}^{(i)})\) are Gaussian. Let \(J\) be the dimensionality of \(\mathbf{z}\). Let \(\mathbf{\mu}\) and \(\mathbf{\sigma}\) denote the variational mean and s.d. evaluated at datapoint \(i\), and let \(\mu_j\) and \(\sigma_j\) simply denote the \(j\)-th element of these vectors. Then:

\[\begin{align} \int q_{\mathbf{\theta}}(\mathbf{z} \mid \mathbf{x} ) \log p(\mathbf{z}) \,d\mathbf{z} &= \int \mathcal{N}(\mathbf{z};\mathbf{\mu},\mathbf{\sigma}^2 \mathbf{I} ) \log \mathcal{N}(\mathbf{z};\mathbf{0},\mathbf{I}) \,d\mathbf{z} \\ &= - \frac{J}{2} \log (2 \pi) - \frac{1}{2} \sum_{j=1}^J (\mu_j^2 + \sigma_j^2) \end{align}\]

And:

\[\begin{align} \int q_{\mathbf{\theta}}(\mathbf{z} \mid \mathbf{x} ) \log q_{\mathbf{\theta}}(\mathbf{z} ) \,d\mathbf{z} &= \int \mathcal{N}(\mathbf{z};\mathbf{\mu},\mathbf{\sigma}^2 \mathbf{I} ) \log \mathcal{N}(\mathbf{z};\mathbf{\mu},\mathbf{\sigma}^2) \,d\mathbf{z} \\ &= - \frac{J}{2} \log (2 \pi) - \frac{1}{2} \sum_{j=1}^J ( 1 + \log \sigma^2_j ) \end{align}\]

Therefore:

\[\begin{align} - D_{KL}((q_{\mathbf{\phi}}(\mathbf{z}) \lVert p_{\mathbf{\theta}}(\mathbf{z})) &= \int q_{\mathbf{\theta}}(\mathbf{z}) \left(\log p_{\mathbf{\theta}}(\mathbf{z}) - \log q_{\mathbf{\theta}}(\mathbf{z})\right) \,d\mathbf{z} \\ &= \frac{1}{2} \sum_{j=1}^J \left(1 + \log ((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2 \right) \end{align}\]

When using a recognition model \(q_{\mathbf{\phi}}(\mathbf{z} \mid \mathbf{x})\) then \(\mathbf{\mu}\) and s.d. \(\mathbf{\sigma}\) are simply functions of \(\mathbf{x}\) and the variational parameters \(\mathbf{\phi}\), as exemplified in the text.

A full derivation can be found here.

In JAX:

@jax.vmap
def kl_divergence(mean, logvar):
  return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

In Pytorch:

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

References

[1] Diederik P. Kingma, Max Welling. Auto-Encoding Variational Bayes. 2013. PDF

[2] Carl Doersch. Tutorial on Variational Autoencoders. PDF