torchsde: SDE Solvers and Adjoint Sensitivity in PyTorch

0.2.6 · active · verified Sat Apr 11

torchsde provides robust stochastic differential equation (SDE) solvers with GPU support and efficient adjoint sensitivity analysis in PyTorch. It enables differentiable SDE simulations, crucial for deep learning models involving stochastic processes. The current version is 0.2.6, and after a significant hiatus, it recently saw a maintenance release to address dependency issues and improve stability, indicating an active but potentially slow release cadence.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a simple SDE model by implementing the drift `f` and diffusion `g` functions as part of a `torch.nn.Module`. It then uses `torchsde.sdeint` to solve the SDE over a given time interval `ts` starting from an initial state `y0`. For optimal performance, defining `f_and_g` to compute both drift and diffusion simultaneously is recommended.

import torch
import torchsde

class SDE(torch.nn.Module):
    def __init__(self, d, m):
        super().__init__()
        self.mu = torch.nn.Linear(d, d)
        self.sigma = torch.nn.Linear(d, m)

    def f(self, t, y):
        return self.mu(y)

    def g(self, t, y):
        return self.sigma(y)

    def f_and_g(self, t, y):
        return self.mu(y), self.sigma(y)

# Define parameters
D = 2  # State dimension
M = 2  # Noise dimension
T = 1.0 # End time

sde = SDE(D, M)
ts = torch.linspace(0, T, 10)
eps = 0.1 # Small initial perturbation
y0 = torch.rand(1, D) * eps # Initial state

# Solve the SDE
with torch.no_grad(): # For inference, use no_grad
    ys = torchsde.sdeint(sde, y0, ts)

print("SDE solved, output shape:", ys.shape)
# Expected output shape: (len(ts), batch_size, D)
# e.g., (10, 1, 2)

view raw JSON →