{"id":2032,"library":"flax","title":"Flax","description":"Flax is a high-performance neural network library for JAX, designed for flexibility and ease of use. It provides building blocks for defining models, handling parameters, and managing training state within the JAX ecosystem. As of my last check, the current version is 0.12.6. Its release cadence is closely tied to JAX updates and major developments in the JAX ecosystem, with frequent minor and patch releases.","status":"active","version":"0.12.6","language":"en","source_language":"en","source_url":"https://github.com/google/flax","tags":["jax","deep learning","neural networks","machine learning","google"],"install":[{"cmd":"pip install flax jax[cpu]","lang":"bash","label":"CPU-only JAX"},{"cmd":"pip install flax jax[cuda12_pip]","lang":"bash","label":"CUDA JAX (adjust for your CUDA version)"}],"dependencies":[{"reason":"Flax is built on top of JAX and requires it for all operations.","package":"jax"},{"reason":"The recommended library for optimizers in Flax, replacing the deprecated flax.optim.","package":"optax","optional":true}],"imports":[{"note":"The primary module for defining neural network layers and modules.","symbol":"nn","correct":"import flax.linen as nn"},{"note":"Used for immutable parameter structures returned by model initialization.","symbol":"FrozenDict","correct":"from flax.core import FrozenDict"},{"note":"Common utility for managing model parameters, optimizer state, and other training state.","symbol":"TrainState","correct":"from flax.training import train_state"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport flax.linen as nn\n\n# Define a simple Multi-Layer Perceptron (MLP)\nclass MLP(nn.Module):\n    num_neurons: int\n\n    @nn.compact\n    def __call__(self, x):\n        x = nn.Dense(features=self.num_neurons)(x) # First dense layer\n        x = nn.relu(x)\n        x = nn.Dense(features=self.num_neurons)(x) # Second dense layer\n        return x\n\n# Example usage:\nkey = jax.random.PRNGKey(0) # Initialize a PRNG key\nmodel = MLP(num_neurons=64)\n\n# Create a dummy input (batch_size, input_features)\ndummy_input = jnp.ones((1, 10))\n\n# Initialize model parameters\n# The 'params' are stored in a FrozenDict within the initialized variables\nvariables = model.init(key, dummy_input)\nparams = variables['params']\n\nprint(f\"Initial parameters structure: {jax.tree_map(lambda x: x.shape, params)}\")\n\n# Perform a forward pass\noutput = model.apply({'params': params}, dummy_input)\n\nprint(f\"Output shape: {output.shape}\")\n\n# Example of applying with a different PRNG key for randomness (e.g., dropout)\ndropout_key, _ = jax.random.split(key)\noutput_with_rng = model.apply({'params': params}, dummy_input, rngs={'dropout': dropout_key})\n","lang":"python","description":"This quickstart demonstrates how to define a basic neural network module using `flax.linen`, initialize its parameters with a JAX PRNG key, and perform a forward pass. It also shows how to pass explicit PRNG keys for stochastic operations."},"warnings":[{"fix":"Migrate to `optax` for all optimizer definitions and applications. `flax.training.train_state` is designed to work seamlessly with `optax` optimizers.","message":"The `flax.optim` module, which previously provided optimizers, has been deprecated since Flax 0.5.0 and fully removed in later versions. Attempting to use it will result in import errors.","severity":"breaking","affected_versions":">=0.5.0"},{"fix":"Always assign the result of any state-modifying function (e.g., `optimizer.apply_gradients` or `state.apply_gradients`) back to your state variable: `state = state.apply_gradients(grads)`.","message":"Flax, built on JAX, enforces immutability for all parameters and model states. Operations that modify state (e.g., weight updates during training) do not mutate in place but return a *new* Pytree with the updated values. This is crucial for JAX's functional paradigm.","severity":"gotcha","affected_versions":"All versions"},{"fix":"For operations requiring randomness within `nn.Module` (e.g., dropout), pass a dictionary of PRNG keys to `model.apply()` using the `rngs` argument (e.g., `model.apply(..., rngs={'dropout': dropout_key})`). Use `jax.random.split` to generate new sub-keys for subsequent random operations.","message":"JAX's functional approach requires explicit handling and splitting of pseudo-random number generator (PRNG) keys for any stochastic operation (e.g., `Dropout`, `initializers`). Reusing the same key will produce the same 'random' sequence, and failing to split keys can lead to non-random or deterministic behavior where randomness is expected.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Ensure that the input shape used for `model.init()` accurately reflects the expected input shape during the model's lifetime. For dynamic shapes, consider using `jax.ShapeDtypeStruct` or designing modules to be robust to varying batch sizes or sequence lengths.","message":"Flax `nn.Module`s typically initialize parameters (and other variables) based on input shapes during `model.init()`. If input shapes change during inference or subsequent calls, the model's structure or behavior might implicitly change or lead to errors if not handled correctly.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-09T00:00:00.000Z","next_check":"2026-07-08T00:00:00.000Z"}