PyTorch/XLA

raw JSON →
2.9.0 verified Mon Apr 27 auth: no python

PyTorch/XLA is a Python package that bridges PyTorch with XLA devices (TPU, GPU, CPU) to enable high-performance machine learning. The current stable version is 2.9.0, with releases aligned to PyTorch minor versions. It supports Python 3.10-3.13 and provides both PJRT and XRT runtimes (PJRT recommended).

pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
error RuntimeError: XLA device not found
cause No XLA device (TPU/GPU) available or PJRT runtime not initialized properly.
fix
Verify you are on a TPU VM or have GPU with XLA support. If on TPU, run inside a TPU VM or set XRT_TPU_CONFIG correctly. For PJRT, ensure libtpu is installed and environment variable PJRT_DEVICE=TPU set.
error ImportError: cannot import name 'xla_model' from 'torch_xla'
cause Incorrect import path for xla_model; it resides in torch_xla.core.xla_model.
fix
Use 'import torch_xla.core.xla_model as xm' instead of 'import torch_xla.xla_model as xm'.
error AssertionError: Torch not compiled with CUDA enabled
cause torch_xla requires PyTorch built with CUDA support even when using TPU.
fix
Install the correct PyTorch wheel from pytorch.org that includes CUDA, or use a prebuilt TPU VM image that includes the right PyTorch version.
breaking XRT runtime is deprecated and removed in PyTorch/XLA 2.1+. Use PJRT runtime for all new code.
fix Ensure you are using PJRT (default). If you explicitly used XRT, switch to PJRT by not setting XRT runtime env vars.
gotcha Missing call to xm.mark_step() or xm.wait_device_ops() causes lazy execution to not materialize, leading to hangs or incorrect results.
fix Call xm.mark_step() after each training step and xm.wait_device_ops() before measuring time or synchronizing.
deprecated Library installation extra [tpu] installs torch-xla but may conflict with newer versions of libtpu. Use separate pip install with index URL.
fix Install libtpu directly: pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html for PyTorch/XLA 2.9+ it installs libtpu automatically.

Basic example: get XLA device, create tensor, run ops, mark step.

import torch
import torch_xla
import torch_xla.core.xla_model as xm

# Get XLA device
device = xm.xla_device()

# Create tensor on XLA device
t = torch.randn(3, 3, device=device)
print(f"Tensor device: {t.device}")

# Perform operations
result = t + t
print(f"Result: {result}")

# Mark step and synchronize (required for XLA)
xm.mark_step()
xm.wait_device_ops()