e3nn
e3nn is a Python library built on PyTorch for developing Euclidean (E(3)) equivariant neural networks. It focuses on symmetries related to 3D rotations, translations, and mirrors, providing fundamental mathematical operations like tensor products and spherical harmonics. The library is under active development, with version 0.6.0 being the current release, and follows a release cadence where the second version number is incremented for breaking changes.
Warnings
- breaking e3nn is under active development. Breaking changes are introduced in minor version increments (e.g., from 0.x.x to 0.y.x). Always check the CHANGELOG when updating.
- breaking Normalization constants for `o3.TensorProduct` and `o3.Linear` were changed in version 0.4.0. Models with inhomogeneous multiplicities may be affected.
- gotcha The output of an e3nn model (or any equivariant operation) must always have equal or higher symmetry than its input. Designing models that violate this principle can lead to unexpected behavior or errors.
- deprecated Older tutorials and examples (especially those for e3nn versions < 0.2) may be outdated and incompatible with current API versions. Always refer to the latest User Guide.
- breaking Python 3.6 support was dropped in earlier versions. The current library requires Python 3.8 or higher.
Install
-
pip install --upgrade e3nn
Imports
- o3
from e3nn import o3
- Irreps
from e3nn.o3 import Irreps
Quickstart
import torch
from e3nn import o3
# Create a random array made of scalar (0e) and a vector (1o) representations
# '0e' denotes a scalar with even parity, '1o' denotes a vector with odd parity.
irreps_in = o3.Irreps("0e + 1o")
x = irreps_in.randn(-1)
# Define output representations and apply a linear layer
irreps_out = o3.Irreps("2x0e + 2x1o") # Two scalars and two vectors
linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
y = linear(x)
# Compute a tensor product of the input with itself
tp = o3.FullTensorProduct(irreps_in1=irreps_in, irreps_in2=irreps_in)
z = tp(x, x)
print(f"Input (x) shape: {x.shape}, irreps: {irreps_in}")
print(f"Linear output (y) shape: {y.shape}, irreps: {irreps_out}")
print(f"Tensor product output (z) shape: {z.shape}, irreps: {tp.irreps_out}")
# For performance, modules can optionally be compiled with torch.compile
# tp_pt2 = torch.compile(tp, fullgraph=True)
# z_pt2 = tp_pt2(x, x)
# torch.testing.assert_close(z, z_pt2)