torchdiffeq

0.2.5 · active · verified Sun Apr 12

torchdiffeq is a Python library providing ordinary differential equation (ODE) solvers implemented in PyTorch. It supports backpropagation through ODE solutions using the adjoint method, ensuring constant memory cost. The library offers a clean API for usage in deep learning applications, fully supporting GPU execution. The current version is 0.2.5, last released in November 2024, indicating an active development and maintenance cadence.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a simple ODE function as an `nn.Module` and use `torchdiffeq.odeint` to solve it over a specified time interval. The output `solution` tensor contains the evaluated states at each time point.

import torch
import torch.nn as nn
from torchdiffeq import odeint

# Define the ODE function as an nn.Module
class ODEFunc(nn.Module):
    def forward(self, t, y):
        # Example ODE: dy/dt = -0.1y + t
        # y and t are torch.Tensor
        return -0.1 * y + t

# Initial state y(t=0)
y0 = torch.tensor([0.7])

# Time points at which to evaluate the solution
t = torch.linspace(0., 10., 100) # 100 points from t=0 to t=10

# Solve the ODE using the default (dopri5) solver
solution = odeint(ODEFunc(), y0, t)

print("Shape of solution (time_steps, initial_dim):")
print(solution.shape) # Expected: (100, 1)
print("\nFirst 5 values of the solution:")
print(solution[:5])

view raw JSON →