Accurate Quantized Training Library (AQT)
AQT (Accurate Quantized Training) is a Python software library designed for easy tensor operation quantization in JAX, providing excellent quantized int8 model quality without extensive manual tuning. It enables significant training speedup on modern ML accelerators and offers simple, flexible APIs suitable for both production and research. AQT focuses on quantizing tensor operations like matmul, einsum, and conv, without making assumptions about their use in neural networks, making it injectable into any JAX computation. It has been extensively tested with frameworks such as Flax, Pax, and MaxText at Google. The current version is 0.9.0, with a rapid release cadence for minor versions (monthly/bi-monthly).
Common errors
-
ModuleNotFoundError: No module named 'aqtp'
cause The 'aqtp' package is not installed in the Python environment.fixInstall the package using pip: 'pip install aqtp'. -
ImportError: cannot import name 'AQT' from 'aqtp'
cause The module 'aqtp' does not contain an 'AQT' attribute.fixEnsure you are importing the correct module or attribute from 'aqtp'. -
AttributeError: module 'aqtp' has no attribute 'quantize'
cause The 'aqtp' module does not have a 'quantize' function or attribute.fixVerify the correct usage and available functions in the 'aqtp' documentation. -
ModuleNotFoundError: No module named 'aqt'
cause Users might attempt to import the `aqt` module directly, but the package name on PyPI is `aqtp`.fixEnsure you install the correct package `aqtp` using `pip install aqtp` and import it as `import aqt.jax.v2 as aqt` or `from aqt.jax.v2 import ...`. -
AttributeError: module 'jax.config' has no attribute 'define_bool_state'
cause This error typically occurs due to an incompatibility between the installed version of `aqtp` (e.g., an older version like 0.1.1) and a newer version of JAX or related frameworks like Flax, where `jax.config.define_bool_state` might have been removed or renamed.fixUpdate `aqtp` to the latest version (`pip install --upgrade aqtp`) and ensure your JAX, JAXlib, and Flax installations are compatible with the latest `aqtp` release. You might need to pin specific versions of JAX/Flax if the issue persists.
Warnings
- breaking The documentation explicitly states 'Other AQT versions are obsolete.'. Users should ensure they are importing from `aqt.jax.v2` and using the latest API patterns as older versions may be unmaintained or removed.
- gotcha The `aqtp` package (Accurate Quantized Training) can be confused with other similarly named packages on PyPI, notably `aqtinstall` (a Qt CLI installer) and an unrelated `aqt` package. Ensure you install `aqtp` for the JAX quantization library.
- gotcha The `aqtp` library requires Python 3.10 or newer. Attempting to install or run `aqtp` with older Python versions will lead to compatibility issues or failures.
- gotcha The project is currently in 'Development Status :: 3 - Alpha'. This indicates that while actively developed, the API may still undergo non-backward compatible changes in minor versions, and stability is not yet guaranteed for long-term production use without careful version pinning.
Install
-
pip install aqtp
Imports
- aqt.jax.v2
import aqt.jax.v2 as aqt
Quickstart
import jax
import jax.numpy as jnp
from flax import linen as nn
import aqt.jax.v2 as aqt
# Define a simple MLP
class Mlp(nn.Module):
num_layers: int
features: int
@nn.compact
def __call__(self, x):
for i in range(self.num_layers - 1):
x = nn.Dense(self.features)(x)
x = nn.relu(x)
x = nn.Dense(self.features)(x)
return x
# Create an AQT configuration for int8 quantization
aqt_config = aqt.config.DotGeneral(
rhs=aqt.config.Tensor(bits=8),
lhs=aqt.config.Tensor(bits=8),
)
# Function to quantize a module
def quantize_module(module, aqt_cfg):
return aqt.config.set_dot_general_by_config(module, aqt_cfg)
# Initialize model and quantize
key = jax.random.PRNGKey(0)
model = Mlp(num_layers=3, features=10)
input_shape = (1, 5) # Batch size 1, 5 input features
params = model.init(key, jnp.zeros(input_shape))
quantized_model = quantize_module(model, aqt_config)
quantized_params = quantized_model.init(key, jnp.zeros(input_shape))
print(f"Original output shape: {model.apply(params, jnp.zeros(input_shape)).shape}")
print(f"Quantized output shape: {quantized_model.apply(quantized_params, jnp.zeros(input_shape)).shape}")