{"library":"jax","type":"library","category":null,"description":"Composable transformations of Python+NumPy: differentiate (jax.grad), compile (jax.jit), vectorize (jax.vmap), parallelize (jax.shard_map). Current version is 0.9.2 (Mar 2026). Requires Python >=3.11. Install requires extras — bare pip install jax gives CPU-only minimal build.","language":"python","status":"active","version":"0.9.2","tags":["deep-learning","ml","gpu","tpu","autodiff","jit","numpy","xla"],"last_verified":"Tue Jun 09","install":[{"cmd":"pip install jax[cpu]","imports":["import jax\nimport jax.numpy as jnp\n\n@jax.jit\ndef f(x, y):\n    return jnp.dot(x, y)\n\n# Or explicit keyword args:\njax.jit(f, static_argnums=(0,))","# Pure function — same inputs always give same outputs\n@jax.jit\ndef add(x, y):\n    return x + y"]},{"cmd":"pip install jax[cuda12]","imports":[]},{"cmd":"pip install jax[cuda13]","imports":[]},{"cmd":"pip install jax[tpu]","imports":[]}],"homepage":"https://jax.readthedocs.io","github":"https://github.com/jax-ml/jax","docs":null,"changelog":null,"pypi":"https://pypi.org/project/jax/","npm":null,"openapi_spec":null,"status_page":null,"smithery":null,"compatibility":{"summary":{"python_range":"3.10–3.9","success_rate":55,"avg_install_s":32.8,"avg_import_s":2,"wheel_type":"wheel"},"url":"https://checklist.day/v1/registry/jax/compatibility"}}