Flow Matching

raw JSON →
1.0.10 verified Sat May 09 auth: no python

Flow Matching for Generative Modeling, a library by Meta Research for training continuous normalizing flows via flow matching objectives. Current version 1.0.10, supports Python >=3.9, actively maintained.

pip install flow-matching
error AttributeError: module 'flow_matching' has no attribute 'FlowMatching'
cause Importing the high-level wrapper incorrectly after v1.0.0, or using outdated documentation.
fix
Use from flow_matching import CondOTFlowMatching (or another specific flow class) instead of from flow_matching import FlowMatching.
error TypeError: __init__() got an unexpected keyword argument 'dim'
cause Passing the old `dim` parameter to a class that expects different arguments (e.g., `CondOTFlowMatching` does not take `dim` directly; it takes `input_dim` in some versions, or defaults to 2).
fix
Check the constructor signature: for CondOTFlowMatching(input_dim=2), or simply omit dim and set it later.
error RuntimeError: Expected all tensors to be on the same device, but found at least two devices
cause Model parameters and input tensors are on different devices (CPU vs GPU).
fix
Move model and tensors to same device: model = CondOTFlowMatching(dim=2).to(device); x0 = x0.to(device); x1 = x1.to(device).
breaking As of v1.0.0, the API was completely rewritten. Imports like `from flow_matching import FlowMatching` now return a high-level wrapper, not the old `CNF` class. If upgrading from <1.0, review all imports.
fix Update imports: use `from flow_matching import CondOTFlowMatching` instead of old `from flow_matching import FlowMatching`.
deprecated The `models.CNF` class is deprecated since v1.0.0. Use `CondOTFlowMatching` or other specific flow matching classes.
fix Replace `from flow_matching.models import CNF` with `from flow_matching import CondOTFlowMatching`.
gotcha The library expects input tensors (x0, x1) to have dtype `torch.float32`. Using float64 may cause silent errors or performance degradation.
fix Ensure tensors are float32: `x0 = x0.float()`.

Minimal example: create a flow matching model, compute the flow matching objective between random source and target samples.

import torch
from flow_matching import CondOTFlowMatching
from flow_matching.objectives import FMObjective

dim = 2
flow = CondOTFlowMatching(dim=dim)
objective = FMObjective(flow)

x0 = torch.randn(100, dim)
x1 = torch.randn(100, dim)
t = torch.rand(100)
loss = objective(x0, x1, t)
print(f"Loss: {loss.item()}")