Diffrax

0.7.2 · active · verified Thu Apr 16

Diffrax is a high-performance Python library for solving ordinary, stochastic, and controlled differential equations (ODEs, SDEs, CDEs). Built on JAX, it offers GPU acceleration, automatic differentiation, and is designed for research in scientific machine learning. It is currently at version 0.7.2 and receives regular updates, often in sync with JAX ecosystem developments.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart solves the simple ODE dy/dt = -y from t=0 to t=1 with initial condition y(0)=1. It uses an `ODETerm` to define the function, the `Tsit5` adaptive solver, and `PIDController` for step size control, saving results at specified time points.

import diffrax as dfx
import jax
import jax.numpy as jnp

# Define the ODE function dy/dt = -y
def func(t, y, args):
    return -y

# Define the ODE term
term = dfx.ODETerm(func)

# Choose a solver and step size controller
solver = dfx.Tsit5()
stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

# Initial conditions and time span
t0 = 0.0
t1 = 1.0
dt0 = 0.1 # initial step size
y0 = jnp.array([1.0])
args = () # No extra arguments for func in this example

# Solve the differential equation
sol = dfx.diffeqsolve(
    term,
    solver,
    t0,
    t1,
    dt0,
    y0,
    args=args,
    stepsize_controller=stepsize_controller,
    saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1, 11))
)

# Access the solution
# print(sol.ts) # Time points
# print(sol.ys) # Solutions at time points
assert jnp.allclose(sol.ys[-1], jnp.exp(-t1), atol=1e-4)
print("Solution at t=1.0:", sol.ys[-1])

view raw JSON →