Accurate Quantized Training Library (AQT)

0.9.0 · active · verified Wed Apr 15

AQT (Accurate Quantized Training) is a Python software library designed for easy tensor operation quantization in JAX, providing excellent quantized int8 model quality without extensive manual tuning. It enables significant training speedup on modern ML accelerators and offers simple, flexible APIs suitable for both production and research. AQT focuses on quantizing tensor operations like matmul, einsum, and conv, without making assumptions about their use in neural networks, making it injectable into any JAX computation. It has been extensively tested with frameworks such as Flax, Pax, and MaxText at Google. The current version is 0.9.0, with a rapid release cadence for minor versions (monthly/bi-monthly).

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a simple Multi-Layer Perceptron (MLP) using Flax, then apply 8-bit quantization using AQT's `aqt.jax.v2` API. It shows how to create an AQT configuration for `DotGeneral` operations and inject it into the neural network, allowing for quantized forward and backward passes. This example requires `jax` and `flax` to be installed.

import jax
import jax.numpy as jnp
from flax import linen as nn
import aqt.jax.v2 as aqt

# Define a simple MLP
class Mlp(nn.Module):
  num_layers: int
  features: int

  @nn.compact
  def __call__(self, x):
    for i in range(self.num_layers - 1):
      x = nn.Dense(self.features)(x)
      x = nn.relu(x)
    x = nn.Dense(self.features)(x)
    return x

# Create an AQT configuration for int8 quantization
aqt_config = aqt.config.DotGeneral(
    rhs=aqt.config.Tensor(bits=8),
    lhs=aqt.config.Tensor(bits=8),
)

# Function to quantize a module
def quantize_module(module, aqt_cfg):
  return aqt.config.set_dot_general_by_config(module, aqt_cfg)

# Initialize model and quantize
key = jax.random.PRNGKey(0)
model = Mlp(num_layers=3, features=10)
input_shape = (1, 5) # Batch size 1, 5 input features
params = model.init(key, jnp.zeros(input_shape))

quantized_model = quantize_module(model, aqt_config)
quantized_params = quantized_model.init(key, jnp.zeros(input_shape))

print(f"Original output shape: {model.apply(params, jnp.zeros(input_shape)).shape}")
print(f"Quantized output shape: {quantized_model.apply(quantized_params, jnp.zeros(input_shape)).shape}")

view raw JSON →