Python Optimal Transport Library (POT)

0.9.6.post1 · active · verified Wed Apr 15

POT (Python Optimal Transport) is a comprehensive Python library offering various solvers for optimal transport problems. It provides efficient implementations for classic optimal transport, Wasserstein distances, Sinkhorn algorithm, Gromov-Wasserstein, and more, including recent extensions like unbalanced OT and GMM-OT. Currently at version 0.9.6.post1, the library sees frequent minor releases, often introducing new features, solvers, and bug fixes.

Warnings

Install

Imports

Quickstart

This example demonstrates how to compute the Earth Mover's Distance (EMD) between two 1D samples using POT's core `ot.emd` function. It covers generating samples, defining uniform marginal distributions, computing a normalized cost matrix, and finally, calculating the optimal transport plan and its total cost.

import numpy as np
import ot

# Generate two 1D samples
n = 100
np.random.seed(0)
xs = np.random.normal(0, 1, n)
xt = np.random.normal(5, 1, n)

# Histogram counts (uniform distribution)
a = np.ones(n) / n
b = np.ones(n) / n

# Cost matrix: squared Euclidean distance
M = ot.dist(xs.reshape((n, 1)), xt.reshape((n, 1)))
M /= M.max() # Normalize cost matrix for stability

# Compute Earth Mover's Distance (EMD) / Wasserstein-1 distance
G = ot.emd(a, b, M)
print(f"Optimal Transport plan (first 5x5):
{G[:5,:5]}")
print(f"EMD cost: {np.sum(G * M)}")

view raw JSON →