{"id":6517,"library":"aqtp","title":"Accurate Quantized Training Library (AQT)","description":"AQT (Accurate Quantized Training) is a Python software library designed for easy tensor operation quantization in JAX, providing excellent quantized int8 model quality without extensive manual tuning. It enables significant training speedup on modern ML accelerators and offers simple, flexible APIs suitable for both production and research. AQT focuses on quantizing tensor operations like matmul, einsum, and conv, without making assumptions about their use in neural networks, making it injectable into any JAX computation. It has been extensively tested with frameworks such as Flax, Pax, and MaxText at Google. The current version is 0.9.0, with a rapid release cadence for minor versions (monthly/bi-monthly).","status":"active","version":"0.9.0","language":"en","source_language":"en","source_url":"https://github.com/google/aqt","tags":["quantization","jax","machine learning","deep learning","ai","neural networks","tensor operations","flax"],"install":[{"cmd":"pip install aqtp","lang":"bash","label":"Install aqtp"}],"dependencies":[{"reason":"Core dependency for tensor operations and model definitions. Specific JAX installation (CPU/GPU/TPU) is left to the user.","package":"jax","optional":false},{"reason":"Commonly used neural network library in JAX ecosystems, demonstrated in quickstart examples.","package":"flax","optional":true}],"imports":[{"note":"The quickstart explicitly states 'Other AQT versions are obsolete.', indicating 'aqt.jax.v2' is the current recommended import path for JAX integration.","symbol":"aqt.jax.v2","correct":"import aqt.jax.v2 as aqt"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nfrom flax import linen as nn\nimport aqt.jax.v2 as aqt\n\n# Define a simple MLP\nclass Mlp(nn.Module):\n  num_layers: int\n  features: int\n\n  @nn.compact\n  def __call__(self, x):\n    for i in range(self.num_layers - 1):\n      x = nn.Dense(self.features)(x)\n      x = nn.relu(x)\n    x = nn.Dense(self.features)(x)\n    return x\n\n# Create an AQT configuration for int8 quantization\naqt_config = aqt.config.DotGeneral(\n    rhs=aqt.config.Tensor(bits=8),\n    lhs=aqt.config.Tensor(bits=8),\n)\n\n# Function to quantize a module\ndef quantize_module(module, aqt_cfg):\n  return aqt.config.set_dot_general_by_config(module, aqt_cfg)\n\n# Initialize model and quantize\nkey = jax.random.PRNGKey(0)\nmodel = Mlp(num_layers=3, features=10)\ninput_shape = (1, 5) # Batch size 1, 5 input features\nparams = model.init(key, jnp.zeros(input_shape))\n\nquantized_model = quantize_module(model, aqt_config)\nquantized_params = quantized_model.init(key, jnp.zeros(input_shape))\n\nprint(f\"Original output shape: {model.apply(params, jnp.zeros(input_shape)).shape}\")\nprint(f\"Quantized output shape: {quantized_model.apply(quantized_params, jnp.zeros(input_shape)).shape}\")","lang":"python","description":"This quickstart demonstrates how to define a simple Multi-Layer Perceptron (MLP) using Flax, then apply 8-bit quantization using AQT's `aqt.jax.v2` API. It shows how to create an AQT configuration for `DotGeneral` operations and inject it into the neural network, allowing for quantized forward and backward passes. This example requires `jax` and `flax` to be installed."},"warnings":[{"fix":"Migrate code to use `import aqt.jax.v2 as aqt` and update API calls according to the latest documentation. Review the official GitHub repository for migration guides.","message":"The documentation explicitly states 'Other AQT versions are obsolete.'. Users should ensure they are importing from `aqt.jax.v2` and using the latest API patterns as older versions may be unmaintained or removed.","severity":"breaking","affected_versions":"<0.9.0"},{"fix":"Always use `pip install aqtp` and verify the package description on PyPI matches 'Accurate Quantized Training library' to avoid installing an incorrect package.","message":"The `aqtp` package (Accurate Quantized Training) can be confused with other similarly named packages on PyPI, notably `aqtinstall` (a Qt CLI installer) and an unrelated `aqt` package. Ensure you install `aqtp` for the JAX quantization library.","severity":"gotcha","affected_versions":"All"},{"fix":"Ensure your Python environment is version 3.10 or greater before installing `aqtp`.","message":"The `aqtp` library requires Python 3.10 or newer. Attempting to install or run `aqtp` with older Python versions will lead to compatibility issues or failures.","severity":"gotcha","affected_versions":"<0.9.0"},{"fix":"Pin `aqtp` to specific versions in your `requirements.txt` or `pyproject.toml` to prevent unexpected breaking changes during updates.","message":"The project is currently in 'Development Status :: 3 - Alpha'. This indicates that while actively developed, the API may still undergo non-backward compatible changes in minor versions, and stability is not yet guaranteed for long-term production use without careful version pinning.","severity":"gotcha","affected_versions":"All"}],"env_vars":null,"last_verified":"2026-04-15T00:00:00.000Z","next_check":"2026-07-14T00:00:00.000Z","problems":[{"fix":"Install the package using pip: 'pip install aqtp'.","cause":"The 'aqtp' package is not installed in the Python environment.","error":"ModuleNotFoundError: No module named 'aqtp'"},{"fix":"Ensure you are importing the correct module or attribute from 'aqtp'.","cause":"The module 'aqtp' does not contain an 'AQT' attribute.","error":"ImportError: cannot import name 'AQT' from 'aqtp'"},{"fix":"Verify the correct usage and available functions in the 'aqtp' documentation.","cause":"The 'aqtp' module does not have a 'quantize' function or attribute.","error":"AttributeError: module 'aqtp' has no attribute 'quantize'"},{"fix":"Ensure you install the correct package `aqtp` using `pip install aqtp` and import it as `import aqt.jax.v2 as aqt` or `from aqt.jax.v2 import ...`.","cause":"Users might attempt to import the `aqt` module directly, but the package name on PyPI is `aqtp`.","error":"ModuleNotFoundError: No module named 'aqt'"},{"fix":"Update `aqtp` to the latest version (`pip install --upgrade aqtp`) and ensure your JAX, JAXlib, and Flax installations are compatible with the latest `aqtp` release. You might need to pin specific versions of JAX/Flax if the issue persists.","cause":"This error typically occurs due to an incompatibility between the installed version of `aqtp` (e.g., an older version like 0.1.1) and a newer version of JAX or related frameworks like Flax, where `jax.config.define_bool_state` might have been removed or renamed.","error":"AttributeError: module 'jax.config' has no attribute 'define_bool_state'"}]}