ML-Dtypes

0.5.4 · active · verified Sun Mar 29

ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries. These include bfloat16, various 8-bit, 6-bit, 4-bit floating point representations, and narrow integer types (int1, int2, int4, uint1, uint2, uint4). It is currently at version 0.5.4 and receives regular updates, primarily driven by its use in projects like JAX.

Warnings

Install

Imports

Quickstart

This example demonstrates importing specific dtypes and creating NumPy arrays with them. It also shows how the types are registered with NumPy, allowing creation via string names. Note that basic NumPy operations are supported.

import numpy as np
from ml_dtypes import bfloat16, float8_e5m2

# Create an array with bfloat16 dtype
b_array = np.zeros(4, dtype=bfloat16)
print(f"bfloat16 array: {b_array}, dtype: {b_array.dtype}")

# Create an array using a string name (after ml_dtypes import registers it)
f8_array = np.array([0.5, 1.0, 1.5, 2.0], dtype='float8_e5m2')
print(f"float8_e5m2 array: {f8_array}, dtype: {f8_array.dtype}")

# Perform a basic operation
sum_f8 = np.sum(f8_array)
print(f"Sum of float8_e5m2 array: {sum_f8}, dtype: {type(sum_f8)}")

view raw JSON →