JAXopt

0.8.5 · active · verified Tue Apr 14

JAXopt is a Python library providing hardware-accelerated, batchable, and differentiable optimizers built on JAX. It offers a wide range of solvers for convex and non-convex optimization problems, suitable for machine learning and scientific computing, including gradient descent, L-BFGS, and quadratic programming. The current version is 0.8.5, and the project maintains a frequent release cadence with bug fixes and new features.

Warnings

Install

Imports

Quickstart

This example demonstrates how to use `GradientDescent` from JAXopt to minimize a simple quadratic loss function. It sets up dummy data, initializes model parameters, and then runs the optimizer to find the optimal parameters. This covers the basic workflow of defining an objective, choosing a solver, and executing it.

import jax
import jax.numpy as jnp
from jaxopt import GradientDescent

# Define a quadratic function to minimize
def quadratic_loss(params, data):
    X, y = data
    return jnp.mean((jnp.dot(X, params['weights']) + params['bias'] - y)**2)

# Generate some dummy data
key = jax.random.PRNGKey(0)
num_samples = 100
num_features = 2
true_weights = jnp.array([1.0, 2.0])
true_bias = 3.0
X = jax.random.normal(key, (num_samples, num_features))
y = jnp.dot(X, true_weights) + true_bias + 0.1 * jax.random.normal(key, (num_samples,))
data = (X, y)

# Initialize parameters
init_params = {'weights': jnp.zeros(num_features), 'bias': 0.0}

# Instantiate the optimizer
gd = GradientDescent(fun=quadratic_loss, maxiter=1000, tol=1e-3)

# Run the optimization
sol = gd.run(init_params, data=data)
print(f"Optimal parameters: {sol.params}")
print(f"True weights: {true_weights}, True bias: {true_bias}")

view raw JSON →