{"id":9714,"library":"e3nn-jax","title":"e3nn-jax library","description":"e3nn-jax is a Python library for constructing Equivariant Neural Networks (ENN) using JAX, specifically designed for the E(3) group of 3D rotations, translations, and reflections. It provides fundamental building blocks like Irreducible Representations (Irreps), spherical harmonics, and equivariant layers, enabling the design of networks that respect geometric symmetries. As of version 0.21.0, it is actively maintained with regular updates, reflecting advancements in the E(3) equivariant deep learning field.","status":"active","version":"0.21.0","language":"en","source_language":"en","source_url":"https://github.com/e3nn/e3nn-jax","tags":["deep learning","machine learning","neural networks","equivariant networks","geometric deep learning","JAX","E(3) equivariance","3D data"],"install":[{"cmd":"pip install e3nn-jax","lang":"bash","label":"Basic installation"},{"cmd":"pip install --upgrade \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\npip install e3nn-jax","lang":"bash","label":"Installation with JAX for CUDA 12 (adjust for your CUDA version)"}],"dependencies":[{"reason":"Core dependency for array manipulation and automatic differentiation.","package":"jax","optional":false},{"reason":"JAX's backend. Often requires careful manual installation for GPU support to ensure the correct hardware-specific version is used. `e3nn-jax` will install a CPU-compatible `jaxlib` by default if not present.","package":"jaxlib","optional":false}],"imports":[{"symbol":"Irreps","correct":"from e3nn_jax import Irreps"},{"symbol":"spherical_harmonics","correct":"from e3nn_jax import spherical_harmonics"},{"symbol":"rand_irreps","correct":"from e3nn_jax import rand_irreps"},{"note":"Flax modules are in the `e3nn_jax.flax` submodule, not directly under `e3nn_jax` or a generic `linear`.","wrong":"from e3nn_jax.linear import Linear","symbol":"flax.Linear","correct":"from e3nn_jax.flax import Linear"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nfrom e3nn_jax import Irreps, rand_irreps, spherical_harmonics\n\nkey = jax.random.PRNGKey(0)\n\n# Define input and output Irreps\nirreps_in = Irreps(\"1x0e + 2x1o\")\nirreps_sh = Irreps(\"0e + 1o + 2e\") # Spherical harmonics up to l=2\n\n# Create random input features and positions\nfeatures = rand_irreps(key, irreps_in, (10,)).array # 10 samples\npositions = jax.random.normal(key, (10, 3)) # 10 samples, 3D coordinates\n\n# Compute spherical harmonics\nsh = spherical_harmonics(irreps_sh, positions, normalize=True, normalization='component')\n\nprint(f\"Input features irreps: {irreps_in}\")\nprint(f\"Input features shape: {features.shape}\")\nprint(f\"Spherical harmonics irreps: {irreps_sh}\")\nprint(f\"Spherical harmonics shape: {sh.shape}\")","lang":"python","description":"This quickstart demonstrates how to define Irreducible Representations (Irreps), generate random equivariant features, and compute spherical harmonics from 3D positions, which are core operations in e3nn-jax. It uses `jax.random` for reproducibility and illustrates basic data generation and transformation."},"warnings":[{"fix":"Always check the `e3nn-jax` release notes and `pyproject.toml` for supported JAX versions. Update both `jax` and `e3nn-jax` simultaneously, ensuring `jaxlib` matches your hardware and `jax` version (e.g., `pip install --upgrade jax jaxlib e3nn-jax`).","message":"JAX version compatibility is critical. `e3nn-jax` closely tracks JAX's development. Upgrading JAX (especially `jaxlib`) often requires a corresponding `e3nn-jax` upgrade to avoid cryptic JIT errors, `AttributeError`s, or unexpected behavior.","severity":"breaking","affected_versions":"All versions, especially when crossing major JAX versions."},{"fix":"If you relied on rich HTML output in notebooks, you will need to manually format the `Irreps` object for display, e.g., using `str(irreps_object)` or custom display logic.","message":"The `Irreps._repr_html_` method was removed in version 0.20.0, which means `Irreps` objects no longer render as rich HTML in Jupyter notebooks by default.","severity":"breaking","affected_versions":">=0.20.0"},{"fix":"Review the signature of `e3nn_jax.flax.Linear` in your code. Explicitly pass `ir_in=...` and `ir_out=...` to avoid ambiguity and ensure correctness.","message":"The argument order for `ir_in` and `ir_out` in `e3nn_jax.flax.Linear` was swapped around version 0.17.0 to align with more natural tensor flow.","severity":"breaking","affected_versions":">=0.17.0"},{"fix":"Follow the official JAX installation guide (`https://github.com/google/jax#installation`) to install the correct `jax` and `jaxlib` version for your specific CPU/CUDA setup first, then install `e3nn-jax`.","message":"For optimal performance (especially with GPU), `jaxlib` should often be installed manually *before* `e3nn-jax` to ensure the correct hardware-specific version is picked up. `pip install e3nn-jax` alone might install a CPU-only `jaxlib` or an incompatible version.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-17T00:00:00.000Z","next_check":"2026-07-16T00:00:00.000Z","problems":[{"fix":"Instantiate the `Irreps` class by passing the string to its constructor, e.g., `my_irreps = Irreps(\"1x0e + 2x1o\")` instead of `my_irreps = Irreps(\"...\")()`.","cause":"Attempting to call an `Irreps` object as if it were a function. `Irreps` is a class used to define a representation, not a function to convert a string.","error":"TypeError: 'Irreps' object is not callable"},{"fix":"Reinstall `jax` and `jaxlib` carefully, ensuring the `jaxlib` version matches your hardware environment (CPU/CUDA) as specified in the official JAX installation guides. Often, this means explicitly installing `jaxlib` for your CUDA version (e.g., `pip install --upgrade \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`) before installing `e3nn-jax`.","cause":"This typically indicates a mismatch between your installed `jax`, `jaxlib`, and potentially CUDA versions (if on GPU), or a corrupted installation.","error":"RuntimeError: JAX is not installed correctly. Please follow the instructions at https://github.com/google/jax#installation to install JAX."},{"fix":"Debug the shapes of all inputs to the `e3nn-jax` function causing the error. Ensure `Irreps` are correctly defined for both input and output, and that batch dimensions are consistent. For dynamic batching in JIT, consider using `jax.vmap`.","cause":"Shape mismatch errors frequently occur in JAX's JITted functions when input array shapes unexpectedly change, or `e3nn-jax` operations receive inputs that do not conform to the expected Irreps or batch dimensions.","error":"ValueError: Arguments '...' and '...' have different shapes. Expected equal shapes, but got ... and ..."},{"fix":"Consult the latest `e3nn-jax` documentation or source code for the correct function signature. Ensure all required positional and keyword arguments are provided with correct types. For `rand_irreps`, it should be `rand_irreps(key, irreps_object, shape_tuple)`.","cause":"Incorrect usage of `e3nn_jax.rand_irreps` (or similar functions). For example, `rand_irreps` requires both a JAX `PRNGKey` and an `Irreps` object.","error":"TypeError: rand_irreps() missing 1 required positional argument: 'irreps'"}]}