Equinox

0.13.6 · active · verified Sat Apr 11

Equinox is a Python library that simplifies building and training neural networks and performing scientific computing within JAX. It provides a PyTorch-like class-based API while maintaining compatibility with JAX's functional programming paradigm and its ecosystem. Equinox isn't a framework; instead, it offers tools for filtered transformations and PyTree manipulation, allowing for fine-grained control over models. It is currently at version 0.13.6 and is actively maintained.

Warnings

Install

Imports

Quickstart

This quickstart defines a simple Multi-Layer Perceptron (MLP) using `eqx.Module` and `eqx.nn.Linear`. It demonstrates how to build a neural network with a PyTorch-like class syntax and then use it for inference. The model's parameters are initialized using JAX random keys.

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom

class MLP(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3 = jrandom.split(key, 3)
        self.layers = [
            eqx.nn.Linear(2, 4, key=key1),
            jax.nn.relu,
            eqx.nn.Linear(4, 1, key=key2),
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

key = jrandom.PRNGKey(0)
model = MLP(key)

x_input = jnp.array([1., 2.])
output = model(x_input)
print(f"Model output for input {x_input}: {output}")

view raw JSON →