{"library":"optax","title":"Optax","description":"Optax is a gradient processing and optimization library designed for JAX. It provides a rich set of optimizers (Adam, SGD, etc.), learning rate schedules, and gradient transformations that can be composed to build custom optimization pipelines. It's actively developed by DeepMind/Google, with the current stable version being 0.2.8, and follows a release cadence tied to JAX ecosystem developments, often releasing minor versions for bug fixes and new features.","language":"python","status":"active","last_verified":"Wed May 13","install":{"commands":["pip install optax jax jaxlib","pip install 'optax[accelerate]' jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"],"cli":null},"imports":["import optax","import optax\noptimizer = optax.adam(...)","from optax import chain\noptimizer = chain(...)"],"auth":{"required":false,"env_vars":[]},"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport optax\n\n# 1. Define a simple model and loss function\ndef model(params, x):\n    return params['w'] * x + params['b']\n\ndef loss_fn(params, x, y):\n    predictions = model(params, x)\n    return jnp.mean((predictions - y)**2)\n\n# 2. Initialize parameters\nkey = jax.random.PRNGKey(0)\nparams = {\n    'w': jax.random.normal(key, ()),\n    'b': jax.random.normal(key, ())\n}\n\n# 3. Choose an optimizer (e.g., Adam)\nlearning_rate = 0.01\noptimizer = optax.adam(learning_rate)\n\n# 4. Initialize optimizer state\nopt_state = optimizer.init(params)\n\n# 5. Sample data for training\nx_data = jnp.array([1.0, 2.0, 3.0, 4.0])\ny_data = jnp.array([2.0, 4.0, 6.0, 8.0]) # Target: y = 2x\n\n# 6. Define a single training step using JAX's jit for performance\n@jax.jit\ndef train_step(params, opt_state, x, y):\n    # Compute loss and gradients\n    loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)\n\n    # Compute updates from gradients and optimizer state\n    updates, new_opt_state = optimizer.update(grads, opt_state, params)\n\n    # Apply updates to parameters\n    new_params = optax.apply_updates(params, updates)\n\n    return new_params, new_opt_state, loss_value\n\n# 7. Training loop\nprint(f\"Initial parameters: {params}\")\nfor i in range(100):\n    params, opt_state, loss_value = train_step(params, opt_state, x_data, y_data)\n    if i % 20 == 0:\n        print(f\"Step {i}, Loss: {loss_value:.4f}\")\n\nprint(f\"Final parameters: {params}\")\n# Expected output for w: ~2.0, b: ~0.0","lang":"python","description":"This quickstart demonstrates a basic training loop with Optax and JAX. It defines a simple linear model and a mean squared error loss. An Adam optimizer is initialized, and a `train_step` function is created, leveraging `jax.grad` to compute gradients and Optax to apply updates. The `train_step` is `jax.jit`-compiled for efficiency. The loop iteratively updates parameters and optimizer state to minimize the loss.","tag":null,"tag_description":null,"last_tested":"2026-04-25","results":[{"runtime":"python:3.10-alpine","exit_code":1},{"runtime":"python:3.10-slim","exit_code":0},{"runtime":"python:3.11-alpine","exit_code":1},{"runtime":"python:3.11-slim","exit_code":0},{"runtime":"python:3.12-alpine","exit_code":1},{"runtime":"python:3.12-slim","exit_code":0},{"runtime":"python:3.13-alpine","exit_code":1},{"runtime":"python:3.13-slim","exit_code":0},{"runtime":"python:3.9-alpine","exit_code":1},{"runtime":"python:3.9-slim","exit_code":0}]},"compatibility":{"tag":null,"tag_description":null,"last_tested":"2026-05-13","installed_version":"0.2.4","pypi_latest":"0.2.8","is_stale":true,"summary":{"python_range":"3.10–3.9","success_rate":43,"avg_install_s":19.3,"avg_import_s":2.66,"wheel_type":"wheel"},"results":[{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"accelerate","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":13.3,"import_time_s":1.86,"mem_mb":38.4,"disk_size":"589M"},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":11.9,"import_time_s":1.18,"mem_mb":38.4,"disk_size":"589M"},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":1.25,"mem_mb":38.4,"disk_size":"589M"},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"accelerate","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":14.5,"import_time_s":3.72,"mem_mb":47.3,"disk_size":"638M"},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"accelerate","exit_code":0,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":3.24,"mem_mb":47.3,"disk_size":"638M"},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":12.2,"import_time_s":2.9,"mem_mb":47.3,"disk_size":"638M"},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":3.11,"mem_mb":47.3,"disk_size":"638M"},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"accelerate","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":12.9,"import_time_s":3.12,"mem_mb":46.3,"disk_size":"623M"},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":12.4,"import_time_s":2.94,"mem_mb":46.3,"disk_size":"623M"},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":3.35,"mem_mb":46.3,"disk_size":"623M"},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"accelerate","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":14.1,"import_time_s":2.94,"mem_mb":47.5,"disk_size":"621M"},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"accelerate","exit_code":0,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":3.6,"mem_mb":47.5,"disk_size":"621M"},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":12.4,"import_time_s":2.78,"mem_mb":47.5,"disk_size":"621M"},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":3.52,"mem_mb":47.5,"disk_size":"621M"},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":"build_error","import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"optax","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"accelerate","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":76,"import_time_s":1.98,"mem_mb":42.6,"disk_size":"4.8G"},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"accelerate","exit_code":1,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":"wheel","failure_reason":null,"import_side_effects":"clean","install_time_s":13.3,"import_time_s":1.86,"mem_mb":42.6,"disk_size":"561M"},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"optax","exit_code":0,"wheel_type":null,"failure_reason":null,"import_side_effects":null,"install_time_s":null,"import_time_s":1.9,"mem_mb":42.6,"disk_size":"561M"}]}}