PyTorch AO (torchao)
TorchAO is a PyTorch library for applying advanced optimization (AO) techniques, primarily quantization and sparsity, to deep learning models running on GPUs. It focuses on performance acceleration through low-precision kernels, mixture-of-experts (MoE) optimizations, and quantization-aware training (QAT). The current version is 0.17.0, with new versions and significant features released frequently, often monthly.
Warnings
- breaking The `quantize_` API underwent a significant overhaul in version 0.9.0, changing how quantization recipes are applied to models. Direct calls to `quantize_` with previous argument patterns will fail.
- deprecated Older configurations and less-used quantization options have been deprecated to streamline the library. Using these deprecated features may lead to warnings or errors in future releases.
- gotcha Features located in `torchao.prototype` modules are experimental and subject to frequent, unannounced API changes, or may be removed entirely without prior deprecation. They are not considered stable.
- gotcha Optimal performance for `torchao`'s advanced kernels (e.g., MXFP8 MoE, W4A8) often requires specific CUDA versions (e.g., CUDA 12.8+) or particular GPU architectures (e.g., Blackwell, GB200). Using non-supported environments may result in reduced performance, errors, or inability to leverage certain features.
Install
-
pip install torchao
Imports
- quantize_
from torchao.quantization import quantize_
- QuantConfig
from torchao.quantization.quant_config import QuantConfig
- int8_dynamic_activation_int4_weight
from torchao.quantization import int8_dynamic_activation_int4_weight
- apply_sparse_weights
from torchao.sparsity import apply_sparse_weights
Quickstart
import torch
import torch.nn as nn
from torchao.quantization import quantize_, int8_dynamic_activation_int4_weight
# 1. Define a simple model
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = MyModel()
print(f"Original model: {model}")
# 2. Define a quantization recipe
# This uses a predefined post-training quantization recipe
quantizer = int8_dynamic_activation_int4_weight()
# 3. Apply quantization to the model
# quantize_ modifies the model in-place (or returns a modified copy)
quantized_model = quantize_(model, quantizer)
print(f"\nQuantized model: {quantized_model}")
# Test with some dummy input
dummy_input = torch.randn(1, 10)
output = quantized_model(dummy_input)
print(f"\nOutput shape: {output.shape}")
assert isinstance(quantized_model.linear1, torch.nn.Module) # Verify structure