JMP (JAX Mixed Precision)
JMP is a Python library from DeepMind that provides abstractions for mixed precision training in JAX. It enables the use of full and half-precision floating-point numbers during model training to reduce memory bandwidth and improve computational efficiency. It is currently at version 0.0.4 and sees active development with new releases addressing JAX compatibility and feature enhancements.
Warnings
- breaking JMP v0.0.3 dropped support for Python 3.7. Users on Python 3.7 must upgrade their Python version to use v0.0.3 or newer.
- gotcha JMP relies on JAX, which has specific installation instructions depending on your desired accelerator (CPU, GPU, TPU). JMP does not list JAX as a direct dependency in its `requirements.txt` to avoid conflicts. You must install JAX separately *before* installing JMP.
- gotcha `DynamicLossScale` might warn if non-floating point types are passed where floating types are expected. While JMP v0.0.4 includes fixes to avoid triggering certain warnings, ensuring correct dtype usage is crucial for stable mixed precision training.
Install
-
pip install jmp -
# First, install JAX with accelerator support (e.g., CPU, CUDA, TPU) following JAX's official instructions. # Example for CPU: pip install jax[cpu] pip install jmp
Imports
- jmp
import jmp
- Policy
from jmp import Policy
- get_policy
from jmp import get_policy
- DynamicLossScale
from jmp import DynamicLossScale
- all_finite
from jmp import all_finite
Quickstart
import jax
import jax.numpy as jnp
import jmp
# Define floating point types based on your hardware (e.g., bfloat16 for TPU, float16 for GPU)
half = jnp.float16 # or jnp.bfloat16 for TPUs
full = jnp.float32
# Create a mixed precision policy
# Parameters stored in full precision, computation and output in half precision
policy = jmp.Policy(param_dtype=full, compute_dtype=half, output_dtype=half)
# Example: Applying policy to a JAX array
x = jnp.array([1.0, 2.0, 3.0], dtype=full)
x_half = policy.cast_to_compute(x)
print(f"Original (full): {x.dtype}, {x}")
print(f"Computed (half): {x_half.dtype}, {x_half}")
# Example: Using DynamicLossScale
# Initialize DynamicLossScale with an initial loss scale value
loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2**15)) # Using a jmp dtype helper
# Simulate a gradient check and adjustment
grads_finite = True # Assume gradients were finite in this step
loss_scale = loss_scale.adjust(grads_finite)
print(f"Adjusted loss scale: {loss_scale.loss_scale}")