PyTorch/XLA
raw JSON → 2.9.0 verified Mon Apr 27 auth: no python
PyTorch/XLA is a Python package that bridges PyTorch with XLA devices (TPU, GPU, CPU) to enable high-performance machine learning. The current stable version is 2.9.0, with releases aligned to PyTorch minor versions. It supports Python 3.10-3.13 and provides both PJRT and XRT runtimes (PJRT recommended).
pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html Common errors
error RuntimeError: XLA device not found ↓
cause No XLA device (TPU/GPU) available or PJRT runtime not initialized properly.
fix
Verify you are on a TPU VM or have GPU with XLA support. If on TPU, run inside a TPU VM or set XRT_TPU_CONFIG correctly. For PJRT, ensure libtpu is installed and environment variable PJRT_DEVICE=TPU set.
error ImportError: cannot import name 'xla_model' from 'torch_xla' ↓
cause Incorrect import path for xla_model; it resides in torch_xla.core.xla_model.
fix
Use 'import torch_xla.core.xla_model as xm' instead of 'import torch_xla.xla_model as xm'.
error AssertionError: Torch not compiled with CUDA enabled ↓
cause torch_xla requires PyTorch built with CUDA support even when using TPU.
fix
Install the correct PyTorch wheel from pytorch.org that includes CUDA, or use a prebuilt TPU VM image that includes the right PyTorch version.
Warnings
breaking XRT runtime is deprecated and removed in PyTorch/XLA 2.1+. Use PJRT runtime for all new code. ↓
fix Ensure you are using PJRT (default). If you explicitly used XRT, switch to PJRT by not setting XRT runtime env vars.
gotcha Missing call to xm.mark_step() or xm.wait_device_ops() causes lazy execution to not materialize, leading to hangs or incorrect results. ↓
fix Call xm.mark_step() after each training step and xm.wait_device_ops() before measuring time or synchronizing.
deprecated Library installation extra [tpu] installs torch-xla but may conflict with newer versions of libtpu. Use separate pip install with index URL. ↓
fix Install libtpu directly: pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html for PyTorch/XLA 2.9+ it installs libtpu automatically.
Imports
- torch_xla wrong
from torch_xla import xla_modelcorrectimport torch_xla - torch_xla.core.xla_model wrong
import torch_xla.xla_model as xmcorrectimport torch_xla.core.xla_model as xm
Quickstart
import torch
import torch_xla
import torch_xla.core.xla_model as xm
# Get XLA device
device = xm.xla_device()
# Create tensor on XLA device
t = torch.randn(3, 3, device=device)
print(f"Tensor device: {t.device}")
# Perform operations
result = t + t
print(f"Result: {result}")
# Mark step and synchronize (required for XLA)
xm.mark_step()
xm.wait_device_ops()