Equinox
Equinox is a Python library that simplifies building and training neural networks and performing scientific computing within JAX. It provides a PyTorch-like class-based API while maintaining compatibility with JAX's functional programming paradigm and its ecosystem. Equinox isn't a framework; instead, it offers tools for filtered transformations and PyTree manipulation, allowing for fine-grained control over models. It is currently at version 0.13.6 and is actively maintained.
Warnings
- gotcha Equinox models (subclasses of `eqx.Module`) are JAX PyTrees. While Equinox allows arbitrary Python objects as leaves, standard JAX transformations like `jax.jit` or `jax.grad` usually expect PyTrees of arrays. Using `eqx.filter_jit` and `eqx.filter_grad` is crucial for correctly handling non-array leaves and selectively applying transformations only to relevant parts (e.g., trainable parameters).
- gotcha JAX, and by extension Equinox, emphasizes immutable data structures. Direct in-place modification of model attributes outside of `__init__` or explicit functional updates (e.g., via `eqx.tree_at`) can lead to unexpected behavior, JIT compilation errors, or silently incorrect computations. Model updates should typically involve creating new model instances.
- gotcha Stochastic layers like `eqx.nn.Dropout` often have a `deterministic` argument. This argument can be provided at both initialization time (`__init__`) and call time (`__call__`). The call-time `deterministic` argument takes precedence over the initialization-time argument, which can be a source of confusion if not understood.
Install
-
pip install equinox
Imports
- equinox
import equinox as eqx
- jax
import jax
- jax.nn
import jax.nn as jnn
- jax.numpy
import jax.numpy as jnp
- jax.random
import jax.random as jrandom
- eqx.Module
from equinox import Module
- eqx.nn.Linear
from equinox.nn import Linear
Quickstart
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
class MLP(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3 = jrandom.split(key, 3)
self.layers = [
eqx.nn.Linear(2, 4, key=key1),
jax.nn.relu,
eqx.nn.Linear(4, 1, key=key2),
]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
key = jrandom.PRNGKey(0)
model = MLP(key)
x_input = jnp.array([1., 2.])
output = model(x_input)
print(f"Model output for input {x_input}: {output}")