Flax
Flax is a high-performance neural network library for JAX, designed for flexibility and ease of use. It provides building blocks for defining models, handling parameters, and managing training state within the JAX ecosystem. As of my last check, the current version is 0.12.6. Its release cadence is closely tied to JAX updates and major developments in the JAX ecosystem, with frequent minor and patch releases.
Warnings
- breaking The `flax.optim` module, which previously provided optimizers, has been deprecated since Flax 0.5.0 and fully removed in later versions. Attempting to use it will result in import errors.
- gotcha Flax, built on JAX, enforces immutability for all parameters and model states. Operations that modify state (e.g., weight updates during training) do not mutate in place but return a *new* Pytree with the updated values. This is crucial for JAX's functional paradigm.
- gotcha JAX's functional approach requires explicit handling and splitting of pseudo-random number generator (PRNG) keys for any stochastic operation (e.g., `Dropout`, `initializers`). Reusing the same key will produce the same 'random' sequence, and failing to split keys can lead to non-random or deterministic behavior where randomness is expected.
- gotcha Flax `nn.Module`s typically initialize parameters (and other variables) based on input shapes during `model.init()`. If input shapes change during inference or subsequent calls, the model's structure or behavior might implicitly change or lead to errors if not handled correctly.
Install
-
pip install flax jax[cpu] -
pip install flax jax[cuda12_pip]
Imports
- nn
import flax.linen as nn
- FrozenDict
from flax.core import FrozenDict
- TrainState
from flax.training import train_state
Quickstart
import jax
import jax.numpy as jnp
import flax.linen as nn
# Define a simple Multi-Layer Perceptron (MLP)
class MLP(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.num_neurons)(x) # First dense layer
x = nn.relu(x)
x = nn.Dense(features=self.num_neurons)(x) # Second dense layer
return x
# Example usage:
key = jax.random.PRNGKey(0) # Initialize a PRNG key
model = MLP(num_neurons=64)
# Create a dummy input (batch_size, input_features)
dummy_input = jnp.ones((1, 10))
# Initialize model parameters
# The 'params' are stored in a FrozenDict within the initialized variables
variables = model.init(key, dummy_input)
params = variables['params']
print(f"Initial parameters structure: {jax.tree_map(lambda x: x.shape, params)}")
# Perform a forward pass
output = model.apply({'params': params}, dummy_input)
print(f"Output shape: {output.shape}")
# Example of applying with a different PRNG key for randomness (e.g., dropout)
dropout_key, _ = jax.random.split(key)
output_with_rng = model.apply({'params': params}, dummy_input, rngs={'dropout': dropout_key})