{"id":9367,"library":"tpu-inference","title":"TPU Inference for vLLM","description":"tpu-inference is a hardware plugin for vLLM, designed to enable efficient inference of large language models (LLMs) on Google Cloud TPUs. It unifies JAX and PyTorch under a single lowering path, allowing PyTorch model definitions to run performantly on TPUs without additional code changes, while also extending native support to JAX. The library aims to push TPU hardware performance limits and retain vLLM's standardized user experience. It is actively maintained by the vLLM project and Google Cloud, with releases tied to vLLM development.","status":"active","version":"0.13.3","language":"en","source_language":"en","source_url":"https://github.com/vllm-project/tpu-inference","tags":["TPU","inference","vLLM","JAX","PyTorch","LLM","Google Cloud","AI/ML acceleration"],"install":[{"cmd":"pip install vllm-tpu","lang":"bash","label":"Recommended installation"},{"cmd":"uv pip install vllm-tpu","lang":"bash","label":"Faster installation with uv"}],"dependencies":[{"reason":"tpu-inference is a plugin for vLLM and requires it for operation.","package":"vllm","optional":false},{"reason":"Provides core JAX functionalities; often used with TPUs.","package":"jax","optional":true},{"reason":"Provides PyTorch functionalities; often used with TPUs.","package":"torch","optional":true}],"imports":[{"note":"tpu-inference is primarily an underlying plugin for vLLM. Direct top-level imports for user-facing classes are not common. Its presence is typically checked via metadata.","symbol":"tpu_inference","correct":"import importlib.metadata\ntpu_version = importlib.metadata.version(\"tpu_inference\")"}],"quickstart":{"code":"import os\n# Ensure you have a Hugging Face token for model downloads if not public\n# os.environ['HF_TOKEN'] = os.environ.get('HF_TOKEN', 'hf_...')\n\nimport jax\nimport vllm\nimport importlib.metadata\nfrom vllm.platforms import current_platform\n\ntry:\n    tpu_version = importlib.metadata.version(\"tpu_inference\")\n    print(f\"vLLM version: {vllm.__version__}\")\n    print(f\"tpu_inference version: {tpu_version}\")\n    print(f\"vLLM platform: {current_platform.get_device_name()}\")\n    print(f\"JAX backends: {jax.devices()}\")\n\n    # Example of how you would typically use vLLM with the TPU backend\n    # This assumes a TPU VM environment and necessary model access.\n    # For a full server quickstart, refer to vLLM-TPU documentation.\n    # For instance:\n    # from vllm.engine.arg_utils import AsyncEngineArgs\n    # from vllm.engine.async_llm_engine import AsyncLLMEngine\n    # engine_args = AsyncEngineArgs(model=\"google/gemma-2b-it\", enable_tpu_vllm=True)\n    # engine = AsyncLLMEngine.from_engine_args(engine_args)\n    # print(\"vLLM engine with TPU backend initialized.\")\n\nexcept importlib.metadata.PackageNotFoundError:\n    print(\"tpu-inference package not found. Please ensure it's installed in a TPU environment.\")\nexcept Exception as e:\n    print(f\"An error occurred during quickstart verification: {e}\")\n    print(\"Ensure you are running this in a Google Cloud TPU VM environment with appropriate setup.\")\n","lang":"python","description":"This quickstart verifies the successful installation and configuration of `vllm-tpu` in a Python environment. It checks the versions of `vLLM` and `tpu-inference`, confirms the detected vLLM platform, and lists available JAX devices, indicating the presence and readiness of TPU hardware. Actual model serving typically involves running the `vLLM` API server with `enable_tpu_vllm` set to true, requiring a Google Cloud TPU VM."},"warnings":[{"fix":"Provision a Google Cloud TPU VM (e.g., v5e, v6e, v7x) as per the official documentation for optimal performance and stability. For Ironwood (TPU7x), Google Kubernetes Engine (GKE) is required.","message":"TPU-inference is designed for specific TPU generations. Recommended versions are v7x, v5e, v6e. Older versions (v3, v4, v5p) are experimental. Ensure your Cloud TPU VM uses a compatible generation.","severity":"gotcha","affected_versions":"All"},{"fix":"For PyTorch, consider refactoring generation to use a manual decode loop with `torch_xla.core.xla_model.mark_step()` for explicit execution. Aim for static shapes (fixed batch/sequence length) and avoid dynamic branching. Converting models to JAX-compatible checkpoints might offer better performance.","message":"When using PyTorch with `tpu-inference` via `vLLM`, direct `model.generate()` calls can be significantly slower than expected due to lazy graph execution and dynamic control flow in PyTorch/XLA. TPUs generally favor static computation graphs, which JAX utilizes more inherently.","severity":"gotcha","affected_versions":"All"},{"fix":"Understand your typical request patterns and model sequence lengths. While `vLLM` handles this automatically, being aware of the bucketization can help in performance analysis and potentially in optimizing input lengths or batching strategies.","message":"vLLM on TPUs uses a bucketization strategy for sequence lengths. Requests are rounded up to the nearest bucket size (e.g., a 176-token request might be treated as 256 tokens). This can lead to wasted computation and increased latency for workloads with varying or long sequence lengths if not accounted for.","severity":"gotcha","affected_versions":"All"},{"fix":"Plan inference strategies without relying on speculative decoding for TPU deployments with vLLM. Monitor vLLM and tpu-inference release notes for future support.","message":"Speculative decoding, a technique to accelerate LLM inference, is currently not supported for TPUs when using vLLM.","severity":"gotcha","affected_versions":"All"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Implement a manual decode loop and explicitly force execution with `torch_xla.core.xla_model.mark_step()` after each token generation. Optimize for static input shapes and minimize dynamic control flow within the loop. Consider using a JAX-based model if possible.","cause":"PyTorch's dynamic computational graphs and lazy execution model can interact poorly with XLA (the compiler for TPUs) when performing auto-regressive decoding, leading to frequent recompilations and performance bottlenecks.","error":"torch.autograd.set_detect_anomaly(True) or similar 'stuck' behavior during PyTorch model.generate() on TPU."},{"fix":"Ensure `vllm-tpu` (which includes `tpu-inference` as a dependency) is installed in the active Python environment. If using a Google Cloud TPU VM, ensure the installation steps from the official documentation are followed, preferably within a virtual environment. Use `pip install vllm-tpu` or `uv pip install vllm-tpu`.","cause":"The `tpu-inference` Python package is not installed or the environment where the check is performed is not the same as where it was installed.","error":"PackageNotFoundError: No package metadata found for tpu_inference"},{"fix":"Ensure your code is executed on a properly configured Google Cloud TPU VM. Verify the TPU runtime environment setup (e.g., `PJRT_DEVICE` environment variable, `libtpu` availability) as per Google Cloud and vLLM-TPU documentation. Access to a TPU VM and sufficient quota are prerequisites.","cause":"The Python environment cannot detect an active TPU device, often because the code is not running on a Google Cloud TPU VM, or the necessary underlying drivers/setup are incomplete.","error":"JAX backends: [] (or only CPU/GPU devices listed)"},{"fix":"Use a smaller model or a TPU generation with more memory if available. Employ model quantization or pruning techniques to reduce memory footprint. Consider model partitioning or pipeline parallelism if supported for the specific model and TPU setup, though vLLM handles much of this automatically.","cause":"The model, or parts of it, exceeds the on-chip memory available on the TPU. This forces data to be moved between the TPU and host memory, introducing significant latency.","error":"OutOfMemoryError or performance degradation characterized by spikes in host memory usage during inference."}]}