Transformer Engine (CUDA 12)
Transformer Engine (TE) is a Python library by NVIDIA for accelerating Transformer models on NVIDIA GPUs. It enables lower precision training and inference, notably supporting 8-bit (FP8) and 4-bit (NVFP4) floating point precision on Hopper, Ada, and Blackwell GPUs, leading to better performance and reduced memory utilization. It provides highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API for PyTorch and JAX. The current version is 2.13.0, with an active release cadence, often aligning with new NVIDIA hardware and software advancements.
Common errors
-
ModuleNotFoundError: No module named 'transformer_engine.pytorch'
cause The framework-specific bindings for PyTorch (or JAX) were not installed. Installing `transformer-engine-cu12` by itself only provides the core library, not the Python bindings for deep learning frameworks.fixInstall with the appropriate extra dependency: `pip install --no-build-isolation transformer-engine-cu12[pytorch]` (for PyTorch) or `pip install --no-build-isolation transformer-engine-cu12[jax]` (for JAX). -
ImportError: undefined symbol: _ZN3c106cuda9SetDeviceEi
cause This error typically indicates an C++ ABI incompatibility between PyTorch and Transformer Engine. They were compiled with different C++ standards or settings.fixVerify that both PyTorch and Transformer Engine are built with compatible C++ ABIs. The simplest solution is often to use the NVIDIA NGC PyTorch or JAX Docker containers, which come pre-configured with compatible versions. If installing from source, ensure consistent compiler flags. -
fatal error: cudnn.h: No such file or directory
cause The CUDNN headers are not found by the build system during installation. This can happen if cuDNN is not installed, or its path is not correctly exposed to the build environment.fixInstall cuDNN 9.3+ and ensure that environment variables like `CUDNN_PATH`, `CUDNN_HOME`, and `LD_LIBRARY_PATH` correctly point to your cuDNN installation. For example: `export CUDNN_PATH=/path/to/cudnn`, `export CUDNN_HOME=$CUDNN_PATH`, `export LD_LIBRARY_PATH=$CUDNN_PATH/lib:$LD_LIBRARY_PATH`. -
ERROR: Failed building wheel for transformer-engine
cause This generic error during `pip install` often masks underlying issues with CMake, CUDA Toolkit, `nvcc` path, or FlashAttention compilation resource intensity.fixFirst, ensure CUDA Toolkit (12.1+), NVIDIA drivers, and cuDNN (9.3+) are correctly installed and configured. Check that `nvcc` is in your `PATH` or `CUDA_PATH` environment variable is set (e.g., `export CUDA_PATH=/usr/local/cuda`). If the error persists, especially when building FlashAttention, try `export MAX_JOBS=1` before installation to reduce memory usage during compilation: `MAX_JOBS=1 pip install --no-build-isolation transformer-engine-cu12[pytorch]`.
Warnings
- breaking Breaking changes in `InferenceParams` and removal of the `interval` argument for `DelayedScaling` in PyTorch. `num_heads_kv`, `head_dim_k`, and `dtype` are now required for `InferenceParams` initialization, and `pre_step` must be called.
- breaking The deprecated packed fused attention C APIs (`nvte_fused_attn_{fwd,bwd}_{qkvpacked,kvpacked}`) have been removed. Users must migrate to the non-packed API variants.
- deprecated The installation of Transformer Engine now requires the `--no-build-isolation` flag when using PyPI or building from source. Support for installations *with* build isolation will be removed in a future release.
- gotcha ABI compatibility issues can arise if PyTorch and Transformer Engine are built with different C++ ABI settings, especially outside of NGC containers. This leads to `ImportError` with undefined symbols.
- gotcha Installing `transformer-engine-cu12` via PyPI may crash in environments with CUDA version < 12.8, despite the `cu12` suffix implying CUDA 12 support generally.
Install
-
pip install --no-build-isolation transformer-engine-cu12[pytorch] -
pip install --no-build-isolation transformer-engine-cu12[jax] -
pip install --no-build-isolation transformer-engine-cu12[core]
Imports
- Linear
from transformer_engine.pytorch import Linear
- LayerNorm
from transformer_engine.pytorch import LayerNorm
- TransformerLayer
from transformer_engine.pytorch import TransformerLayer
- fp8_autocast
from transformer_engine.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_autocast
Quickstart
import torch
from transformer_engine.pytorch import Linear, fp8_autocast
# Dummy input tensor
input_tensor = torch.randn(16, 128, device='cuda', dtype=torch.float16)
# Initialize a Transformer Engine Linear layer
te_linear_layer = Linear(128, 256, bias=True, dtype=torch.float16).cuda()
# Perform a forward pass with FP8 autocasting
with fp8_autocast():
output_tensor = te_linear_layer(input_tensor)
print(f"Input shape: {input_tensor.shape}, dtype: {input_tensor.dtype}")
print(f"Output shape: {output_tensor.shape}, dtype: {output_tensor.dtype}")
assert output_tensor.dtype == torch.float8_e4m3fn or output_tensor.dtype == torch.float8_e5m2, "Output should be FP8 or similar based on precision policy."
print("Quickstart example ran successfully with FP8 autocasting.")