Vector Quantization - Pytorch
A vector quantization library for PyTorch, originally transcribed from Deepmind's TensorFlow implementation. It focuses on using exponential moving averages to update the dictionary and has been applied successfully in generative models for images (VQ-VAE-2) and music (Jukebox). The library is actively maintained with frequent micro-releases, often incorporating new research techniques.
Warnings
- gotcha Dead codebook entries are a common issue in Vector Quantization, where some codebook vectors are rarely or never used. This can lead to inefficient models. The library offers features like `orthogonal_reg_weight` to help mitigate this problem by encouraging codebook diversity.
- gotcha The vector quantization layer is non-differentiable, typically requiring a straight-through estimator (STE) for gradient flow. The standard STE might not fully capture the quantization operation's dynamics. The library includes the 'rotation trick' to potentially improve gradient quality through the VQ layer.
- gotcha Distributed Data Parallel (DDP) training setups might encounter issues like hanging, especially in older versions or with specific configurations involving codebook updates (e.g., k-means clustering or EMA). While some DDP-related issues have been addressed, it's a known area of complexity.
- gotcha Hyperparameters such as `decay` (for EMA codebook updates) and `commitment_weight` are critical for training stability and performance. Incorrect values can lead to unstable training, dead codes, or poor reconstruction quality.
Install
-
pip install vector-quantize-pytorch
Imports
- VectorQuantize
from vector_quantize_pytorch import VectorQuantize
- ResidualVQ
from vector_quantize_pytorch import ResidualVQ
- ResidualFSQ
from vector_quantize_pytorch import ResidualFSQ
Quickstart
import torch
from vector_quantize_pytorch import VectorQuantize
# Initialize VectorQuantize
vq = VectorQuantize(
dim = 256, # input feature dimension
codebook_size = 512, # number of vectors in the codebook
decay = 0.8, # exponential moving average decay, lower means faster dictionary change
commitment_weight = 1. # weight on the commitment loss
)
# Example input tensor: (batch_size, sequence_length, dim)
x = torch.randn(1, 1024, 256)
# Perform quantization
quantized, indices, commit_loss = vq(x)
print(f"Original input shape: {x.shape}")
print(f"Quantized output shape: {quantized.shape}")
print(f"Indices shape: {indices.shape}")
print(f"Commitment loss: {commit_loss.item():.4f}")