{"id":6346,"library":"dm-haiku","title":"Haiku","description":"Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations. It provides a module abstraction (`hk.Module`) and a function transformation (`hk.transform`) to manage model parameters and state. As of July 2023, Google DeepMind recommends Flax for new projects, with Haiku having entered maintenance mode, focusing on bug fixes and compatibility with new JAX releases. The current version is 0.0.16.","status":"maintenance","version":"0.0.16","language":"en","source_language":"en","source_url":"https://github.com/deepmind/dm-haiku","tags":["JAX","neural networks","deep learning","machine learning","DeepMind"],"install":[{"cmd":"pip install -U dm-haiku","lang":"bash","label":"Install from PyPI"}],"dependencies":[{"reason":"Haiku is built on JAX; JAX must be installed separately with appropriate accelerator support (e.g., CUDA) before installing Haiku.","package":"jax","optional":false},{"reason":"Required for JAX functionality.","package":"jaxlib","optional":false},{"reason":"Required for configuration and logging utilities.","package":"absl-py","optional":false},{"reason":"Used for mixed precision training.","package":"jmp","optional":false},{"reason":"Fundamental numerical computing library.","package":"numpy","optional":false},{"reason":"Used for displaying data in tabular format.","package":"tabulate","optional":false},{"reason":"Recommended by DeepMind for new projects as an alternative to Haiku; Haiku can be used without Flax in Python >=3.13.","package":"flax","optional":true}],"imports":[{"symbol":"haiku","correct":"import haiku as hk"},{"symbol":"jax","correct":"import jax"},{"symbol":"jax.numpy","correct":"import jax.numpy as jnp"}],"quickstart":{"code":"import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\ndef forward_fn(x):\n    mlp = hk.nets.MLP([300, 100, 10])\n    return mlp(x)\n\ntransformed_forward = hk.transform(forward_fn)\n\nrng = hk.PRNGSequence(jax.random.PRNGKey(42))\nx = jnp.ones([8, 28 * 28]) # Example input\n\n# Initialize parameters\nparams = transformed_forward.init(next(rng), x)\n\n# Apply the model\nlogits = transformed_forward.apply(params, next(rng), x)\n\nprint(\"Parameters structure:\", jax.tree_util.tree_map(lambda x: x.shape, params))\nprint(\"Output shape:\", logits.shape)","lang":"python","description":"This quickstart demonstrates how to define a simple Multi-Layer Perceptron (MLP) using Haiku modules and then transform it into a pair of pure functions (init and apply) compatible with JAX transformations. It shows how to initialize model parameters using a JAX PRNG key and then apply the model to an input."},"warnings":[{"fix":"Consider using Flax (e.g., `flax.linen`) for new projects. Haiku will continue to be supported for existing internal DeepMind usage.","message":"As of July 2023, Google DeepMind recommends that new projects adopt Flax instead of Haiku. Haiku is in maintenance mode, focusing on bug fixes and JAX compatibility rather than new features.","severity":"deprecated","affected_versions":"0.0.15 and later"},{"fix":"Explicitly specify `jax` and `jaxlib` versions in your `requirements.txt` or `pyproject.toml` file. Refer to Haiku's GitHub releases for notes on JAX compatibility for each version.","message":"Haiku frequently updates to maintain compatibility with new JAX releases. It is highly recommended to pin your JAX and JAXlib versions to specific compatible versions to avoid unexpected breakage, especially in production environments.","severity":"gotcha","affected_versions":"All versions"},{"fix":"When using `hk.vmap`, ensure you provide the `split_rng` argument (e.g., `hk.vmap(func, split_rng=True)`).","message":"`hk.vmap(..)` now requires the `split_rng` argument to be explicitly passed.","severity":"breaking","affected_versions":"0.0.7 and later"},{"fix":"Directly use `jax.jit` on the `apply` function of your `hk.transform`-ed model instead of `hk.jit`. For example: `apply_jit = jax.jit(transformed_forward.apply)`.","message":"`hk.jit` was removed from the public API.","severity":"breaking","affected_versions":"0.0.7 and later"},{"fix":"Always apply JAX transformations *after* transforming your Haiku function with `hk.transform` (or `hk.transform_with_state`). If you need to apply JAX transformations to parts of your model *within* a Haiku module, consider using `hk.lift` or carefully structuring your code to ensure pure functions are passed to JAX transformations.","message":"Using JAX transformations (like `jax.jit`, `jax.vmap`, `jax.remat`, `jax.lax.scan`) directly inside a Haiku module or within a function intended to be transformed by `hk.transform` can lead to `jax.errors.UnexpectedTracerError` or silently wrong results, because Haiku modules are side-effecting before transformation.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-15T00:00:00.000Z","next_check":"2026-07-14T00:00:00.000Z"}