torchdiffeq
torchdiffeq is a Python library providing ordinary differential equation (ODE) solvers implemented in PyTorch. It supports backpropagation through ODE solutions using the adjoint method, ensuring constant memory cost. The library offers a clean API for usage in deep learning applications, fully supporting GPU execution. The current version is 0.2.5, last released in November 2024, indicating an active development and maintenance cadence.
Warnings
- gotcha When using `odeint_adjoint` for O(1) memory backpropagation, the ODE function (`func`) must be an instance of `torch.nn.Module`. This is crucial for the adjoint method to correctly identify and collect parameters for gradient computation.
- gotcha Direct backpropagation through `odeint` (without `odeint_adjoint`) can be memory-intensive, especially for complex ODE trajectories or long integration times, as it stores all intermediate states. For O(1) memory cost, use the adjoint method (`odeint_adjoint`).
- gotcha Adaptive ODE solvers (like the default `dopri5`) use `rtol` (relative tolerance) and `atol` (absolute tolerance) to control the accuracy and number of steps. Incorrectly set tolerances can lead to either excessively slow computations or inaccurate solutions.
- gotcha The `dtype` for timelike quantities in solvers defaults to `torch.float64`. While more stable, using `torch.float32` can significantly improve speed but might lead to numerical instability or underflow issues in certain scenarios.
Install
-
pip install torchdiffeq
Imports
- odeint
from torchdiffeq import odeint
- odeint_adjoint
from torchdiffeq import odeint_adjoint as odeint
Quickstart
import torch
import torch.nn as nn
from torchdiffeq import odeint
# Define the ODE function as an nn.Module
class ODEFunc(nn.Module):
def forward(self, t, y):
# Example ODE: dy/dt = -0.1y + t
# y and t are torch.Tensor
return -0.1 * y + t
# Initial state y(t=0)
y0 = torch.tensor([0.7])
# Time points at which to evaluate the solution
t = torch.linspace(0., 10., 100) # 100 points from t=0 to t=10
# Solve the ODE using the default (dopri5) solver
solution = odeint(ODEFunc(), y0, t)
print("Shape of solution (time_steps, initial_dim):")
print(solution.shape) # Expected: (100, 1)
print("\nFirst 5 values of the solution:")
print(solution[:5])