Vector Quantization - Pytorch

1.28.1 · active · verified Sat Apr 11

A vector quantization library for PyTorch, originally transcribed from Deepmind's TensorFlow implementation. It focuses on using exponential moving averages to update the dictionary and has been applied successfully in generative models for images (VQ-VAE-2) and music (Jukebox). The library is actively maintained with frequent micro-releases, often incorporating new research techniques.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the basic usage of the `VectorQuantize` module. It initializes a VQ layer with a specified input dimension, codebook size, EMA decay, and commitment weight, then quantizes a random input tensor and returns the quantized output, codebook indices, and the commitment loss.

import torch
from vector_quantize_pytorch import VectorQuantize

# Initialize VectorQuantize
vq = VectorQuantize(
    dim = 256,          # input feature dimension
    codebook_size = 512,  # number of vectors in the codebook
    decay = 0.8,        # exponential moving average decay, lower means faster dictionary change
    commitment_weight = 1. # weight on the commitment loss
)

# Example input tensor: (batch_size, sequence_length, dim)
x = torch.randn(1, 1024, 256)

# Perform quantization
quantized, indices, commit_loss = vq(x)

print(f"Original input shape: {x.shape}")
print(f"Quantized output shape: {quantized.shape}")
print(f"Indices shape: {indices.shape}")
print(f"Commitment loss: {commit_loss.item():.4f}")

view raw JSON →