Flax

0.12.6 · active · verified Thu Apr 09

Flax is a high-performance neural network library for JAX, designed for flexibility and ease of use. It provides building blocks for defining models, handling parameters, and managing training state within the JAX ecosystem. As of my last check, the current version is 0.12.6. Its release cadence is closely tied to JAX updates and major developments in the JAX ecosystem, with frequent minor and patch releases.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a basic neural network module using `flax.linen`, initialize its parameters with a JAX PRNG key, and perform a forward pass. It also shows how to pass explicit PRNG keys for stochastic operations.

import jax
import jax.numpy as jnp
import flax.linen as nn

# Define a simple Multi-Layer Perceptron (MLP)
class MLP(nn.Module):
    num_neurons: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.num_neurons)(x) # First dense layer
        x = nn.relu(x)
        x = nn.Dense(features=self.num_neurons)(x) # Second dense layer
        return x

# Example usage:
key = jax.random.PRNGKey(0) # Initialize a PRNG key
model = MLP(num_neurons=64)

# Create a dummy input (batch_size, input_features)
dummy_input = jnp.ones((1, 10))

# Initialize model parameters
# The 'params' are stored in a FrozenDict within the initialized variables
variables = model.init(key, dummy_input)
params = variables['params']

print(f"Initial parameters structure: {jax.tree_map(lambda x: x.shape, params)}")

# Perform a forward pass
output = model.apply({'params': params}, dummy_input)

print(f"Output shape: {output.shape}")

# Example of applying with a different PRNG key for randomness (e.g., dropout)
dropout_key, _ = jax.random.split(key)
output_with_rng = model.apply({'params': params}, dummy_input, rngs={'dropout': dropout_key})

view raw JSON →