ML-Dtypes
raw JSON → 0.5.4 verified Tue May 12 auth: no python install: draft quickstart: draft
ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries. These include bfloat16, various 8-bit, 6-bit, 4-bit floating point representations, and narrow integer types (int1, int2, int4, uint1, uint2, uint4). It is currently at version 0.5.4 and receives regular updates, primarily driven by its use in projects like JAX.
pip install ml-dtypes Common errors
error AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11' ↓
cause This error typically occurs when a specific dtype, like 'float8_e4m3b11' (or 'int2', 'float4_e2m1fn', etc.), is accessed from the `ml_dtypes` module but is either missing, renamed, or not available in the installed version of the library, often due to a version mismatch with consuming libraries like JAX or TensorFlow.
fix
Ensure
ml-dtypes and its dependencies (like JAX and NumPy) are compatible versions. Upgrade ml-dtypes to the latest version (pip install --upgrade ml-dtypes) or downgrade to a version known to be compatible with your other libraries. Refer to the documentation or release notes for exact dtype availability. error TypeError: ufunc 'isnan' not supported for the input types ↓
cause This error arises when NumPy's `isnan` ufunc is called with `ml_dtypes` types (e.g., `float8_e8m0fnu`), indicating an incompatibility, typically with older versions of NumPy that do not fully support these custom data types.
fix
Upgrade NumPy to a more recent version (e.g., NumPy 2.x or later) that offers improved compatibility with
ml_dtypes' custom float types. Run pip install --upgrade numpy. error INFO: pip is looking at multiple versions of ml-dtypes to determine which version ↓
cause This informational message, often preceding a build failure, indicates that `pip` is struggling to resolve conflicting dependency requirements for `ml-dtypes` among various installed or requested packages (e.g., TensorFlow, JAX, and NumPy), leading to an inability to find a compatible set of versions.
fix
Manage your Python environment carefully using virtual environments. Explicitly pin compatible versions of
ml-dtypes, numpy, JAX, and TensorFlow based on their official documentation to avoid conflicts. For example: pip install ml-dtypes==0.5.4 numpy==1.26.4 jax==0.4.23. error np.ndarray of bfloat16 using ml_dtypes is being interpreted as complex64 ↓
cause This is a reported bug where a consuming library (e.g., `mlx.core`) incorrectly interprets `ml_dtypes.bfloat16` NumPy arrays as `complex64` instead of the intended bfloat16 type.
fix
Check for updates to the consuming library (e.g.,
mlx.core) or ml-dtypes itself, as this is likely a bug fix that will be addressed in a new release. If a fix isn't available, consider explicit type casting in your code if the consuming library provides such an option, or report the issue to the respective library's maintainers. Warnings
breaking Values pickled with previous versions of `ml_dtypes` (prior to 0.5.4, due to `NPY_NEEDS_PYAPI` removal from dtype flags) are incompatible with the current release. These values must be regenerated with `ml_dtypes>=0.5.4`. ↓
fix Regenerate any pickled `ml_dtypes` values using the current version of the library.
deprecated The `float8_e4m3b11` dtype was deprecated starting from version 0.3.0. This change caused `AttributeError` issues in older versions of dependent libraries like JAX and TensorFlow that expected its presence. ↓
fix Update dependent libraries (e.g., JAX, TensorFlow) to versions compatible with `ml_dtypes` >= 0.3.0, or pin `ml_dtypes` to 0.2.0 if an older dependency is strictly required.
gotcha The narrow integer types (e.g., `int2`, `int4`, `uint2`, `uint4`) are implemented as 'unpacked' representations. Each element is padded up to a byte in memory because NumPy does not natively support types smaller than a single byte. This means these types may consume more memory than their bit-width suggests if memory layout is not carefully considered. ↓
fix Be aware of the unpacked representation for narrow integer types, especially when memory efficiency is critical. The lower bits store the value, while the upper bits are ignored.
gotcha There is an open proposal to deprecate the creation of `ml_dtypes` dtypes via string names using `np.dtype('typename')` in a future release, potentially in favor of direct imports. While currently supported, users should be aware of this potential future change. ↓
fix Prefer direct imports (e.g., `from ml_dtypes import bfloat16`) over string-based `np.dtype` creation where possible, to future-proof your code against potential deprecations.
breaking Building `ml-dtypes` requires a C++ compiler (e.g., g++) for its C++ extensions. The installation will fail with 'command 'g++' failed: No such file or directory' if a suitable compiler is not present in the environment. ↓
fix Ensure a C++ compiler (such as g++ or clang++) is installed in the build environment. For Alpine Linux, this can be resolved by installing the `build-base` or `g++` package.
Install compatibility draft last tested: 2026-05-12
python os / libc status wheel install import disk
3.10 alpine (musl) build_error - - - -
3.10 alpine (musl) - - - -
3.10 slim (glibc) wheel 3.8s 0.21s 110M
3.10 slim (glibc) - - 0.17s 110M
3.11 alpine (musl) build_error - - - -
3.11 alpine (musl) - - - -
3.11 slim (glibc) wheel 3.7s 0.33s 116M
3.11 slim (glibc) - - 0.27s 116M
3.12 alpine (musl) build_error - - - -
3.12 alpine (musl) - - - -
3.12 slim (glibc) wheel 3.6s 0.37s 105M
3.12 slim (glibc) - - 0.28s 105M
3.13 alpine (musl) build_error - - - -
3.13 alpine (musl) - - - -
3.13 slim (glibc) wheel 3.9s 0.27s 104M
3.13 slim (glibc) - - 0.29s 104M
3.9 alpine (musl) build_error - - - -
3.9 alpine (musl) - - - -
3.9 slim (glibc) wheel 4.4s 0.27s 120M
3.9 slim (glibc) - - - -
Imports
- bfloat16
from ml_dtypes import bfloat16 - ml_dtypes module (for registration)
import ml_dtypes import numpy as np np.dtype('bfloat16')
Quickstart draft last tested: 2026-04-24
import numpy as np
from ml_dtypes import bfloat16, float8_e5m2
# Create an array with bfloat16 dtype
b_array = np.zeros(4, dtype=bfloat16)
print(f"bfloat16 array: {b_array}, dtype: {b_array.dtype}")
# Create an array using a string name (after ml_dtypes import registers it)
f8_array = np.array([0.5, 1.0, 1.5, 2.0], dtype='float8_e5m2')
print(f"float8_e5m2 array: {f8_array}, dtype: {f8_array.dtype}")
# Perform a basic operation
sum_f8 = np.sum(f8_array)
print(f"Sum of float8_e5m2 array: {sum_f8}, dtype: {type(sum_f8)}")