JAXopt
JAXopt is a Python library providing hardware-accelerated, batchable, and differentiable optimizers built on JAX. It offers a wide range of solvers for convex and non-convex optimization problems, suitable for machine learning and scientific computing, including gradient descent, L-BFGS, and quadratic programming. The current version is 0.8.5, and the project maintains a frequent release cadence with bug fixes and new features.
Warnings
- breaking Support for Python 3.8 and 3.9 has been removed in recent versions (v0.8.4 and v0.8.5 respectively). Users on these Python versions must upgrade to Python 3.10 or newer.
- gotcha The usage of `jax.pure_callback` has been migrated, specifically affecting how it's handled under `vmap` when `vmap_method` is not explicitly specified. The default behavior to `vmap_method='sequential'` is deprecated, and future versions will raise `NotImplementedError` without explicit `vmap_method`.
- gotcha Early versions of JAXopt had issues with PyTree handling in certain solvers (e.g., `prox` functions, `BoxOSQP`), leading to incorrect behavior or errors when parameters were structured as JAX PyTrees. These have been fixed in newer releases.
- deprecated The 'boston' dataset, previously used in some examples, was removed due to ethical concerns. Its removal might affect older tutorials or user code that directly referenced it.
Install
-
pip install jaxopt
Imports
- GradientDescent
from jaxopt import GradientDescent
- LBFGS
from jaxopt import LBFGS
- OSQP
from jaxopt import OSQP
Quickstart
import jax
import jax.numpy as jnp
from jaxopt import GradientDescent
# Define a quadratic function to minimize
def quadratic_loss(params, data):
X, y = data
return jnp.mean((jnp.dot(X, params['weights']) + params['bias'] - y)**2)
# Generate some dummy data
key = jax.random.PRNGKey(0)
num_samples = 100
num_features = 2
true_weights = jnp.array([1.0, 2.0])
true_bias = 3.0
X = jax.random.normal(key, (num_samples, num_features))
y = jnp.dot(X, true_weights) + true_bias + 0.1 * jax.random.normal(key, (num_samples,))
data = (X, y)
# Initialize parameters
init_params = {'weights': jnp.zeros(num_features), 'bias': 0.0}
# Instantiate the optimizer
gd = GradientDescent(fun=quadratic_loss, maxiter=1000, tol=1e-3)
# Run the optimization
sol = gd.run(init_params, data=data)
print(f"Optimal parameters: {sol.params}")
print(f"True weights: {true_weights}, True bias: {true_bias}")