Optax

0.2.8 · active · verified Thu Apr 09

Optax is a gradient processing and optimization library designed for JAX. It provides a rich set of optimizers (Adam, SGD, etc.), learning rate schedules, and gradient transformations that can be composed to build custom optimization pipelines. It's actively developed by DeepMind/Google, with the current stable version being 0.2.8, and follows a release cadence tied to JAX ecosystem developments, often releasing minor versions for bug fixes and new features.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a basic training loop with Optax and JAX. It defines a simple linear model and a mean squared error loss. An Adam optimizer is initialized, and a `train_step` function is created, leveraging `jax.grad` to compute gradients and Optax to apply updates. The `train_step` is `jax.jit`-compiled for efficiency. The loop iteratively updates parameters and optimizer state to minimize the loss.

import jax
import jax.numpy as jnp
import optax

# 1. Define a simple model and loss function
def model(params, x):
    return params['w'] * x + params['b']

def loss_fn(params, x, y):
    predictions = model(params, x)
    return jnp.mean((predictions - y)**2)

# 2. Initialize parameters
key = jax.random.PRNGKey(0)
params = {
    'w': jax.random.normal(key, ()),
    'b': jax.random.normal(key, ())
}

# 3. Choose an optimizer (e.g., Adam)
learning_rate = 0.01
optimizer = optax.adam(learning_rate)

# 4. Initialize optimizer state
opt_state = optimizer.init(params)

# 5. Sample data for training
x_data = jnp.array([1.0, 2.0, 3.0, 4.0])
y_data = jnp.array([2.0, 4.0, 6.0, 8.0]) # Target: y = 2x

# 6. Define a single training step using JAX's jit for performance
@jax.jit
def train_step(params, opt_state, x, y):
    # Compute loss and gradients
    loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)

    # Compute updates from gradients and optimizer state
    updates, new_opt_state = optimizer.update(grads, opt_state, params)

    # Apply updates to parameters
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss_value

# 7. Training loop
print(f"Initial parameters: {params}")
for i in range(100):
    params, opt_state, loss_value = train_step(params, opt_state, x_data, y_data)
    if i % 20 == 0:
        print(f"Step {i}, Loss: {loss_value:.4f}")

print(f"Final parameters: {params}")
# Expected output for w: ~2.0, b: ~0.0

view raw JSON →