Diffrax
Diffrax is a high-performance Python library for solving ordinary, stochastic, and controlled differential equations (ODEs, SDEs, CDEs). Built on JAX, it offers GPU acceleration, automatic differentiation, and is designed for research in scientific machine learning. It is currently at version 0.7.2 and receives regular updates, often in sync with JAX ecosystem developments.
Common errors
-
RuntimeError: A JAX array was produced with dtype float32, but an input array had dtype float64.
cause You are mixing JAX arrays of `float32` (the JAX default) and `float64` precision within a computation.fixEnsure all JAX arrays involved in `diffrax` computations have the same `dtype`. The easiest way is to set `jax.config.update('jax_enable_x64', True)` at the very start of your script to make JAX default to `float64`, or explicitly cast all input arrays to `jnp.float32` (e.g., `jnp.array(my_data, dtype=jnp.float32)`). -
TypeError: '<tuple>' object is not callable
cause The `func` argument to `ODETerm` (or similar `Term` classes) must be a callable Python function or an `equinox.Module`. This error often occurs when the `args` parameter, or another non-callable object, is mistakenly passed as the `func`.fixVerify that the first argument to `ODETerm` is indeed a function or `equinox.Module` that accepts `(t, y, args)` as arguments. Ensure that any auxiliary data (`args`) is passed correctly via the `args` parameter of `diffeqsolve`, not as the primary function. -
ValueError: Cannot array-like from non-array-like python list
cause This typically happens when you try to pass standard Python lists or dictionaries that are not registered as JAX PyTrees (e.g., containing JAX arrays) directly into `diffeqsolve`'s `y0` or within `args`, especially in a JIT-compiled context.fixConvert Python lists or scalars to JAX arrays (e.g., `y0 = jnp.array([1.0, 2.0])`) before passing them as `y0`. For complex structures in `args`, use `equinox.Module` or explicitly register custom Python classes as JAX PyTrees if they contain JAX arrays. -
KeyError: This JAX array is not hashable
cause You are attempting to use a JAX array as a key in a Python dictionary or as an element in a Python set within a JIT-compiled function. JAX arrays are tracers during JIT compilation and are not hashable.fixDo not use JAX arrays as keys in Python dictionaries or as elements in Python sets within JIT-compiled code. Instead, use string, integer, or other hashable Python primitives as keys. If you need to map values based on array content, consider using `jax.tree_map` or `jax.tree_util` functions with appropriate pytree structures.
Warnings
- gotcha JAX defaults to `float32` precision, which is often insufficient for scientific computing and numerical ODE/SDE solvers, leading to lower precision than expected from traditional numerical libraries. This can cause discrepancies in results compared to `scipy.integrate` or MATLAB.
- gotcha Operations within the `func` passed to `ODETerm` (or other terms) must be JAX-compatible and side-effect free for efficient JIT compilation. Using standard Python data structures (lists, dicts with JAX array keys) or non-JAX operations inside `func` can cause errors, prevent JIT compilation, or lead to incorrect gradients.
- gotcha The `dt0` parameter means different things for adaptive solvers (e.g., `Tsit5`, `Dopri5`) where it's an *initial* step size, versus fixed-step solvers (e.g., `Euler`, `Midpoint`) where it's the *fixed* step size. Misinterpreting this can lead to inefficient computation (too small `dt0` for adaptive) or inaccurate/unstable results (too large `dt0` for fixed).
Install
-
pip install diffrax -
pip install 'diffrax[cuda]' # for GPU users
Imports
- diffeqsolve
from diffrax import diffeqsolve
- ODETerm
from diffrax import ODETerm
- Tsit5
from diffrax import Tsit5
- Euler
from diffrax import Euler
- SaveAt
from diffrax import SaveAt
- Solution
from diffrax import Solution
Quickstart
import diffrax as dfx
import jax
import jax.numpy as jnp
# Define the ODE function dy/dt = -y
def func(t, y, args):
return -y
# Define the ODE term
term = dfx.ODETerm(func)
# Choose a solver and step size controller
solver = dfx.Tsit5()
stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)
# Initial conditions and time span
t0 = 0.0
t1 = 1.0
dt0 = 0.1 # initial step size
y0 = jnp.array([1.0])
args = () # No extra arguments for func in this example
# Solve the differential equation
sol = dfx.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
stepsize_controller=stepsize_controller,
saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1, 11))
)
# Access the solution
# print(sol.ts) # Time points
# print(sol.ys) # Solutions at time points
assert jnp.allclose(sol.ys[-1], jnp.exp(-t1), atol=1e-4)
print("Solution at t=1.0:", sol.ys[-1])