Optimistix
Optimistix is a JAX library for nonlinear solvers, including root finding, minimisation, fixed points, and least squares. It features highly modular optimisers, interoperable solvers (e.g., converting root find problems to least squares), PyTree-based state management, fast compilation and runtimes, and deep integration with the JAX ecosystem for features like autodiff, autoparallelism, and GPU/TPU support. As of version 0.1.0, it requires Python 3.11+ and is under active, rapid development with frequent updates.
Common errors
-
XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output.
cause This typically means the operator was not well-posed (e.g., singular or ill-conditioned Jacobian matrix), or received non-finite input.fixCheck inputs to the problem for `NaN` or `inf` values. If solving a linear least-squares problem, pass `solver=AutoLinearSolver(well_posed=False)`. If the problem is inherently ill-conditioned, consider a more robust solver or re-parametrisation. Placing `jax.debug.print` or `jax.debug.breakpoint` can help diagnose the issue. -
sol.result indicates 'max_steps_reached' or 'nonlinear_max_steps_reached'
cause The solver iterated the maximum allowed number of steps without converging to the specified tolerance. The problem might not have a solution, or the initial conditions/tolerances are too strict.fixIncrease the `max_steps` argument in the solve function (e.g., `optx.fixed_point(..., max_steps=N)`). Verify that the problem actually has a solution. Loosen `rtol` (relative tolerance) or `atol` (absolute tolerance) if appropriate for the application. -
sol.result indicates 'nonlinear_divergence' or 'nonfinite'
cause The iterative solver diverged, or non-finite values (NaN/inf) were detected during the solve process.fixThis often points to a poorly scaled problem, a bad initial guess (`y0`), or an unsuitable solver. Try different initial guesses, consider scaling your problem variables, or switch to a more robust solver for 'messier' problems (e.g., `OptaxMinimiser` for minimisation, `LevenbergMarquardt` or `Dogleg` for root-finding/least-squares). -
Solver fails to converge or produces an error for a root-finding problem without a root (e.g., `1 + y**2`).
cause Attempting to find a root for a function that does not cross zero (or a fixed point for `f(x)=x` when no such `x` exists).fixVerify the mathematical properties of the function being solved. If you expect a root or fixed point but the solver fails, it may be converging to a local minimum of the squared residual instead of zero. For problems where a root is not guaranteed, consider using a minimisation algorithm on the squared residual `f(y)^2` instead of a root finder directly.
Warnings
- breaking In Optimistix v0.1.0, the `verbose` argument for solvers (e.g., `LevenbergMarquardt`) changed from accepting a `frozenset` of elements to display to a simple boolean (`True`/`False`) or a callable for full control.
- gotcha By default, solver failures (e.g., maximum steps reached, divergence, non-finite values) raise an `XlaRuntimeError`.
- gotcha Optimistix solvers may converge to a local minimum or fixed point, not necessarily the global optimum.
- gotcha JAX's `jax.scipy.optimize.minimize` API is being deprecated in favor of libraries like Optimistix and JAXopt. Optimistix provides a compatibility layer.
- gotcha Iterative solvers in Optimistix often return a solution that satisfies the tolerance conditions, but it is not necessarily the 'best-so-far' value encountered during the iterations, as tracking this would require additional memory.
Install
-
pip install optimistix
Imports
- optimistix
import optimistix as optx
- jax.numpy
import jax.numpy as jnp
- equinox
import equinox as eqx
Quickstart
import jax.numpy as jnp
import optimistix as optx
# Let's solve the ODE dy/dt = tanh(y(t)) with the implicit Euler method.
# We need to find y1 s.t. y1 = y0 + tanh(y1) * dt.
y0 = jnp.array(1.0)
dt = jnp.array(0.1)
def fn(y, args):
# The function to find the fixed point of: y1 = fn(y1, args)
# Here, fn(y1) = y0 + tanh(y1) * dt
return y0 + jnp.tanh(y) * dt
solver = optx.Newton(rtol=1e-5, atol=1e-5)
# Find the fixed point: y1 such that y1 = fn(y1).
sol = optx.fixed_point(fn, solver, y0)
y1 = sol.value
print(f"Initial y0: {y0}")
print(f"dt: {dt}")
print(f"Fixed point y1: {y1}")
print(f"Check fn(y1): {fn(y1, None)}")
print(f"Solution result (0 is success): {sol.result}")