TPU Inference for vLLM
tpu-inference is a hardware plugin for vLLM, designed to enable efficient inference of large language models (LLMs) on Google Cloud TPUs. It unifies JAX and PyTorch under a single lowering path, allowing PyTorch model definitions to run performantly on TPUs without additional code changes, while also extending native support to JAX. The library aims to push TPU hardware performance limits and retain vLLM's standardized user experience. It is actively maintained by the vLLM project and Google Cloud, with releases tied to vLLM development.
Common errors
-
torch.autograd.set_detect_anomaly(True) or similar 'stuck' behavior during PyTorch model.generate() on TPU.
cause PyTorch's dynamic computational graphs and lazy execution model can interact poorly with XLA (the compiler for TPUs) when performing auto-regressive decoding, leading to frequent recompilations and performance bottlenecks.fixImplement a manual decode loop and explicitly force execution with `torch_xla.core.xla_model.mark_step()` after each token generation. Optimize for static input shapes and minimize dynamic control flow within the loop. Consider using a JAX-based model if possible. -
PackageNotFoundError: No package metadata found for tpu_inference
cause The `tpu-inference` Python package is not installed or the environment where the check is performed is not the same as where it was installed.fixEnsure `vllm-tpu` (which includes `tpu-inference` as a dependency) is installed in the active Python environment. If using a Google Cloud TPU VM, ensure the installation steps from the official documentation are followed, preferably within a virtual environment. Use `pip install vllm-tpu` or `uv pip install vllm-tpu`. -
JAX backends: [] (or only CPU/GPU devices listed)
cause The Python environment cannot detect an active TPU device, often because the code is not running on a Google Cloud TPU VM, or the necessary underlying drivers/setup are incomplete.fixEnsure your code is executed on a properly configured Google Cloud TPU VM. Verify the TPU runtime environment setup (e.g., `PJRT_DEVICE` environment variable, `libtpu` availability) as per Google Cloud and vLLM-TPU documentation. Access to a TPU VM and sufficient quota are prerequisites. -
OutOfMemoryError or performance degradation characterized by spikes in host memory usage during inference.
cause The model, or parts of it, exceeds the on-chip memory available on the TPU. This forces data to be moved between the TPU and host memory, introducing significant latency.fixUse a smaller model or a TPU generation with more memory if available. Employ model quantization or pruning techniques to reduce memory footprint. Consider model partitioning or pipeline parallelism if supported for the specific model and TPU setup, though vLLM handles much of this automatically.
Warnings
- gotcha TPU-inference is designed for specific TPU generations. Recommended versions are v7x, v5e, v6e. Older versions (v3, v4, v5p) are experimental. Ensure your Cloud TPU VM uses a compatible generation.
- gotcha When using PyTorch with `tpu-inference` via `vLLM`, direct `model.generate()` calls can be significantly slower than expected due to lazy graph execution and dynamic control flow in PyTorch/XLA. TPUs generally favor static computation graphs, which JAX utilizes more inherently.
- gotcha vLLM on TPUs uses a bucketization strategy for sequence lengths. Requests are rounded up to the nearest bucket size (e.g., a 176-token request might be treated as 256 tokens). This can lead to wasted computation and increased latency for workloads with varying or long sequence lengths if not accounted for.
- gotcha Speculative decoding, a technique to accelerate LLM inference, is currently not supported for TPUs when using vLLM.
Install
-
pip install vllm-tpu -
uv pip install vllm-tpu
Imports
- tpu_inference
import importlib.metadata tpu_version = importlib.metadata.version("tpu_inference")
Quickstart
import os
# Ensure you have a Hugging Face token for model downloads if not public
# os.environ['HF_TOKEN'] = os.environ.get('HF_TOKEN', 'hf_...')
import jax
import vllm
import importlib.metadata
from vllm.platforms import current_platform
try:
tpu_version = importlib.metadata.version("tpu_inference")
print(f"vLLM version: {vllm.__version__}")
print(f"tpu_inference version: {tpu_version}")
print(f"vLLM platform: {current_platform.get_device_name()}")
print(f"JAX backends: {jax.devices()}")
# Example of how you would typically use vLLM with the TPU backend
# This assumes a TPU VM environment and necessary model access.
# For a full server quickstart, refer to vLLM-TPU documentation.
# For instance:
# from vllm.engine.arg_utils import AsyncEngineArgs
# from vllm.engine.async_llm_engine import AsyncLLMEngine
# engine_args = AsyncEngineArgs(model="google/gemma-2b-it", enable_tpu_vllm=True)
# engine = AsyncLLMEngine.from_engine_args(engine_args)
# print("vLLM engine with TPU backend initialized.")
except importlib.metadata.PackageNotFoundError:
print("tpu-inference package not found. Please ensure it's installed in a TPU environment.")
except Exception as e:
print(f"An error occurred during quickstart verification: {e}")
print("Ensure you are running this in a Google Cloud TPU VM environment with appropriate setup.")