Optimistix

0.1.0 · active · verified Thu Apr 16

Optimistix is a JAX library for nonlinear solvers, including root finding, minimisation, fixed points, and least squares. It features highly modular optimisers, interoperable solvers (e.g., converting root find problems to least squares), PyTree-based state management, fast compilation and runtimes, and deep integration with the JAX ecosystem for features like autodiff, autoparallelism, and GPU/TPU support. As of version 0.1.0, it requires Python 3.11+ and is under active, rapid development with frequent updates.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates finding a fixed point for an implicit Euler step of an ODE. It uses `optimistix.fixed_point` with a `Newton` solver. The `fn` defines the function for which the fixed point is sought, taking `y` and `args` and returning the next `y` value. The solution object `sol` contains the `value` of the fixed point and a `result` code indicating success or failure.

import jax.numpy as jnp
import optimistix as optx

# Let's solve the ODE dy/dt = tanh(y(t)) with the implicit Euler method.
# We need to find y1 s.t. y1 = y0 + tanh(y1) * dt.
y0 = jnp.array(1.0)
dt = jnp.array(0.1)

def fn(y, args):
    # The function to find the fixed point of: y1 = fn(y1, args)
    # Here, fn(y1) = y0 + tanh(y1) * dt
    return y0 + jnp.tanh(y) * dt

solver = optx.Newton(rtol=1e-5, atol=1e-5)

# Find the fixed point: y1 such that y1 = fn(y1).
sol = optx.fixed_point(fn, solver, y0)
y1 = sol.value

print(f"Initial y0: {y0}")
print(f"dt: {dt}")
print(f"Fixed point y1: {y1}")
print(f"Check fn(y1): {fn(y1, None)}")
print(f"Solution result (0 is success): {sol.result}")

view raw JSON →