MLX
MLX is an array framework for machine learning on Apple Silicon, developed by Apple machine learning research. It provides a Python API that closely follows NumPy for array operations and higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs that are similar to PyTorch. MLX supports composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization. Key features include lazy computation, dynamic graph construction, and a unified memory model across CPU and GPU. It is actively developed with frequent releases, currently at version 0.31.1.
Warnings
- gotcha MLX is primarily optimized for Apple Silicon. While CUDA backend support is in active development, performance and feature parity on non-Apple hardware (e.g., NVIDIA GPUs on Linux) might be limited or in beta. Expect potential performance differences and missing operations.
- gotcha `mlx.nn.Conv2d` expects input images in NHWC (batch, height, width, channels) format, which differs from the NCHW (batch, channels, height, width) format commonly used by frameworks like PyTorch. This requires transposing input tensors when porting models or using data loaders that provide NCHW data.
- gotcha MLX uses lazy computation, meaning operations build a computation graph but do not execute immediately. Arrays are only materialized when their values are needed (e.g., printed, converted to NumPy, or explicitly evaluated with `mx.eval()`). Users expecting immediate execution might be surprised by this behavior.
- gotcha When using function transformations like `mlx.core.value_and_grad()`, they operate on 'pure' functions. This means that model parameters or other state should be explicitly passed as arguments to the function being transformed, rather than relying on global mutable state or attributes, to ensure correct gradient computation and graph optimization.
- gotcha For optimal performance on Apple Silicon, it is crucial to use a native `arm` Python environment. Running MLX in an `x86` Python environment via Rosetta emulation will result in significantly degraded performance.
Install
-
pip install mlx
Imports
- mlx.core
import mlx.core as mx
- mlx.nn
import mlx.nn as nn
- mlx.optimizers
import mlx.optimizers as optim
Quickstart
import mlx.core as mx
# Create an MLX array
a = mx.array([1, 2, 3], mx.float32)
print(f"Array a: {a}, dtype: {a.dtype}, shape: {a.shape}")
# Perform an operation (e.g., exponential)
b = mx.exp(a)
print(f"Array b (exp(a)): {b}")
# Perform a matrix multiplication
c = mx.array([[1, 2], [3, 4]])
d = mx.array([[5, 6], [7, 8]])
e = mx.matmul(c, d)
print(f"Matrix c: {c}\nMatrix d: {d}\nMatrix product c @ d: {e}")
# Ensure computation is materialized (for lazy operations)
mx.eval(e)
print("Computation materialized.")