e3nn-jax library
e3nn-jax is a Python library for constructing Equivariant Neural Networks (ENN) using JAX, specifically designed for the E(3) group of 3D rotations, translations, and reflections. It provides fundamental building blocks like Irreducible Representations (Irreps), spherical harmonics, and equivariant layers, enabling the design of networks that respect geometric symmetries. As of version 0.21.0, it is actively maintained with regular updates, reflecting advancements in the E(3) equivariant deep learning field.
Common errors
-
TypeError: 'Irreps' object is not callable
cause Attempting to call an `Irreps` object as if it were a function. `Irreps` is a class used to define a representation, not a function to convert a string.fixInstantiate the `Irreps` class by passing the string to its constructor, e.g., `my_irreps = Irreps("1x0e + 2x1o")` instead of `my_irreps = Irreps("...")()`. -
RuntimeError: JAX is not installed correctly. Please follow the instructions at https://github.com/google/jax#installation to install JAX.
cause This typically indicates a mismatch between your installed `jax`, `jaxlib`, and potentially CUDA versions (if on GPU), or a corrupted installation.fixReinstall `jax` and `jaxlib` carefully, ensuring the `jaxlib` version matches your hardware environment (CPU/CUDA) as specified in the official JAX installation guides. Often, this means explicitly installing `jaxlib` for your CUDA version (e.g., `pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`) before installing `e3nn-jax`. -
ValueError: Arguments '...' and '...' have different shapes. Expected equal shapes, but got ... and ...
cause Shape mismatch errors frequently occur in JAX's JITted functions when input array shapes unexpectedly change, or `e3nn-jax` operations receive inputs that do not conform to the expected Irreps or batch dimensions.fixDebug the shapes of all inputs to the `e3nn-jax` function causing the error. Ensure `Irreps` are correctly defined for both input and output, and that batch dimensions are consistent. For dynamic batching in JIT, consider using `jax.vmap`. -
TypeError: rand_irreps() missing 1 required positional argument: 'irreps'
cause Incorrect usage of `e3nn_jax.rand_irreps` (or similar functions). For example, `rand_irreps` requires both a JAX `PRNGKey` and an `Irreps` object.fixConsult the latest `e3nn-jax` documentation or source code for the correct function signature. Ensure all required positional and keyword arguments are provided with correct types. For `rand_irreps`, it should be `rand_irreps(key, irreps_object, shape_tuple)`.
Warnings
- breaking JAX version compatibility is critical. `e3nn-jax` closely tracks JAX's development. Upgrading JAX (especially `jaxlib`) often requires a corresponding `e3nn-jax` upgrade to avoid cryptic JIT errors, `AttributeError`s, or unexpected behavior.
- breaking The `Irreps._repr_html_` method was removed in version 0.20.0, which means `Irreps` objects no longer render as rich HTML in Jupyter notebooks by default.
- breaking The argument order for `ir_in` and `ir_out` in `e3nn_jax.flax.Linear` was swapped around version 0.17.0 to align with more natural tensor flow.
- gotcha For optimal performance (especially with GPU), `jaxlib` should often be installed manually *before* `e3nn-jax` to ensure the correct hardware-specific version is picked up. `pip install e3nn-jax` alone might install a CPU-only `jaxlib` or an incompatible version.
Install
-
pip install e3nn-jax -
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install e3nn-jax
Imports
- Irreps
from e3nn_jax import Irreps
- spherical_harmonics
from e3nn_jax import spherical_harmonics
- rand_irreps
from e3nn_jax import rand_irreps
- flax.Linear
from e3nn_jax.linear import Linear
from e3nn_jax.flax import Linear
Quickstart
import jax
import jax.numpy as jnp
from e3nn_jax import Irreps, rand_irreps, spherical_harmonics
key = jax.random.PRNGKey(0)
# Define input and output Irreps
irreps_in = Irreps("1x0e + 2x1o")
irreps_sh = Irreps("0e + 1o + 2e") # Spherical harmonics up to l=2
# Create random input features and positions
features = rand_irreps(key, irreps_in, (10,)).array # 10 samples
positions = jax.random.normal(key, (10, 3)) # 10 samples, 3D coordinates
# Compute spherical harmonics
sh = spherical_harmonics(irreps_sh, positions, normalize=True, normalization='component')
print(f"Input features irreps: {irreps_in}")
print(f"Input features shape: {features.shape}")
print(f"Spherical harmonics irreps: {irreps_sh}")
print(f"Spherical harmonics shape: {sh.shape}")