Haiku

0.0.16 · maintenance · verified Wed Apr 15

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations. It provides a module abstraction (`hk.Module`) and a function transformation (`hk.transform`) to manage model parameters and state. As of July 2023, Google DeepMind recommends Flax for new projects, with Haiku having entered maintenance mode, focusing on bug fixes and compatibility with new JAX releases. The current version is 0.0.16.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a simple Multi-Layer Perceptron (MLP) using Haiku modules and then transform it into a pair of pure functions (init and apply) compatible with JAX transformations. It shows how to initialize model parameters using a JAX PRNG key and then apply the model to an input.

import haiku as hk
import jax
import jax.numpy as jnp

def forward_fn(x):
    mlp = hk.nets.MLP([300, 100, 10])
    return mlp(x)

transformed_forward = hk.transform(forward_fn)

rng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jnp.ones([8, 28 * 28]) # Example input

# Initialize parameters
params = transformed_forward.init(next(rng), x)

# Apply the model
logits = transformed_forward.apply(params, next(rng), x)

print("Parameters structure:", jax.tree_util.tree_map(lambda x: x.shape, params))
print("Output shape:", logits.shape)

view raw JSON →