DiffQ: Differentiable Quantization Framework for PyTorch
DiffQ is a differentiable quantization framework for PyTorch that provides tools to quantize PyTorch models, primarily focusing on large language models (LLMs) and computer vision models. It enables quantization-aware training and leverages various quantization methods like GPTQ, HQQ, and AWQ. Currently at version 0.2.4, it has seen active development, especially in late 2023, with periodic releases addressing new features and bug fixes.
Common errors
-
ModuleNotFoundError: No module named 'bitsandbytes'
cause You are attempting to use an 8-bit quantization method (often through `transformers` integration) which relies on the `bitsandbytes` library, but it is not installed or not correctly linked to your CUDA setup.fixInstall `bitsandbytes` with `pip install bitsandbytes`. Ensure your CUDA environment and GPU drivers are compatible, as `bitsandbytes` is often CUDA-specific. -
AttributeError: 'tuple' object has no attribute 'input_ids'
cause This error typically occurs when a `dataloader` provided to a quantizer (like `GPTQQuantizer`) does not yield data in the format expected by the model or the quantizer, particularly for language models expecting specific keys like `input_ids`.fixEnsure your `dataloader` yields data in the format expected by the model's `forward` method and the chosen quantizer. For many LLM examples, this means providing a tuple or dictionary with `input_ids` (e.g., `(input_ids,)` for a simple tuple, or `{'input_ids': input_ids}` for a dict). -
RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
cause This is a low-level CUDA error, often indicating an incompatibility between your installed PyTorch version, CUDA toolkit, GPU driver, or the specific `pytorch_quantization` version being used. Quantization operations are highly sensitive to the entire CUDA software stack.fixCarefully check the required PyTorch and CUDA versions for `pytorch_quantization` and `diffq`. Often, reinstalling PyTorch (e.g., `pip install torch==X.Y.Z+cuXXX -f https://download.pytorch.org/whl/torch_stable.html`) with the exact matching CUDA version is necessary.
Warnings
- gotcha Quantization libraries like `diffq` and its dependency `pytorch_quantization` are highly sensitive to PyTorch and CUDA version compatibility. Mismatches can lead to cryptic `RuntimeError`s, `cuDNN` errors, or unexpected behavior.
- gotcha Many advanced quantization methods (e.g., HQQ, AWQ, 8-bit quantization via `bitsandbytes`) require additional, often hardware-specific, optional dependencies. Forgetting to install these will result in `ModuleNotFoundError` or `ImportError` when attempting to use the corresponding methods.
- gotcha While `diffq` supports general PyTorch models, its primary examples and optimizations are often for large language models (LLMs) from the `transformers` library. Applying aggressive quantization to arbitrary custom models or models with complex, non-standard layers may require manual adaptations or might not yield optimal results.
Install
-
pip install diffq -
pip install diffq[hqq_ext] # for HQQ support pip install diffq[awq_ext] # for AWQ support pip install bitsandbytes # for 8-bit quantization
Imports
- DiffQModel
from diffq import DiffQModel
- BaseQuantizationConfig
from diffq import BaseQuantizationConfig
- GPTQQuantizer
from diffq.quantizers import GPTQQuantizer
- ViTForImageClassification
from diffq import ViTForImageClassification
from diffq.models import ViTForImageClassification
Quickstart
import torch
import torch.nn as nn
from diffq import DiffQModel, BaseQuantizationConfig
# 1. Define a simple PyTorch model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 1)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = SimpleModel()
# 2. Define a basic quantization configuration
# For actual quantization (e.g., 'gptq', 'hqq', 'awq'),
# additional steps with a specific quantizer (e.g., GPTQQuantizer)
# and a dataloader would be required.
quant_config = BaseQuantizationConfig(
quant_method="none", # Use "gptq", "hqq", "awq" for actual methods
w_bits=8,
w_group_size=128, # Not strictly applicable for "none", but often part of config
w_sym=False,
w_mse_scheme="per_tensor"
)
# 3. Convert the PyTorch model into a DiffQModel
# This automatically replaces modules with their quantized counterparts based on config.
diffq_model = DiffQModel(model, quantization_config=quant_config)
# Print the model structure to see the converted modules
print("Original model:")
print(model)
print("\nDiffQModel (converted structure):")
print(diffq_model)
# Example forward pass (will not perform actual quantization during inference
# without a preceding quantizer.quantize() call for methods like GPTQ)
dummy_input = torch.randn(1, 10)
output = diffq_model(dummy_input)
print(f"\nOutput shape: {output.shape}")