{"id":3991,"library":"equinox","title":"Equinox","description":"Equinox is a Python library that simplifies building and training neural networks and performing scientific computing within JAX. It provides a PyTorch-like class-based API while maintaining compatibility with JAX's functional programming paradigm and its ecosystem. Equinox isn't a framework; instead, it offers tools for filtered transformations and PyTree manipulation, allowing for fine-grained control over models. It is currently at version 0.13.6 and is actively maintained.","status":"active","version":"0.13.6","language":"en","source_language":"en","source_url":"https://github.com/patrick-kidger/equinox","tags":["jax","deep learning","neural networks","machine learning","functional programming","scientific computing"],"install":[{"cmd":"pip install equinox","lang":"bash","label":"Install latest version"}],"dependencies":[{"reason":"Core dependency for numerical computation and automatic differentiation.","package":"jax","optional":false},{"reason":"JAX's compiled XLA operations library, required for JAX functionality.","package":"jaxlib","optional":false}],"imports":[{"symbol":"equinox","correct":"import equinox as eqx"},{"symbol":"jax","correct":"import jax"},{"symbol":"jax.nn","correct":"import jax.nn as jnn"},{"symbol":"jax.numpy","correct":"import jax.numpy as jnp"},{"symbol":"jax.random","correct":"import jax.random as jrandom"},{"note":"Commonly imported directly or accessed via `eqx.Module`","symbol":"eqx.Module","correct":"from equinox import Module"},{"note":"Commonly imported directly or accessed via `eqx.nn.Linear`","symbol":"eqx.nn.Linear","correct":"from equinox.nn import Linear"}],"quickstart":{"code":"import equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jrandom\n\nclass MLP(eqx.Module):\n    layers: list\n\n    def __init__(self, key):\n        key1, key2, key3 = jrandom.split(key, 3)\n        self.layers = [\n            eqx.nn.Linear(2, 4, key=key1),\n            jax.nn.relu,\n            eqx.nn.Linear(4, 1, key=key2),\n        ]\n\n    def __call__(self, x):\n        for layer in self.layers:\n            x = layer(x)\n        return x\n\nkey = jrandom.PRNGKey(0)\nmodel = MLP(key)\n\nx_input = jnp.array([1., 2.])\noutput = model(x_input)\nprint(f\"Model output for input {x_input}: {output}\")","lang":"python","description":"This quickstart defines a simple Multi-Layer Perceptron (MLP) using `eqx.Module` and `eqx.nn.Linear`. It demonstrates how to build a neural network with a PyTorch-like class syntax and then use it for inference. The model's parameters are initialized using JAX random keys."},"warnings":[{"fix":"Always use `eqx.filter_jit` and `eqx.filter_grad` for transformations on Equinox models, or explicitly `eqx.partition` your model into `(trainable, nontrainable)` parts before applying raw JAX transformations.","message":"Equinox models (subclasses of `eqx.Module`) are JAX PyTrees. While Equinox allows arbitrary Python objects as leaves, standard JAX transformations like `jax.jit` or `jax.grad` usually expect PyTrees of arrays. Using `eqx.filter_jit` and `eqx.filter_grad` is crucial for correctly handling non-array leaves and selectively applying transformations only to relevant parts (e.g., trainable parameters).","severity":"gotcha","affected_versions":"All versions"},{"fix":"Embrace functional programming paradigms. To update parts of a model, use functional approaches like `eqx.tree_at` to create a new model with updated values rather than modifying attributes directly. For example: `model = eqx.tree_at(lambda m: m.attribute, model, new_value)`.","message":"JAX, and by extension Equinox, emphasizes immutable data structures. Direct in-place modification of model attributes outside of `__init__` or explicit functional updates (e.g., via `eqx.tree_at`) can lead to unexpected behavior, JIT compilation errors, or silently incorrect computations. Model updates should typically involve creating new model instances.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Be mindful of the `deterministic` argument's precedence. Explicitly pass `deterministic=True` or `deterministic=False` at call time for clarity, especially during inference or training phases, if you want to override the default behavior set at initialization.","message":"Stochastic layers like `eqx.nn.Dropout` often have a `deterministic` argument. This argument can be provided at both initialization time (`__init__`) and call time (`__call__`). The call-time `deterministic` argument takes precedence over the initialization-time argument, which can be a source of confusion if not understood.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-11T00:00:00.000Z","next_check":"2026-07-10T00:00:00.000Z"}