{"id":8938,"library":"diffrax","title":"Diffrax","description":"Diffrax is a high-performance Python library for solving ordinary, stochastic, and controlled differential equations (ODEs, SDEs, CDEs). Built on JAX, it offers GPU acceleration, automatic differentiation, and is designed for research in scientific machine learning. It is currently at version 0.7.2 and receives regular updates, often in sync with JAX ecosystem developments.","status":"active","version":"0.7.2","language":"en","source_language":"en","source_url":"https://github.com/patrick-kidger/diffrax","tags":["JAX","ODE","SDE","CDE","numerical-solvers","differential-equations","autodiff","GPU","scientific-machine-learning"],"install":[{"cmd":"pip install diffrax","lang":"bash","label":"Stable release"},{"cmd":"pip install 'diffrax[cuda]' # for GPU users","lang":"bash","label":"GPU support"}],"dependencies":[{"reason":"Core numerical backend for Diffrax, providing automatic differentiation and GPU support.","package":"jax","optional":false},{"reason":"Used for structuring models and differential equation terms as JAX-compatible PyTrees.","package":"equinox","optional":false},{"reason":"Provides static type-checking for JAX arrays, improving code clarity and robustness.","package":"jaxtyping","optional":false}],"imports":[{"symbol":"diffeqsolve","correct":"from diffrax import diffeqsolve"},{"symbol":"ODETerm","correct":"from diffrax import ODETerm"},{"symbol":"Tsit5","correct":"from diffrax import Tsit5"},{"symbol":"Euler","correct":"from diffrax import Euler"},{"symbol":"SaveAt","correct":"from diffrax import SaveAt"},{"symbol":"Solution","correct":"from diffrax import Solution"}],"quickstart":{"code":"import diffrax as dfx\nimport jax\nimport jax.numpy as jnp\n\n# Define the ODE function dy/dt = -y\ndef func(t, y, args):\n    return -y\n\n# Define the ODE term\nterm = dfx.ODETerm(func)\n\n# Choose a solver and step size controller\nsolver = dfx.Tsit5()\nstepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)\n\n# Initial conditions and time span\nt0 = 0.0\nt1 = 1.0\ndt0 = 0.1 # initial step size\ny0 = jnp.array([1.0])\nargs = () # No extra arguments for func in this example\n\n# Solve the differential equation\nsol = dfx.diffeqsolve(\n    term,\n    solver,\n    t0,\n    t1,\n    dt0,\n    y0,\n    args=args,\n    stepsize_controller=stepsize_controller,\n    saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1, 11))\n)\n\n# Access the solution\n# print(sol.ts) # Time points\n# print(sol.ys) # Solutions at time points\nassert jnp.allclose(sol.ys[-1], jnp.exp(-t1), atol=1e-4)\nprint(\"Solution at t=1.0:\", sol.ys[-1])\n","lang":"python","description":"This quickstart solves the simple ODE dy/dt = -y from t=0 to t=1 with initial condition y(0)=1. It uses an `ODETerm` to define the function, the `Tsit5` adaptive solver, and `PIDController` for step size control, saving results at specified time points."},"warnings":[{"fix":"Enable 64-bit precision in JAX by adding `jax.config.update('jax_enable_x64', True)` at the very beginning of your program, or explicitly cast all JAX arrays to `jnp.float64`.","message":"JAX defaults to `float32` precision, which is often insufficient for scientific computing and numerical ODE/SDE solvers, leading to lower precision than expected from traditional numerical libraries. This can cause discrepancies in results compared to `scipy.integrate` or MATLAB.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Ensure `func` only uses JAX operations (`jax.numpy`), and that `y0` and `args` are JAX PyTrees. Avoid mutating Python lists/dicts, printing within `func`, or using non-JAX libraries inside JIT-compiled code. If custom classes are used in `y0` or `args`, ensure they are registered as JAX PyTrees (often handled by `equinox.Module`).","message":"Operations within the `func` passed to `ODETerm` (or other terms) must be JAX-compatible and side-effect free for efficient JIT compilation. Using standard Python data structures (lists, dicts with JAX array keys) or non-JAX operations inside `func` can cause errors, prevent JIT compilation, or lead to incorrect gradients.","severity":"gotcha","affected_versions":"All versions"},{"fix":"For adaptive solvers, `dt0` provides an initial guess; the `stepsize_controller` will adjust it. For fixed-step solvers, `dt0` *is* the step size. Always choose a `dt0` appropriate for your solver type and problem stiffness. Review documentation for your chosen solver.","message":"The `dt0` parameter means different things for adaptive solvers (e.g., `Tsit5`, `Dopri5`) where it's an *initial* step size, versus fixed-step solvers (e.g., `Euler`, `Midpoint`) where it's the *fixed* step size. Misinterpreting this can lead to inefficient computation (too small `dt0` for adaptive) or inaccurate/unstable results (too large `dt0` for fixed).","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Ensure all JAX arrays involved in `diffrax` computations have the same `dtype`. The easiest way is to set `jax.config.update('jax_enable_x64', True)` at the very start of your script to make JAX default to `float64`, or explicitly cast all input arrays to `jnp.float32` (e.g., `jnp.array(my_data, dtype=jnp.float32)`).","cause":"You are mixing JAX arrays of `float32` (the JAX default) and `float64` precision within a computation.","error":"RuntimeError: A JAX array was produced with dtype float32, but an input array had dtype float64."},{"fix":"Verify that the first argument to `ODETerm` is indeed a function or `equinox.Module` that accepts `(t, y, args)` as arguments. Ensure that any auxiliary data (`args`) is passed correctly via the `args` parameter of `diffeqsolve`, not as the primary function.","cause":"The `func` argument to `ODETerm` (or similar `Term` classes) must be a callable Python function or an `equinox.Module`. This error often occurs when the `args` parameter, or another non-callable object, is mistakenly passed as the `func`.","error":"TypeError: '<tuple>' object is not callable"},{"fix":"Convert Python lists or scalars to JAX arrays (e.g., `y0 = jnp.array([1.0, 2.0])`) before passing them as `y0`. For complex structures in `args`, use `equinox.Module` or explicitly register custom Python classes as JAX PyTrees if they contain JAX arrays.","cause":"This typically happens when you try to pass standard Python lists or dictionaries that are not registered as JAX PyTrees (e.g., containing JAX arrays) directly into `diffeqsolve`'s `y0` or within `args`, especially in a JIT-compiled context.","error":"ValueError: Cannot array-like from non-array-like python list"},{"fix":"Do not use JAX arrays as keys in Python dictionaries or as elements in Python sets within JIT-compiled code. Instead, use string, integer, or other hashable Python primitives as keys. If you need to map values based on array content, consider using `jax.tree_map` or `jax.tree_util` functions with appropriate pytree structures.","cause":"You are attempting to use a JAX array as a key in a Python dictionary or as an element in a Python set within a JIT-compiled function. JAX arrays are tracers during JIT compilation and are not hashable.","error":"KeyError: This JAX array is not hashable"}]}