JMP (JAX Mixed Precision)

0.0.4 · active · verified Tue Apr 14

JMP is a Python library from DeepMind that provides abstractions for mixed precision training in JAX. It enables the use of full and half-precision floating-point numbers during model training to reduce memory bandwidth and improve computational efficiency. It is currently at version 0.0.4 and sees active development with new releases addressing JAX compatibility and feature enhancements.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a mixed precision policy using `jmp.Policy` and apply it to JAX arrays. It also shows the basic usage of `jmp.DynamicLossScale` for adjusting the loss scale during training based on gradient finiteness.

import jax
import jax.numpy as jnp
import jmp

# Define floating point types based on your hardware (e.g., bfloat16 for TPU, float16 for GPU)
half = jnp.float16 # or jnp.bfloat16 for TPUs
full = jnp.float32

# Create a mixed precision policy
# Parameters stored in full precision, computation and output in half precision
policy = jmp.Policy(param_dtype=full, compute_dtype=half, output_dtype=half)

# Example: Applying policy to a JAX array
x = jnp.array([1.0, 2.0, 3.0], dtype=full)
x_half = policy.cast_to_compute(x)
print(f"Original (full): {x.dtype}, {x}")
print(f"Computed (half): {x_half.dtype}, {x_half}")

# Example: Using DynamicLossScale
# Initialize DynamicLossScale with an initial loss scale value
loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2**15)) # Using a jmp dtype helper

# Simulate a gradient check and adjustment
grads_finite = True # Assume gradients were finite in this step
loss_scale = loss_scale.adjust(grads_finite)
print(f"Adjusted loss scale: {loss_scale.loss_scale}")

view raw JSON →