Haiku
Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations. It provides a module abstraction (`hk.Module`) and a function transformation (`hk.transform`) to manage model parameters and state. As of July 2023, Google DeepMind recommends Flax for new projects, with Haiku having entered maintenance mode, focusing on bug fixes and compatibility with new JAX releases. The current version is 0.0.16.
Warnings
- deprecated As of July 2023, Google DeepMind recommends that new projects adopt Flax instead of Haiku. Haiku is in maintenance mode, focusing on bug fixes and JAX compatibility rather than new features.
- gotcha Haiku frequently updates to maintain compatibility with new JAX releases. It is highly recommended to pin your JAX and JAXlib versions to specific compatible versions to avoid unexpected breakage, especially in production environments.
- breaking `hk.vmap(..)` now requires the `split_rng` argument to be explicitly passed.
- breaking `hk.jit` was removed from the public API.
- gotcha Using JAX transformations (like `jax.jit`, `jax.vmap`, `jax.remat`, `jax.lax.scan`) directly inside a Haiku module or within a function intended to be transformed by `hk.transform` can lead to `jax.errors.UnexpectedTracerError` or silently wrong results, because Haiku modules are side-effecting before transformation.
Install
-
pip install -U dm-haiku
Imports
- haiku
import haiku as hk
- jax
import jax
- jax.numpy
import jax.numpy as jnp
Quickstart
import haiku as hk
import jax
import jax.numpy as jnp
def forward_fn(x):
mlp = hk.nets.MLP([300, 100, 10])
return mlp(x)
transformed_forward = hk.transform(forward_fn)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jnp.ones([8, 28 * 28]) # Example input
# Initialize parameters
params = transformed_forward.init(next(rng), x)
# Apply the model
logits = transformed_forward.apply(params, next(rng), x)
print("Parameters structure:", jax.tree_util.tree_map(lambda x: x.shape, params))
print("Output shape:", logits.shape)