e3nn-jax library

0.21.0 · active · verified Fri Apr 17

e3nn-jax is a Python library for constructing Equivariant Neural Networks (ENN) using JAX, specifically designed for the E(3) group of 3D rotations, translations, and reflections. It provides fundamental building blocks like Irreducible Representations (Irreps), spherical harmonics, and equivariant layers, enabling the design of networks that respect geometric symmetries. As of version 0.21.0, it is actively maintained with regular updates, reflecting advancements in the E(3) equivariant deep learning field.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define Irreducible Representations (Irreps), generate random equivariant features, and compute spherical harmonics from 3D positions, which are core operations in e3nn-jax. It uses `jax.random` for reproducibility and illustrates basic data generation and transformation.

import jax
import jax.numpy as jnp
from e3nn_jax import Irreps, rand_irreps, spherical_harmonics

key = jax.random.PRNGKey(0)

# Define input and output Irreps
irreps_in = Irreps("1x0e + 2x1o")
irreps_sh = Irreps("0e + 1o + 2e") # Spherical harmonics up to l=2

# Create random input features and positions
features = rand_irreps(key, irreps_in, (10,)).array # 10 samples
positions = jax.random.normal(key, (10, 3)) # 10 samples, 3D coordinates

# Compute spherical harmonics
sh = spherical_harmonics(irreps_sh, positions, normalize=True, normalization='component')

print(f"Input features irreps: {irreps_in}")
print(f"Input features shape: {features.shape}")
print(f"Spherical harmonics irreps: {irreps_sh}")
print(f"Spherical harmonics shape: {sh.shape}")

view raw JSON →