MLX Language Models
mlx-lm provides tools for loading, fine-tuning, and generating text with Large Language Models (LLMs) on Apple Silicon, leveraging the MLX framework. It offers seamless integration with the Hugging Face Hub for model access. The library is actively developed, with frequent patch releases, currently at version 0.31.2.
Warnings
- gotcha MLX (and thus mlx-lm) is primarily optimized for Apple Silicon (macOS devices with M-series chips). While it can run on CPU, performance will be significantly slower, and larger models might exceed memory limits.
- breaking Compatibility with the `transformers` library can be sensitive across `mlx-lm` versions. Significant changes to `transformers` (e.g., transition to v5) have required corresponding `mlx-lm` updates.
- gotcha Not all Hugging Face models can be directly loaded or will perform optimally with `mlx-lm`. Many require a conversion step to the MLX format, especially for quantization or specific architectures.
- gotcha Batch generation and KV caching mechanisms have received numerous improvements and fixes across versions. Older versions might exhibit inefficiencies, incorrect behavior with varying prompt lengths, or issues with specific cache strategies.
Install
-
pip install mlx-lm
Imports
- load
from mlx_lm import load
- generate
from mlx_lm import generate
- convert
from mlx_lm.convert import convert
Quickstart
import mlx_lm as lm
# Load a model and its tokenizer from Hugging Face Hub (MLX community variants are optimized)
# Replace 'mlx-community/Phi-3-mini-4k-instruct-8bit' with your desired model
model, tokenizer = lm.load("mlx-community/Phi-3-mini-4k-instruct-8bit")
# Define a prompt for text generation
prompt_text = "Write a short story about a cat who learns to fly:"
# Generate text
response_stream = lm.generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_text,
verbose=False, # Set to True for detailed generation info
temp=0.7,
max_tokens=200,
stream=True # Stream tokens as they are generated
)
print("Generated text:")
for token in response_stream:
print(token, end="")
print()