{"id":5977,"library":"jmp","title":"JMP (JAX Mixed Precision)","description":"JMP is a Python library from DeepMind that provides abstractions for mixed precision training in JAX. It enables the use of full and half-precision floating-point numbers during model training to reduce memory bandwidth and improve computational efficiency. It is currently at version 0.0.4 and sees active development with new releases addressing JAX compatibility and feature enhancements.","status":"active","version":"0.0.4","language":"en","source_language":"en","source_url":"https://github.com/deepmind/jmp","tags":["jax","mixed precision","deep learning","deepmind","machine learning"],"install":[{"cmd":"pip install jmp","lang":"bash","label":"Install JMP from PyPI"},{"cmd":"# First, install JAX with accelerator support (e.g., CPU, CUDA, TPU) following JAX's official instructions.\n# Example for CPU:\npip install jax[cpu]\n\npip install jmp","lang":"bash","label":"Install JAX (prerequisite) and JMP"}],"dependencies":[{"reason":"Core dependency for numerical computation. Must be installed separately to choose correct accelerator version.","package":"jax","optional":false},{"reason":"Used by JAX internally and for array handling.","package":"numpy","optional":false}],"imports":[{"symbol":"jmp","correct":"import jmp"},{"symbol":"Policy","correct":"from jmp import Policy"},{"symbol":"get_policy","correct":"from jmp import get_policy"},{"symbol":"DynamicLossScale","correct":"from jmp import DynamicLossScale"},{"symbol":"all_finite","correct":"from jmp import all_finite"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport jmp\n\n# Define floating point types based on your hardware (e.g., bfloat16 for TPU, float16 for GPU)\nhalf = jnp.float16 # or jnp.bfloat16 for TPUs\nfull = jnp.float32\n\n# Create a mixed precision policy\n# Parameters stored in full precision, computation and output in half precision\npolicy = jmp.Policy(param_dtype=full, compute_dtype=half, output_dtype=half)\n\n# Example: Applying policy to a JAX array\nx = jnp.array([1.0, 2.0, 3.0], dtype=full)\nx_half = policy.cast_to_compute(x)\nprint(f\"Original (full): {x.dtype}, {x}\")\nprint(f\"Computed (half): {x_half.dtype}, {x_half}\")\n\n# Example: Using DynamicLossScale\n# Initialize DynamicLossScale with an initial loss scale value\nloss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2**15)) # Using a jmp dtype helper\n\n# Simulate a gradient check and adjustment\ngrads_finite = True # Assume gradients were finite in this step\nloss_scale = loss_scale.adjust(grads_finite)\nprint(f\"Adjusted loss scale: {loss_scale.loss_scale}\")\n","lang":"python","description":"This quickstart demonstrates how to define a mixed precision policy using `jmp.Policy` and apply it to JAX arrays. It also shows the basic usage of `jmp.DynamicLossScale` for adjusting the loss scale during training based on gradient finiteness."},"warnings":[{"fix":"Upgrade Python to 3.8 or newer. The current recommended minimum Python version is 3.8.","message":"JMP v0.0.3 dropped support for Python 3.7. Users on Python 3.7 must upgrade their Python version to use v0.0.3 or newer.","severity":"breaking","affected_versions":">=0.0.3"},{"fix":"Follow JAX's official installation guide (e.g., `pip install jax[cuda12_pip]` for CUDA 12) before running `pip install jmp`.","message":"JMP relies on JAX, which has specific installation instructions depending on your desired accelerator (CPU, GPU, TPU). JMP does not list JAX as a direct dependency in its `requirements.txt` to avoid conflicts. You must install JAX separately *before* installing JMP.","severity":"gotcha","affected_versions":"All"},{"fix":"Ensure that inputs to `DynamicLossScale` and related functions are of appropriate floating-point dtypes (e.g., `jnp.float16`, `jnp.bfloat16`, `jnp.float32`). Use `jmp.half_dtype()` or `jmp.full_dtype()` for clarity.","message":"`DynamicLossScale` might warn if non-floating point types are passed where floating types are expected. While JMP v0.0.4 includes fixes to avoid triggering certain warnings, ensuring correct dtype usage is crucial for stable mixed precision training.","severity":"gotcha","affected_versions":"All"}],"env_vars":null,"last_verified":"2026-04-14T00:00:00.000Z","next_check":"2026-07-13T00:00:00.000Z"}