{"library":"jax","install":[{"cmd":"pip install jax[cpu]","imports":["import jax\nimport jax.numpy as jnp\n\n@jax.jit\ndef f(x, y):\n    return jnp.dot(x, y)\n\n# Or explicit keyword args:\njax.jit(f, static_argnums=(0,))","# Pure function — same inputs always give same outputs\n@jax.jit\ndef add(x, y):\n    return x + y"]},{"cmd":"pip install jax[cuda12]","imports":[]},{"cmd":"pip install jax[cuda13]","imports":[]},{"cmd":"pip install jax[tpu]","imports":[]}]}