TorchAx: PyTorch on JAX/TPU Bridge
torchax is a library that serves as a backend for PyTorch, enabling users to run PyTorch programs on JAX-supported hardware like Google Cloud TPUs. It provides seamless graph-level interoperability, allowing the mixing of JAX and PyTorch syntax within the same program, and leveraging JAX features such as `jax.grad`, Optax, and GSPMD for PyTorch model training. The current version is 0.0.11, with development active on GitHub.
Warnings
- gotcha Enabling `torchax.enable_globally()` before loading a PyTorch model can lead to errors, as it might intercept unsupported initialization operations. Always enable globally *after* the model has been fully loaded or instantiated.
- gotcha Running `torchax` models in eager mode (without JAX JIT compilation) can be significantly slower than native PyTorch or JIT-compiled JAX execution. JAX's eager mode generally does not offer the same performance benefits as its compiled mode.
- gotcha JAX's JIT compilation specializes for fixed input shapes. If input shapes change between calls (common in scenarios like autoregressive text generation), JAX will recompile the graph, potentially leading to performance degradation worse than eager mode.
- gotcha JAX transformations, including JIT, require functions to be 'pure' (i.e., all inputs passed as arguments, all outputs returned, no side effects or closure over mutable state). PyTorch `nn.Module.forward` implicitly closes over model weights. This can lead to unexpected behavior or performance issues with JAX.
- gotcha When interoperating with custom JAX types (e.g., specific output types from HuggingFace models like `CausalLMOutputWithPast`), these types might not be automatically recognized by JAX's pytree mechanism. This can cause `TypeError: ... is not a valid JAX type` errors.
Install
-
pip install torchax -
# First, install PyTorch CPU: pip install torch --index-url https://download.pytorch.org/whl/cpu # Linux pip install torch # Mac # Then, install JAX for your accelerator: pip install -U jax[tpu] # Google Cloud TPU pip install -U jax[cuda12] # GPU machines pip install -U jax # Linux CPU or Mac
Imports
- torchax
import torchax
- enable_globally
torchax.enable_globally()
- JittableModule
from torchax.interop import JittableModule
- jax_jit
from torchax.interop import jax_jit
Quickstart
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchax
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Instantiate the PyTorch model
m = MyModel()
# IMPORTANT: Enable torchax GLOBALLY *after* model instantiation/loading
torchax.enable_globally()
# Move the model to the 'jax' device
m.to('jax')
# Create input tensor on the 'jax' device
inputs = torch.randn(3, 1, 28, 28, device='jax')
# Run the model; operations will be executed by JAX
outputs = m(inputs)
print(outputs.shape)
print(outputs.device)
# Example with jax.jit for performance (using JittableModule)
from torchax.interop import JittableModule
m_jitted = JittableModule(m) # Wraps the model for JIT compilation
jitted_outputs = m_jitted(inputs)
print(jitted_outputs.shape)
print(jitted_outputs.device)