NVIDIA Model Optimizer
NVIDIA Model Optimizer (nvidia-modelopt) is an open toolkit designed to accelerate AI inference by applying state-of-the-art model optimization techniques such as quantization, pruning, and distillation. It primarily targets PyTorch and ONNX models, integrating directly into the training loop and enabling seamless deployment to NVIDIA's inference frameworks like TensorRT-LLM and TensorRT. The library is actively developed, with its current stable version being 0.42.0, and frequent pre-release candidates (e.g., 0.43.0rcX) indicating a rapid release cadence.
Warnings
- gotcha For full functionality, especially with pre-release versions or specific NVIDIA-optimized components, it is often necessary to install `nvidia-modelopt` using `--extra-index-url https://pypi.nvidia.com`. Without this, certain features or versions might not be available or compatible.
- breaking The `num_query_groups` parameter in Minitron pruning (specifically for `mcore_minitron`) was deprecated. If you relied on this for pruning, you might need to use an older version of ModelOpt.
- gotcha NVIDIA Model Optimizer has specific Python version requirements (currently Python >=3.10, <3.13). Using incompatible Python versions can lead to installation failures or runtime errors.
- gotcha The actual inference performance gains from model optimization (quantization, pruning, distillation) depend heavily on the downstream deployment framework (e.g., TensorRT-LLM, TensorRT) and the specific hardware configuration. `nvidia-modelopt` optimizes the model, but the runtime performance is realized by these specialized inference engines.
- gotcha When working with ONNX models, specific opset versions are required for certain quantization types (e.g., INT8 requires opset 13+, FP8 and INT4 require opset 21+). While ModelOpt can automatically upgrade lower opset versions, awareness of these requirements can prevent unexpected behavior or errors.
Install
-
pip install nvidia-modelopt -
pip install "nvidia-modelopt[all]" --extra-index-url https://pypi.nvidia.com
Imports
- NVIDIAModelOptConfig
from diffusers import NVIDIAModelOptConfig
- enable_huggingface_checkpointing
from modelopt.torch.opt import enable_huggingface_checkpointing
- quantization
import modelopt.torch.quantization as mtq
- export_hf_checkpoint
from modelopt.torch.export import export_hf_checkpoint
Quickstart
import torch
from diffusers import AutoModel, NVIDIAModelOptConfig
from modelopt.torch.opt import enable_huggingface_checkpointing
import os # Required for os.environ.get if needed for token, though not direct in this example
# Enable checkpointing for Hugging Face models
enable_huggingface_checkpointing()
# Define the model ID and data type
model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dtype = torch.bfloat16
# Define quantization configuration for FP8
# For simplicity, this example doesn't use os.environ.get as the model loading doesn't require explicit auth in this snippet.
# However, if your model required a Hugging Face token, you would pass token=os.environ.get('HF_TOKEN', '')
quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt")
# Load the model with quantization configuration
# In a real scenario, ensure your environment has the necessary NVIDIA drivers and CUDA setup.
try:
print(f"Attempting to load model {model_id} with FP8 quantization...")
model = AutoModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
)
print("Model loaded successfully with quantization enabled.")
# Example of a simple forward pass (replace with actual usage)
# dummy_input = torch.randn(1, 3, 224, 224, dtype=dtype, device='cuda')
# output = model(dummy_input)
# print("Forward pass successful.")
# To save the quantized model (requires a path)
# model.save_pretrained('path/to/sana_fp8', safe_serialization=False)
except Exception as e:
print(f"Error loading or processing model: {e}")
print("Ensure you have `diffusers` installed, a compatible GPU, and potentially `--extra-index-url https://pypi.nvidia.com` during installation if encountering issues.")