Qwix: A JAX Quantization Library
Qwix is a JAX-native quantization library for both research and production, providing efficient model compression through Quantization-Aware Training (QAT), Post-Training Quantization (PTQ), and ODML quantization. It supports various XLA targets (CPU/GPU/TPU) and LiteRT, featuring a flexible, regex-based configuration system for Flax Linen and NNX models. The current version is 0.1.5, with an active release cadence.
Warnings
- gotcha The PyPI project page for Qwix currently states, 'Qwix doesn't provide a PyPI package yet. To use Qwix, you need to install from GitHub directly.' This is despite the `qwix` package being available on PyPI. For consistent and recommended installation, use the `pip install git+https://github.com/google/qwix` command.
- gotcha An RNG issue for LoRA (Low-Rank Adaptation) was fixed in version 0.1.5. If you are using LoRA or QLoRA with Qwix in older versions, you may encounter unexpected behavior related to random number generation.
- gotcha Qwix does not currently expose a `__version__` attribute, which is a common Python practice for programmatic version checking. This may complicate dependency management or conditional logic based on the installed Qwix version.
- gotcha Users should understand the distinction between Qwix's Quantization-Aware Training (QAT) and Quantized Training (QT). QAT uses fake quantization to recover model quality by making the model aware of precision loss during inference, while QT performs computations using low-precision integer arithmetic in both forward and backward passes for performance benefits. Choosing the correct mode depends on your specific optimization goals.
Install
-
pip install git+https://github.com/google/qwix
Imports
- qwix
import qwix
- QuantizationRule
from qwix import QuantizationRule
- quantize_model
from qwix import quantize_model
Quickstart
import jax
import jax.numpy as jnp
from flax import linen as nn
import qwix
# Define a simple MLP model using Flax Linen
class MLP(nn.Module):
dhidden: int = 64
dout: int = 10
@nn.compact
def __call__(self, x):
x = nn.Dense(self.dhidden, use_bias=False)(x)
x = nn.relu(x)
x = nn.Dense(self.dout, use_bias=False)(x)
return x
# Initialize the model and dummy input
model = MLP()
key = jax.random.key(0)
model_input = jax.random.uniform(key, (8, 16))
params = model.init(key, model_input)['params']
# Define quantization rules for int8 weight and activation quantization
# This rule matches all modules ('.*')
rules = [
qwix.QuantizationRule(
module_path='.*',
weight_qtype=jnp.int8,
act_qtype=jnp.int8,
)
]
# Apply Post-Training Quantization (PTQ)
ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules))
print("Original model parameters (example kernel shape):", params['Dense_0']['kernel'].shape)
print("Quantized model parameters (example kernel):", jax.eval_shape(ptq_model.apply, {'params': params}, model_input)['Dense_0']['kernel'])