{"id":5522,"library":"torchax","title":"TorchAx: PyTorch on JAX/TPU Bridge","description":"torchax is a library that serves as a backend for PyTorch, enabling users to run PyTorch programs on JAX-supported hardware like Google Cloud TPUs. It provides seamless graph-level interoperability, allowing the mixing of JAX and PyTorch syntax within the same program, and leveraging JAX features such as `jax.grad`, Optax, and GSPMD for PyTorch model training. The current version is 0.0.11, with development active on GitHub.","status":"active","version":"0.0.11","language":"en","source_language":"en","source_url":"https://github.com/google/torchax","tags":["pytorch","jax","deep-learning","interoperability","tpu","gpu","accelerator"],"install":[{"cmd":"pip install torchax","lang":"bash","label":"Install torchax"},{"cmd":"# First, install PyTorch CPU:\npip install torch --index-url https://download.pytorch.org/whl/cpu # Linux\npip install torch # Mac\n\n# Then, install JAX for your accelerator:\npip install -U jax[tpu] # Google Cloud TPU\npip install -U jax[cuda12] # GPU machines\npip install -U jax # Linux CPU or Mac","lang":"bash","label":"Prerequisites: PyTorch and JAX"}],"dependencies":[{"reason":"torchax is a PyTorch backend/frontend for JAX, requiring PyTorch to function. Users must choose their desired PyTorch build (CPU, CUDA, etc.)","package":"torch","optional":false},{"reason":"torchax runs PyTorch models on JAX's backend, requiring JAX to be installed with the appropriate accelerator (TPU, CUDA, or CPU).","package":"jax","optional":false}],"imports":[{"note":"The primary import for enabling torchax functionality.","symbol":"torchax","correct":"import torchax"},{"note":"Activates torchax to intercept PyTorch operations and execute them via JAX. Should be called after model loading.","symbol":"enable_globally","correct":"torchax.enable_globally()"},{"note":"Helper class to wrap PyTorch models for JIT compilation with JAX.","symbol":"JittableModule","correct":"from torchax.interop import JittableModule"},{"note":"Decorator to compile PyTorch functions (taking/returning torch.Tensors) into JAX-compiled versions for performance.","symbol":"jax_jit","correct":"from torchax.interop import jax_jit"}],"quickstart":{"code":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchax\n\nclass MyModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.fc1 = nn.Linear(28 * 28, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        x = x.view(-1, 28 * 28)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n# Instantiate the PyTorch model\nm = MyModel()\n\n# IMPORTANT: Enable torchax GLOBALLY *after* model instantiation/loading\ntorchax.enable_globally()\n\n# Move the model to the 'jax' device\nm.to('jax')\n\n# Create input tensor on the 'jax' device\ninputs = torch.randn(3, 1, 28, 28, device='jax')\n\n# Run the model; operations will be executed by JAX\noutputs = m(inputs)\nprint(outputs.shape)\nprint(outputs.device)\n\n# Example with jax.jit for performance (using JittableModule)\nfrom torchax.interop import JittableModule\n\nm_jitted = JittableModule(m) # Wraps the model for JIT compilation\njitted_outputs = m_jitted(inputs)\nprint(jitted_outputs.shape)\nprint(jitted_outputs.device)","lang":"python","description":"This quickstart demonstrates how to run a standard PyTorch `nn.Module` using torchax. The key steps are to import `torchax`, call `torchax.enable_globally()` *after* model initialization, and then move the model and inputs to the 'jax' device. For improved performance, especially in production, `torchax.interop.JittableModule` (which leverages `jax.jit`) is recommended for compiling the model."},"warnings":[{"fix":"Ensure `torchax.enable_globally()` is called only after your `torch.nn.Module` has been initialized and its weights potentially loaded. For example, `model = MyModel(); torchax.enable_globally(); model.to('jax')`.","message":"Enabling `torchax.enable_globally()` before loading a PyTorch model can lead to errors, as it might intercept unsupported initialization operations. Always enable globally *after* the model has been fully loaded or instantiated.","severity":"gotcha","affected_versions":"All versions"},{"fix":"For performance-critical workloads, always use JAX's Just-In-Time (JIT) compilation. Wrap your model with `torchax.interop.JittableModule` or decorate functions with `torchax.interop.jax_jit` to compile the computation graph for faster execution. The first call will include compilation time, but subsequent calls will be much faster.","message":"Running `torchax` models in eager mode (without JAX JIT compilation) can be significantly slower than native PyTorch or JIT-compiled JAX execution. JAX's eager mode generally does not offer the same performance benefits as its compiled mode.","severity":"gotcha","affected_versions":"All versions"},{"fix":"For dynamic input shapes, consider using techniques like `StaticCache` (for Hugging Face models) or ensuring that varying dimensions are handled as static arguments (`static_argnums`) if using `jax.jit` directly. Alternatively, refactor the computation to minimize shape changes within JIT-compiled regions.","message":"JAX's JIT compilation specializes for fixed input shapes. If input shapes change between calls (common in scenarios like autoregressive text generation), JAX will recompile the graph, potentially leading to performance degradation worse than eager mode.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Utilize `torchax.interop.JittableModule` which handles this by passing weights as explicit arguments, or explicitly convert your PyTorch model to a functional form using `torch.func.functional_call` when interacting with JAX transforms.","message":"JAX transformations, including JIT, require functions to be 'pure' (i.e., all inputs passed as arguments, all outputs returned, no side effects or closure over mutable state). PyTorch `nn.Module.forward` implicitly closes over model weights. This can lead to unexpected behavior or performance issues with JAX.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Register custom types as JAX pytrees using `jax.tree_util.register_pytree_node`. Refer to JAX documentation and `torchax` examples for correct registration patterns.","message":"When interoperating with custom JAX types (e.g., specific output types from HuggingFace models like `CausalLMOutputWithPast`), these types might not be automatically recognized by JAX's pytree mechanism. This can cause `TypeError: ... is not a valid JAX type` errors.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-13T00:00:00.000Z","next_check":"2026-07-12T00:00:00.000Z"}