Optax
Optax is a gradient processing and optimization library designed for JAX. It provides a rich set of optimizers (Adam, SGD, etc.), learning rate schedules, and gradient transformations that can be composed to build custom optimization pipelines. It's actively developed by DeepMind/Google, with the current stable version being 0.2.8, and follows a release cadence tied to JAX ecosystem developments, often releasing minor versions for bug fixes and new features.
Common errors
-
TypeError: 'type' object is not subscriptable
cause This error occurs in Python versions older than 3.9, where type hints like `tuple[float, float]` (PEP 585 generics) are not supported in `optax`'s code.fixUpgrade your Python environment to version 3.9 or newer. Alternatively, if modifying the `optax` source is an option, change `tuple[...]` to `typing.Tuple[...]` where this error occurs. -
TypeError: zeros_like requires ndarray or scalar arguments
cause Optax optimizers are designed to operate on JAX arrays, specifically floating-point types. This error occurs when `optimizer.init()` is called with a PyTree (e.g., a model) that contains non-JAX array elements (like Python objects or other data types) that Optax cannot process for optimization.fixEnsure that only JAX arrays are passed to `optimizer.init()`. If using Equinox, filter the model parameters using `eqx.filter(model, eqx.is_inexact_array)` to include only the optimizable floating-point arrays. -
AttributeError: module 'jax.interpreters.xla' has no attribute 'DeviceArray'
cause This error signifies an incompatibility between the installed `optax` version and your JAX installation. Newer JAX versions deprecated `jax.interpreters.xla.DeviceArray` in favor of `jax.Array`, and older `optax` versions might still reference the deprecated attribute.fixUpdate both `optax` and JAX to their latest compatible versions using `pip install --upgrade optax jax jaxlib` to ensure API consistency. -
TypeError: true_fun and false_fun output must have identical types, got (...)
cause This JAX error, frequently encountered with `optax` transformations that involve conditional logic (like `optax.MultiSteps` using `jax.lax.cond`), occurs when the true and false branches of the conditional operation produce PyTrees (e.g., optimizer states or parameter updates) with differing structures or inconsistent data types (dtypes).fixEnsure that the PyTree structures and dtypes of the outputs from both branches of the conditional function are exactly identical. This often requires explicitly casting or ensuring consistent dtypes (e.g., `jnp.float32`, `jnp.bfloat16`) across all relevant components. -
ValueError: Expected dict, got Traced<ShapedArray(...)>
cause This error indicates that an `optax` function, such as `optimizer.update()` or `optax.apply_updates()`, expected a dictionary-like PyTree structure for parameters or gradients but received a different type, often a JAX `Traced<ShapedArray>` directly or an unexpected PyTree structure from another library (like Flax or Brax) within JAX's tracing context.fixVerify that the PyTree structure of your parameters and gradients aligns with what `optax` expects, typically a nested dictionary-like structure. If integrating with other JAX-based libraries, ensure proper parameter conversion or unwrapping before passing to `optax` functions. Upgrading `optax` may also resolve known issues with specific optimizers and parameter types.
Warnings
- breaking The signature of `optimizer.update` changed in Optax 0.1.x to 0.2.x to include `params` as the third argument: `optimizer.update(grads, opt_state, params)`. This is crucial for optimizers that might need current parameter values to compute updates (e.g., for weight decay).
- gotcha Optax, like JAX, operates on immutable data structures. When you call `optimizer.update`, it returns a *new* optimizer state and *new* updates. You must reassign these values (`opt_state = new_opt_state`, `params = optax.apply_updates(params, updates)`) otherwise your training loop will not progress.
- gotcha Optax expects parameters, gradients, and optimizer states to be JAX PyTrees (e.g., nested dictionaries, lists, tuples, or custom types registered with `jax.tree_util`). Passing non-PyTree structures or incompatible types can lead to errors.
Install
-
pip install optax jax jaxlib -
pip install 'optax[accelerate]' jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Imports
- optax
import optax
- adam
from optax import optimizers.adam
import optax optimizer = optax.adam(...)
- chain
import optax.chain
from optax import chain optimizer = chain(...)
Quickstart
import jax
import jax.numpy as jnp
import optax
# 1. Define a simple model and loss function
def model(params, x):
return params['w'] * x + params['b']
def loss_fn(params, x, y):
predictions = model(params, x)
return jnp.mean((predictions - y)**2)
# 2. Initialize parameters
key = jax.random.PRNGKey(0)
params = {
'w': jax.random.normal(key, ()),
'b': jax.random.normal(key, ())
}
# 3. Choose an optimizer (e.g., Adam)
learning_rate = 0.01
optimizer = optax.adam(learning_rate)
# 4. Initialize optimizer state
opt_state = optimizer.init(params)
# 5. Sample data for training
x_data = jnp.array([1.0, 2.0, 3.0, 4.0])
y_data = jnp.array([2.0, 4.0, 6.0, 8.0]) # Target: y = 2x
# 6. Define a single training step using JAX's jit for performance
@jax.jit
def train_step(params, opt_state, x, y):
# Compute loss and gradients
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
# Compute updates from gradients and optimizer state
updates, new_opt_state = optimizer.update(grads, opt_state, params)
# Apply updates to parameters
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss_value
# 7. Training loop
print(f"Initial parameters: {params}")
for i in range(100):
params, opt_state, loss_value = train_step(params, opt_state, x_data, y_data)
if i % 20 == 0:
print(f"Step {i}, Loss: {loss_value:.4f}")
print(f"Final parameters: {params}")
# Expected output for w: ~2.0, b: ~0.0