TPU Inference for vLLM

0.13.3 · active · verified Thu Apr 16

tpu-inference is a hardware plugin for vLLM, designed to enable efficient inference of large language models (LLMs) on Google Cloud TPUs. It unifies JAX and PyTorch under a single lowering path, allowing PyTorch model definitions to run performantly on TPUs without additional code changes, while also extending native support to JAX. The library aims to push TPU hardware performance limits and retain vLLM's standardized user experience. It is actively maintained by the vLLM project and Google Cloud, with releases tied to vLLM development.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart verifies the successful installation and configuration of `vllm-tpu` in a Python environment. It checks the versions of `vLLM` and `tpu-inference`, confirms the detected vLLM platform, and lists available JAX devices, indicating the presence and readiness of TPU hardware. Actual model serving typically involves running the `vLLM` API server with `enable_tpu_vllm` set to true, requiring a Google Cloud TPU VM.

import os
# Ensure you have a Hugging Face token for model downloads if not public
# os.environ['HF_TOKEN'] = os.environ.get('HF_TOKEN', 'hf_...')

import jax
import vllm
import importlib.metadata
from vllm.platforms import current_platform

try:
    tpu_version = importlib.metadata.version("tpu_inference")
    print(f"vLLM version: {vllm.__version__}")
    print(f"tpu_inference version: {tpu_version}")
    print(f"vLLM platform: {current_platform.get_device_name()}")
    print(f"JAX backends: {jax.devices()}")

    # Example of how you would typically use vLLM with the TPU backend
    # This assumes a TPU VM environment and necessary model access.
    # For a full server quickstart, refer to vLLM-TPU documentation.
    # For instance:
    # from vllm.engine.arg_utils import AsyncEngineArgs
    # from vllm.engine.async_llm_engine import AsyncLLMEngine
    # engine_args = AsyncEngineArgs(model="google/gemma-2b-it", enable_tpu_vllm=True)
    # engine = AsyncLLMEngine.from_engine_args(engine_args)
    # print("vLLM engine with TPU backend initialized.")

except importlib.metadata.PackageNotFoundError:
    print("tpu-inference package not found. Please ensure it's installed in a TPU environment.")
except Exception as e:
    print(f"An error occurred during quickstart verification: {e}")
    print("Ensure you are running this in a Google Cloud TPU VM environment with appropriate setup.")

view raw JSON →