VQ-VAE & VQ-GAN
Table of Contents:
VQ Overview
A vector-quantized network (VQN) is a neural-network consisting of a vector-quantization layer \(h(\cdot,\cdot)\).
Given autoencoder \(\hat{\mathbf{y}} = G(F(\mathbf{x}))\)
Add a VQ layer: \(\hat{\mathbf{y}} = G\Big(h(F(\mathbf{x}), C)\Big) = G\Big(h(\mathbf{z}_e, C)\Big) = G(\mathbf{z}_q)\)
Round/Replace: The VQ layer \(h(\cdot)\) quantizes the embedding \(\mathbf{z}_e = F(x)\) by selecting a vector from a collection of \(m\) vectors.
The individual vector \(\mathbf{c}_i\) is referred to as the code-vector, the index \(i\) as the code, and the collection of the code-vectors as the codebook \(\mathcal{C} = \{\mathbf{c}_1, \mathbf{c}_2, \dots, \mathbf{c}_m \}\).
Straight-through Estimation (STE)
Estimator of the expected gradient through stochastic neurons.
The idea is simply to back-propagate through the hard threshold function (1 if the argument is positive, 0 otherwise) as if it had been the identity function. It is clearly a biased estimator, but when considering a single layer of neurons, it has the right sign (this is not guaranteed anymore when backpropagating through more hidden layers).
\[\frac{\partial \mathcal{L}}{\partial F} = \frac{\partial \mathcal{L} }{\partial \mathbf{\hat{y}} } \frac{\partial \mathbf{\hat{y}} }{\partial \mathbf{z}_q} \frac{\partial \mathbf{z}_q}{\partial \mathbf{z}_e} \frac{\partial \mathbf{z}_e}{\partial F} \approx \frac{}{}\]“Simply copies the gradients from the decoder to the encoder”
How to learn the codebook?
Implementation
Karpathy’s VQ-VAE (Github)
Codes can be uniformly initialized w.r.t. square root of dimension, or with k-means update for unused code.
Improving VQ Networks: Approaches
- LRU Policy / Random Restarts for “Codebook Collapse”
- Affine Reparameterization.
VQ-GAN
Two-stage training.
Transformer models predict sequences of 16*16. Vary the model capacities between 85M and 310M parameters. Unconditional synthesis?
ImageNet: 256 embed dim, 1024 embeddings in codebook
- Discriminator weight: 0.8
- Discriminator start: 250K
- Codebook weight: 1.0
- Discriminator in channels: 3
- base learning rate: 4.5e-6
OpenImages: 256 embed dim, 8192 embeddings in codebook
- 36 layers, 16 heards, 1536 embedding dim
- Uses Karpathy’s MinGPT.
vs. DALL-E 8192 codebook size.
Parti
L1 & L2 reconstruction losses.
References
- VQ-VAE. Van Den Oord et al, NeurIPS 2017
- VQ-GAN. Esser al al., Taming Transformers for High-Resolution Image Synthesis, CVPR 2021. PDF.
- Muse. Chang et al., 2023.
- ViT-VQGAN Yu et al., 2022. Vector-quantized Image Modeling with Improved VQGAN [PDF].
- Parti Yu et al., 2022. Scaling Autoregressive Models for Content-Rich Text-to-Image Generation. [PDF].
- Y Bengio, N Leonard, A Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv, 2013. PDF.
- Minyoung Huh, Brian Cheung, Pulkit Agrawal, Phillip Isola. Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks. arXiv, 2023. PDF.