ML-Dtypes
ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries. These include bfloat16, various 8-bit, 6-bit, 4-bit floating point representations, and narrow integer types (int1, int2, int4, uint1, uint2, uint4). It is currently at version 0.5.4 and receives regular updates, primarily driven by its use in projects like JAX.
Common errors
-
AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'
cause This error typically occurs when a specific dtype, like 'float8_e4m3b11' (or 'int2', 'float4_e2m1fn', etc.), is accessed from the `ml_dtypes` module but is either missing, renamed, or not available in the installed version of the library, often due to a version mismatch with consuming libraries like JAX or TensorFlow.fixEnsure `ml-dtypes` and its dependencies (like JAX and NumPy) are compatible versions. Upgrade `ml-dtypes` to the latest version (`pip install --upgrade ml-dtypes`) or downgrade to a version known to be compatible with your other libraries. Refer to the documentation or release notes for exact dtype availability. -
TypeError: ufunc 'isnan' not supported for the input types
cause This error arises when NumPy's `isnan` ufunc is called with `ml_dtypes` types (e.g., `float8_e8m0fnu`), indicating an incompatibility, typically with older versions of NumPy that do not fully support these custom data types.fixUpgrade NumPy to a more recent version (e.g., NumPy 2.x or later) that offers improved compatibility with `ml_dtypes`' custom float types. Run `pip install --upgrade numpy`. -
INFO: pip is looking at multiple versions of ml-dtypes to determine which version
cause This informational message, often preceding a build failure, indicates that `pip` is struggling to resolve conflicting dependency requirements for `ml-dtypes` among various installed or requested packages (e.g., TensorFlow, JAX, and NumPy), leading to an inability to find a compatible set of versions.fixManage your Python environment carefully using virtual environments. Explicitly pin compatible versions of `ml-dtypes`, `numpy`, JAX, and TensorFlow based on their official documentation to avoid conflicts. For example: `pip install ml-dtypes==0.5.4 numpy==1.26.4 jax==0.4.23`. -
np.ndarray of bfloat16 using ml_dtypes is being interpreted as complex64
cause This is a reported bug where a consuming library (e.g., `mlx.core`) incorrectly interprets `ml_dtypes.bfloat16` NumPy arrays as `complex64` instead of the intended bfloat16 type.fixCheck for updates to the consuming library (e.g., `mlx.core`) or `ml-dtypes` itself, as this is likely a bug fix that will be addressed in a new release. If a fix isn't available, consider explicit type casting in your code if the consuming library provides such an option, or report the issue to the respective library's maintainers.
Warnings
- breaking Values pickled with previous versions of `ml_dtypes` (prior to 0.5.4, due to `NPY_NEEDS_PYAPI` removal from dtype flags) are incompatible with the current release. These values must be regenerated with `ml_dtypes>=0.5.4`.
- deprecated The `float8_e4m3b11` dtype was deprecated starting from version 0.3.0. This change caused `AttributeError` issues in older versions of dependent libraries like JAX and TensorFlow that expected its presence.
- gotcha The narrow integer types (e.g., `int2`, `int4`, `uint2`, `uint4`) are implemented as 'unpacked' representations. Each element is padded up to a byte in memory because NumPy does not natively support types smaller than a single byte. This means these types may consume more memory than their bit-width suggests if memory layout is not carefully considered.
- gotcha There is an open proposal to deprecate the creation of `ml_dtypes` dtypes via string names using `np.dtype('typename')` in a future release, potentially in favor of direct imports. While currently supported, users should be aware of this potential future change.
- breaking Building `ml-dtypes` requires a C++ compiler (e.g., g++) for its C++ extensions. The installation will fail with 'command 'g++' failed: No such file or directory' if a suitable compiler is not present in the environment.
Install
-
pip install ml-dtypes
Imports
- bfloat16
from ml_dtypes import bfloat16
- ml_dtypes module (for registration)
import ml_dtypes import numpy as np np.dtype('bfloat16')
Quickstart
import numpy as np
from ml_dtypes import bfloat16, float8_e5m2
# Create an array with bfloat16 dtype
b_array = np.zeros(4, dtype=bfloat16)
print(f"bfloat16 array: {b_array}, dtype: {b_array.dtype}")
# Create an array using a string name (after ml_dtypes import registers it)
f8_array = np.array([0.5, 1.0, 1.5, 2.0], dtype='float8_e5m2')
print(f"float8_e5m2 array: {f8_array}, dtype: {f8_array.dtype}")
# Perform a basic operation
sum_f8 = np.sum(f8_array)
print(f"Sum of float8_e5m2 array: {sum_f8}, dtype: {type(sum_f8)}")