{"id":2172,"library":"optax","title":"Optax","description":"Optax is a gradient processing and optimization library designed for JAX. It provides a rich set of optimizers (Adam, SGD, etc.), learning rate schedules, and gradient transformations that can be composed to build custom optimization pipelines. It's actively developed by DeepMind/Google, with the current stable version being 0.2.8, and follows a release cadence tied to JAX ecosystem developments, often releasing minor versions for bug fixes and new features.","status":"active","version":"0.2.8","language":"en","source_language":"en","source_url":"https://github.com/deepmind/optax","tags":["jax","optimization","machine-learning","gradients","deep-learning"],"install":[{"cmd":"pip install optax jax jaxlib","lang":"bash","label":"Basic Installation"},{"cmd":"pip install 'optax[accelerate]' jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html","lang":"bash","label":"CUDA-enabled Installation (example)"}],"dependencies":[],"imports":[{"symbol":"optax","correct":"import optax"},{"note":"Optimizers like adam are directly under the `optax` namespace, not a nested `optimizers` submodule.","wrong":"from optax import optimizers.adam","symbol":"adam","correct":"import optax\noptimizer = optax.adam(...)"},{"note":"`chain` is a function directly under the `optax` module, not a submodule.","wrong":"import optax.chain","symbol":"chain","correct":"from optax import chain\noptimizer = chain(...)"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport optax\n\n# 1. Define a simple model and loss function\ndef model(params, x):\n    return params['w'] * x + params['b']\n\ndef loss_fn(params, x, y):\n    predictions = model(params, x)\n    return jnp.mean((predictions - y)**2)\n\n# 2. Initialize parameters\nkey = jax.random.PRNGKey(0)\nparams = {\n    'w': jax.random.normal(key, ()),\n    'b': jax.random.normal(key, ())\n}\n\n# 3. Choose an optimizer (e.g., Adam)\nlearning_rate = 0.01\noptimizer = optax.adam(learning_rate)\n\n# 4. Initialize optimizer state\nopt_state = optimizer.init(params)\n\n# 5. Sample data for training\nx_data = jnp.array([1.0, 2.0, 3.0, 4.0])\ny_data = jnp.array([2.0, 4.0, 6.0, 8.0]) # Target: y = 2x\n\n# 6. Define a single training step using JAX's jit for performance\n@jax.jit\ndef train_step(params, opt_state, x, y):\n    # Compute loss and gradients\n    loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)\n\n    # Compute updates from gradients and optimizer state\n    updates, new_opt_state = optimizer.update(grads, opt_state, params)\n\n    # Apply updates to parameters\n    new_params = optax.apply_updates(params, updates)\n\n    return new_params, new_opt_state, loss_value\n\n# 7. Training loop\nprint(f\"Initial parameters: {params}\")\nfor i in range(100):\n    params, opt_state, loss_value = train_step(params, opt_state, x_data, y_data)\n    if i % 20 == 0:\n        print(f\"Step {i}, Loss: {loss_value:.4f}\")\n\nprint(f\"Final parameters: {params}\")\n# Expected output for w: ~2.0, b: ~0.0","lang":"python","description":"This quickstart demonstrates a basic training loop with Optax and JAX. It defines a simple linear model and a mean squared error loss. An Adam optimizer is initialized, and a `train_step` function is created, leveraging `jax.grad` to compute gradients and Optax to apply updates. The `train_step` is `jax.jit`-compiled for efficiency. The loop iteratively updates parameters and optimizer state to minimize the loss."},"warnings":[{"fix":"Ensure your `optimizer.update` calls pass the current model parameters as the third argument. If you're coming from an older version, add `params` to your `update` function signature.","message":"The signature of `optimizer.update` changed in Optax 0.1.x to 0.2.x to include `params` as the third argument: `optimizer.update(grads, opt_state, params)`. This is crucial for optimizers that might need current parameter values to compute updates (e.g., for weight decay).","severity":"breaking","affected_versions":"<0.2.0"},{"fix":"Always reassign the optimizer state and parameters after calling `optimizer.update` and `optax.apply_updates` respectively, as demonstrated in the quickstart example.","message":"Optax, like JAX, operates on immutable data structures. When you call `optimizer.update`, it returns a *new* optimizer state and *new* updates. You must reassign these values (`opt_state = new_opt_state`, `params = optax.apply_updates(params, updates)`) otherwise your training loop will not progress.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Ensure that your model parameters and the gradients produced by `jax.grad` are consistent PyTree structures. When composing optimizers or transformations, confirm that intermediate results maintain PyTree compatibility.","message":"Optax expects parameters, gradients, and optimizer states to be JAX PyTrees (e.g., nested dictionaries, lists, tuples, or custom types registered with `jax.tree_util`). Passing non-PyTree structures or incompatible types can lead to errors.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-09T00:00:00.000Z","next_check":"2026-07-08T00:00:00.000Z"}