Accurate Quantized Training Library (AQT)

raw JSON →
0.9.0 verified Wed Apr 15 auth: no python

AQT (Accurate Quantized Training) is a Python software library designed for easy tensor operation quantization in JAX, providing excellent quantized int8 model quality without extensive manual tuning. It enables significant training speedup on modern ML accelerators and offers simple, flexible APIs suitable for both production and research. AQT focuses on quantizing tensor operations like matmul, einsum, and conv, without making assumptions about their use in neural networks, making it injectable into any JAX computation. It has been extensively tested with frameworks such as Flax, Pax, and MaxText at Google. The current version is 0.9.0, with a rapid release cadence for minor versions (monthly/bi-monthly).

pip install aqtp
error ModuleNotFoundError: No module named 'aqtp'
cause The 'aqtp' package is not installed in the Python environment.
fix
Install the package using pip: 'pip install aqtp'.
error ImportError: cannot import name 'AQT' from 'aqtp'
cause The module 'aqtp' does not contain an 'AQT' attribute.
fix
Ensure you are importing the correct module or attribute from 'aqtp'.
error AttributeError: module 'aqtp' has no attribute 'quantize'
cause The 'aqtp' module does not have a 'quantize' function or attribute.
fix
Verify the correct usage and available functions in the 'aqtp' documentation.
error ModuleNotFoundError: No module named 'aqt'
cause Users might attempt to import the `aqt` module directly, but the package name on PyPI is `aqtp`.
fix
Ensure you install the correct package aqtp using pip install aqtp and import it as import aqt.jax.v2 as aqt or from aqt.jax.v2 import ....
error AttributeError: module 'jax.config' has no attribute 'define_bool_state'
cause This error typically occurs due to an incompatibility between the installed version of `aqtp` (e.g., an older version like 0.1.1) and a newer version of JAX or related frameworks like Flax, where `jax.config.define_bool_state` might have been removed or renamed.
fix
Update aqtp to the latest version (pip install --upgrade aqtp) and ensure your JAX, JAXlib, and Flax installations are compatible with the latest aqtp release. You might need to pin specific versions of JAX/Flax if the issue persists.
breaking The documentation explicitly states 'Other AQT versions are obsolete.'. Users should ensure they are importing from `aqt.jax.v2` and using the latest API patterns as older versions may be unmaintained or removed.
fix Migrate code to use `import aqt.jax.v2 as aqt` and update API calls according to the latest documentation. Review the official GitHub repository for migration guides.
gotcha The `aqtp` package (Accurate Quantized Training) can be confused with other similarly named packages on PyPI, notably `aqtinstall` (a Qt CLI installer) and an unrelated `aqt` package. Ensure you install `aqtp` for the JAX quantization library.
fix Always use `pip install aqtp` and verify the package description on PyPI matches 'Accurate Quantized Training library' to avoid installing an incorrect package.
gotcha The `aqtp` library requires Python 3.10 or newer. Attempting to install or run `aqtp` with older Python versions will lead to compatibility issues or failures.
fix Ensure your Python environment is version 3.10 or greater before installing `aqtp`.
gotcha The project is currently in 'Development Status :: 3 - Alpha'. This indicates that while actively developed, the API may still undergo non-backward compatible changes in minor versions, and stability is not yet guaranteed for long-term production use without careful version pinning.
fix Pin `aqtp` to specific versions in your `requirements.txt` or `pyproject.toml` to prevent unexpected breaking changes during updates.

This quickstart demonstrates how to define a simple Multi-Layer Perceptron (MLP) using Flax, then apply 8-bit quantization using AQT's `aqt.jax.v2` API. It shows how to create an AQT configuration for `DotGeneral` operations and inject it into the neural network, allowing for quantized forward and backward passes. This example requires `jax` and `flax` to be installed.

import jax
import jax.numpy as jnp
from flax import linen as nn
import aqt.jax.v2 as aqt

# Define a simple MLP
class Mlp(nn.Module):
  num_layers: int
  features: int

  @nn.compact
  def __call__(self, x):
    for i in range(self.num_layers - 1):
      x = nn.Dense(self.features)(x)
      x = nn.relu(x)
    x = nn.Dense(self.features)(x)
    return x

# Create an AQT configuration for int8 quantization
aqt_config = aqt.config.DotGeneral(
    rhs=aqt.config.Tensor(bits=8),
    lhs=aqt.config.Tensor(bits=8),
)

# Function to quantize a module
def quantize_module(module, aqt_cfg):
  return aqt.config.set_dot_general_by_config(module, aqt_cfg)

# Initialize model and quantize
key = jax.random.PRNGKey(0)
model = Mlp(num_layers=3, features=10)
input_shape = (1, 5) # Batch size 1, 5 input features
params = model.init(key, jnp.zeros(input_shape))

quantized_model = quantize_module(model, aqt_config)
quantized_params = quantized_model.init(key, jnp.zeros(input_shape))

print(f"Original output shape: {model.apply(params, jnp.zeros(input_shape)).shape}")
print(f"Quantized output shape: {quantized_model.apply(quantized_params, jnp.zeros(input_shape)).shape}")