better-optimize

raw JSON →
0.4.1 verified Mon Apr 27 auth: no python

A drop-in replacement for scipy.optimize functions with quality-of-life improvements, such as automatic gradient computation via JAX, progress bars, and early stopping. Current version is 0.4.1, released in 2025. The project is under active development with monthly releases.

pip install better-optimize
error ModuleNotFoundError: No module named 'better_optimize'
cause Library not installed or installed under a different name.
fix
Run 'pip install better-optimize' and ensure you are using the correct import name 'better_optimize' (underscore).
error TypeError: expected JAX array, got numpy.ndarray
cause Passing NumPy arrays directly to functions that expect JAX arrays.
fix
Convert inputs to JAX arrays: x_jax = jnp.array(x_numpy).
breaking This library requires Python >=3.12. Attempting to install on older versions will fail.
fix Upgrade Python to 3.12 or later.
gotcha The library expects JAX arrays as inputs when using JAX-based gradients. Using NumPy arrays may cause errors or unexpected behavior.
fix Convert NumPy arrays to JAX arrays via jnp.array(your_array).
deprecated The 'callback' parameter in earlier versions was renamed to 'progress_callback'. The old name still works but will be removed in 0.5.0.
fix Use 'progress_callback' instead of 'callback'.

Minimize the Rosenbrock function using BFGS. Note: JAX arrays are used for automatic differentiation.

import jax.numpy as jnp
from better_optimize import minimize

def rosen(x):
    return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2

x0 = jnp.array([0.0, 0.0])
res = minimize(rosen, x0, method='BFGS')
print(res.x)