Qwix: A JAX Quantization Library

0.1.5 · active · verified Tue Apr 14

Qwix is a JAX-native quantization library for both research and production, providing efficient model compression through Quantization-Aware Training (QAT), Post-Training Quantization (PTQ), and ODML quantization. It supports various XLA targets (CPU/GPU/TPU) and LiteRT, featuring a flexible, regex-based configuration system for Flax Linen and NNX models. The current version is 0.1.5, with an active release cadence.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a simple Flax MLP model and then apply Post-Training Quantization (PTQ) using Qwix. It configures a `QuantizationRule` to quantize both weights and activations to int8 across all modules in the model, and then applies the quantization using `qwix.quantize_model`.

import jax
import jax.numpy as jnp
from flax import linen as nn
import qwix

# Define a simple MLP model using Flax Linen
class MLP(nn.Module):
    dhidden: int = 64
    dout: int = 10

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.dhidden, use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(self.dout, use_bias=False)(x)
        return x

# Initialize the model and dummy input
model = MLP()
key = jax.random.key(0)
model_input = jax.random.uniform(key, (8, 16))
params = model.init(key, model_input)['params']

# Define quantization rules for int8 weight and activation quantization
# This rule matches all modules ('.*')
rules = [
    qwix.QuantizationRule(
        module_path='.*',
        weight_qtype=jnp.int8,
        act_qtype=jnp.int8,
    )
]

# Apply Post-Training Quantization (PTQ)
ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules))

print("Original model parameters (example kernel shape):", params['Dense_0']['kernel'].shape)
print("Quantized model parameters (example kernel):", jax.eval_shape(ptq_model.apply, {'params': params}, model_input)['Dense_0']['kernel'])

view raw JSON →