Python Optimal Transport Library (POT)
POT (Python Optimal Transport) is a comprehensive Python library offering various solvers for optimal transport problems. It provides efficient implementations for classic optimal transport, Wasserstein distances, Sinkhorn algorithm, Gromov-Wasserstein, and more, including recent extensions like unbalanced OT and GMM-OT. Currently at version 0.9.6.post1, the library sees frequent minor releases, often introducing new features, solvers, and bug fixes.
Warnings
- breaking The Gromov-Wasserstein (GW) solvers underwent a major refactor in version 0.9.0, leading to significant performance gains and the ability to handle non-symmetric cost matrices. While the API generally remained consistent, users relying on specific internal behaviors or numerical properties of older GW implementations might observe changes in results or performance characteristics.
- gotcha POT supports multiple array backends (NumPy, PyTorch, JAX, CuPy) for computation. By default, it uses NumPy. To leverage GPU acceleration (e.g., with CuPy or PyTorch on CUDA), users must explicitly configure `ot.backend` or ensure their input arrays are of the desired backend type (e.g., CuPy arrays for `ot.gpu` functions). Mixing backends or incorrect setup can lead to errors or unexpected CPU-only computation.
- gotcha Input array dimensions are critical and frequently a source of errors. For example, marginal distributions `a` and `b` are typically 1D arrays, while coordinates `X` and `Y` are 2D arrays (n_samples, n_features), and cost matrices `M` are 2D (n_samples_source, n_samples_target). Mismatched dimensions (e.g., `(n,)` instead of `(n,1)` for single-feature coordinates or transposed cost matrices) will lead to runtime errors.
- gotcha For many optimal transport problems, particularly those with a probabilistic interpretation, the marginal distributions `a` and `b` are expected to sum to 1. While some solvers might handle unnormalized inputs, it's best practice to normalize them (e.g., `a = a / np.sum(a)`) to ensure correct interpretation and avoid potential numerical instabilities in certain algorithms.
- gotcha Optimal transport problems, especially exact EMD, can be computationally very expensive for large numbers of samples. While POT provides efficient C/Cython implementations, exact solvers scale poorly (e.g., cubic complexity for EMD). For large-scale applications, consider using entropic regularized solvers (Sinkhorn) or specialized approximate methods which trade off accuracy for speed.
Install
-
pip install pot
Imports
- ot
import ot
Quickstart
import numpy as np
import ot
# Generate two 1D samples
n = 100
np.random.seed(0)
xs = np.random.normal(0, 1, n)
xt = np.random.normal(5, 1, n)
# Histogram counts (uniform distribution)
a = np.ones(n) / n
b = np.ones(n) / n
# Cost matrix: squared Euclidean distance
M = ot.dist(xs.reshape((n, 1)), xt.reshape((n, 1)))
M /= M.max() # Normalize cost matrix for stability
# Compute Earth Mover's Distance (EMD) / Wasserstein-1 distance
G = ot.emd(a, b, M)
print(f"Optimal Transport plan (first 5x5):
{G[:5,:5]}")
print(f"EMD cost: {np.sum(G * M)}")