torchsde: SDE Solvers and Adjoint Sensitivity in PyTorch
torchsde provides robust stochastic differential equation (SDE) solvers with GPU support and efficient adjoint sensitivity analysis in PyTorch. It enables differentiable SDE simulations, crucial for deep learning models involving stochastic processes. The current version is 0.2.6, and after a significant hiatus, it recently saw a maintenance release to address dependency issues and improve stability, indicating an active but potentially slow release cadence.
Warnings
- breaking Older versions (pre-0.2.6) of torchsde had known dependency resolution issues on PyPI, leading to installation failures or incorrect dependency versions.
- gotcha The choice between `sdeint` and `sdeint_adjoint` is crucial for memory efficiency during training. `sdeint_adjoint` uses adjoint sensitivity analysis to compute gradients with O(1) memory cost with respect to the SDE trajectory length, which is vital for long simulations.
- gotcha For optimal performance with Brownian motion generation, ensure you are importing `BrownianPath` or `BrownianTree` from `torchsde.brownian_lib` (the C++ backend) rather than `torchsde.brownian` (the older Python backend).
- gotcha torchsde is tightly integrated with PyTorch and specific PyTorch versions. Incompatibility between torchsde and your installed PyTorch version can lead to runtime errors or unexpected behavior.
Install
-
pip install torchsde
Imports
- sdeint
from torchsde import sdeint
- sdeint_adjoint
from torchsde import sdeint_adjoint
- BrownianPath
from torchsde.brownian_lib import BrownianPath
Quickstart
import torch
import torchsde
class SDE(torch.nn.Module):
def __init__(self, d, m):
super().__init__()
self.mu = torch.nn.Linear(d, d)
self.sigma = torch.nn.Linear(d, m)
def f(self, t, y):
return self.mu(y)
def g(self, t, y):
return self.sigma(y)
def f_and_g(self, t, y):
return self.mu(y), self.sigma(y)
# Define parameters
D = 2 # State dimension
M = 2 # Noise dimension
T = 1.0 # End time
sde = SDE(D, M)
ts = torch.linspace(0, T, 10)
eps = 0.1 # Small initial perturbation
y0 = torch.rand(1, D) * eps # Initial state
# Solve the SDE
with torch.no_grad(): # For inference, use no_grad
ys = torchsde.sdeint(sde, y0, ts)
print("SDE solved, output shape:", ys.shape)
# Expected output shape: (len(ts), batch_size, D)
# e.g., (10, 1, 2)