ML-Dtypes
ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries. These include bfloat16, various 8-bit, 6-bit, 4-bit floating point representations, and narrow integer types (int1, int2, int4, uint1, uint2, uint4). It is currently at version 0.5.4 and receives regular updates, primarily driven by its use in projects like JAX.
Warnings
- breaking Values pickled with previous versions of `ml_dtypes` (prior to 0.5.4, due to `NPY_NEEDS_PYAPI` removal from dtype flags) are incompatible with the current release. These values must be regenerated with `ml_dtypes>=0.5.4`.
- deprecated The `float8_e4m3b11` dtype was deprecated starting from version 0.3.0. This change caused `AttributeError` issues in older versions of dependent libraries like JAX and TensorFlow that expected its presence.
- gotcha The narrow integer types (e.g., `int2`, `int4`, `uint2`, `uint4`) are implemented as 'unpacked' representations. Each element is padded up to a byte in memory because NumPy does not natively support types smaller than a single byte. This means these types may consume more memory than their bit-width suggests if memory layout is not carefully considered.
- gotcha There is an open proposal to deprecate the creation of `ml_dtypes` dtypes via string names using `np.dtype('typename')` in a future release, potentially in favor of direct imports. While currently supported, users should be aware of this potential future change.
Install
-
pip install ml-dtypes
Imports
- bfloat16
from ml_dtypes import bfloat16
- ml_dtypes module (for registration)
import ml_dtypes import numpy as np np.dtype('bfloat16')
Quickstart
import numpy as np
from ml_dtypes import bfloat16, float8_e5m2
# Create an array with bfloat16 dtype
b_array = np.zeros(4, dtype=bfloat16)
print(f"bfloat16 array: {b_array}, dtype: {b_array.dtype}")
# Create an array using a string name (after ml_dtypes import registers it)
f8_array = np.array([0.5, 1.0, 1.5, 2.0], dtype='float8_e5m2')
print(f"float8_e5m2 array: {f8_array}, dtype: {f8_array.dtype}")
# Perform a basic operation
sum_f8 = np.sum(f8_array)
print(f"Sum of float8_e5m2 array: {sum_f8}, dtype: {type(sum_f8)}")