{"library":"torchax","type":"library","category":null,"description":"torchax is a library that serves as a backend for PyTorch, enabling users to run PyTorch programs on JAX-supported hardware like Google Cloud TPUs. It provides seamless graph-level interoperability, allowing the mixing of JAX and PyTorch syntax within the same program, and leveraging JAX features such as `jax.grad`, Optax, and GSPMD for PyTorch model training. The current version is 0.0.11, with development active on GitHub.","language":"python","status":"active","version":"0.0.11","tags":["pytorch","jax","deep-learning","interoperability","tpu","gpu","accelerator"],"last_verified":"Sat May 23","install":[{"cmd":"pip install torchax","imports":["import torchax","torchax.enable_globally()","from torchax.interop import JittableModule","from torchax.interop import jax_jit"]},{"cmd":"# First, install PyTorch CPU:\npip install torch --index-url https://download.pytorch.org/whl/cpu # Linux\npip install torch # Mac\n\n# Then, install JAX for your accelerator:\npip install -U jax[tpu] # Google Cloud TPU\npip install -U jax[cuda12] # GPU machines\npip install -U jax # Linux CPU or Mac","imports":[]}],"homepage":null,"github":"https://github.com/google/torchax","docs":null,"changelog":null,"pypi":"https://pypi.org/project/torchax/","npm":null,"openapi_spec":null,"status_page":null,"smithery":null,"compatibility":{"summary":{"python_range":"3.10–3.9","success_rate":80,"avg_install_s":1.6,"avg_import_s":null,"wheel_type":"wheel"},"url":"https://checklist.day/v1/registry/torchax/compatibility"}}