Optax
Optax is a gradient processing and optimization library designed for JAX. It provides a rich set of optimizers (Adam, SGD, etc.), learning rate schedules, and gradient transformations that can be composed to build custom optimization pipelines. It's actively developed by DeepMind/Google, with the current stable version being 0.2.8, and follows a release cadence tied to JAX ecosystem developments, often releasing minor versions for bug fixes and new features.
Warnings
- breaking The signature of `optimizer.update` changed in Optax 0.1.x to 0.2.x to include `params` as the third argument: `optimizer.update(grads, opt_state, params)`. This is crucial for optimizers that might need current parameter values to compute updates (e.g., for weight decay).
- gotcha Optax, like JAX, operates on immutable data structures. When you call `optimizer.update`, it returns a *new* optimizer state and *new* updates. You must reassign these values (`opt_state = new_opt_state`, `params = optax.apply_updates(params, updates)`) otherwise your training loop will not progress.
- gotcha Optax expects parameters, gradients, and optimizer states to be JAX PyTrees (e.g., nested dictionaries, lists, tuples, or custom types registered with `jax.tree_util`). Passing non-PyTree structures or incompatible types can lead to errors.
Install
-
pip install optax jax jaxlib -
pip install 'optax[accelerate]' jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Imports
- optax
import optax
- adam
import optax optimizer = optax.adam(...)
- chain
from optax import chain optimizer = chain(...)
Quickstart
import jax
import jax.numpy as jnp
import optax
# 1. Define a simple model and loss function
def model(params, x):
return params['w'] * x + params['b']
def loss_fn(params, x, y):
predictions = model(params, x)
return jnp.mean((predictions - y)**2)
# 2. Initialize parameters
key = jax.random.PRNGKey(0)
params = {
'w': jax.random.normal(key, ()),
'b': jax.random.normal(key, ())
}
# 3. Choose an optimizer (e.g., Adam)
learning_rate = 0.01
optimizer = optax.adam(learning_rate)
# 4. Initialize optimizer state
opt_state = optimizer.init(params)
# 5. Sample data for training
x_data = jnp.array([1.0, 2.0, 3.0, 4.0])
y_data = jnp.array([2.0, 4.0, 6.0, 8.0]) # Target: y = 2x
# 6. Define a single training step using JAX's jit for performance
@jax.jit
def train_step(params, opt_state, x, y):
# Compute loss and gradients
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
# Compute updates from gradients and optimizer state
updates, new_opt_state = optimizer.update(grads, opt_state, params)
# Apply updates to parameters
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss_value
# 7. Training loop
print(f"Initial parameters: {params}")
for i in range(100):
params, opt_state, loss_value = train_step(params, opt_state, x_data, y_data)
if i % 20 == 0:
print(f"Step {i}, Loss: {loss_value:.4f}")
print(f"Final parameters: {params}")
# Expected output for w: ~2.0, b: ~0.0