DiffQ: Differentiable Quantization Framework for PyTorch

0.2.4 · active · verified Thu Apr 16

DiffQ is a differentiable quantization framework for PyTorch that provides tools to quantize PyTorch models, primarily focusing on large language models (LLMs) and computer vision models. It enables quantization-aware training and leverages various quantization methods like GPTQ, HQQ, and AWQ. Currently at version 0.2.4, it has seen active development, especially in late 2023, with periodic releases addressing new features and bug fixes.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to convert a standard PyTorch model into a `DiffQModel` using a `BaseQuantizationConfig`. While this example uses `quant_method="none"` for structural conversion, for actual quantization (e.g., 4-bit, 8-bit), you would specify a method like `"gptq"` and then run a specific quantizer (e.g., `GPTQQuantizer`) with a dataloader.

import torch
import torch.nn as nn
from diffq import DiffQModel, BaseQuantizationConfig

# 1. Define a simple PyTorch model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(20, 1)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

model = SimpleModel()

# 2. Define a basic quantization configuration
# For actual quantization (e.g., 'gptq', 'hqq', 'awq'),
# additional steps with a specific quantizer (e.g., GPTQQuantizer)
# and a dataloader would be required.
quant_config = BaseQuantizationConfig(
    quant_method="none", # Use "gptq", "hqq", "awq" for actual methods
    w_bits=8,
    w_group_size=128, # Not strictly applicable for "none", but often part of config
    w_sym=False,
    w_mse_scheme="per_tensor"
)

# 3. Convert the PyTorch model into a DiffQModel
# This automatically replaces modules with their quantized counterparts based on config.
diffq_model = DiffQModel(model, quantization_config=quant_config)

# Print the model structure to see the converted modules
print("Original model:")
print(model)
print("\nDiffQModel (converted structure):")
print(diffq_model)

# Example forward pass (will not perform actual quantization during inference
# without a preceding quantizer.quantize() call for methods like GPTQ)
dummy_input = torch.randn(1, 10)
output = diffq_model(dummy_input)
print(f"\nOutput shape: {output.shape}")

view raw JSON →