TorchAx: PyTorch on JAX/TPU Bridge

0.0.11 · active · verified Mon Apr 13

torchax is a library that serves as a backend for PyTorch, enabling users to run PyTorch programs on JAX-supported hardware like Google Cloud TPUs. It provides seamless graph-level interoperability, allowing the mixing of JAX and PyTorch syntax within the same program, and leveraging JAX features such as `jax.grad`, Optax, and GSPMD for PyTorch model training. The current version is 0.0.11, with development active on GitHub.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to run a standard PyTorch `nn.Module` using torchax. The key steps are to import `torchax`, call `torchax.enable_globally()` *after* model initialization, and then move the model and inputs to the 'jax' device. For improved performance, especially in production, `torchax.interop.JittableModule` (which leverages `jax.jit`) is recommended for compiling the model.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchax

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Instantiate the PyTorch model
m = MyModel()

# IMPORTANT: Enable torchax GLOBALLY *after* model instantiation/loading
torchax.enable_globally()

# Move the model to the 'jax' device
m.to('jax')

# Create input tensor on the 'jax' device
inputs = torch.randn(3, 1, 28, 28, device='jax')

# Run the model; operations will be executed by JAX
outputs = m(inputs)
print(outputs.shape)
print(outputs.device)

# Example with jax.jit for performance (using JittableModule)
from torchax.interop import JittableModule

m_jitted = JittableModule(m) # Wraps the model for JIT compilation
jitted_outputs = m_jitted(inputs)
print(jitted_outputs.shape)
print(jitted_outputs.device)

view raw JSON →