Grain (ML Data Library)
Grain is a Python library from Google designed for efficiently loading and transforming data, primarily for machine learning model training and evaluation, particularly with JAX. It emphasizes flexibility, speed, and determinism in data processing pipelines. The library is actively developed, currently at version 0.2.16, with frequent updates including new features, bug fixes, and deprecations.
Warnings
- breaking Custom implementations of `RandomAccessDataSource` must now accept an `int` index in `__getitem__`. While legacy paths handling `SupportsIndex` still work at runtime, type checkers may flag errors. Switch to `int` for full compatibility.
- deprecated Support for Python 3.10 has been deprecated, and the library now requires Python >=3.11.
- deprecated Experimental APIs `grain.python.experimental.MultiprocessPrefetchIterDataset` and `grain.python.experimental.ConcatenateMapDataset` have been deprecated. Use their graduated versions `grain.IterDataset.mp_prefetch` and `grain.MapDataset.concatenate` respectively.
- gotcha When using Python multiprocessing for parallel data loading and transformations, all custom transformation functions (e.g., `MapTransform` subclasses) must be picklable. Non-picklable objects or closures can lead to errors during serialization.
- gotcha Choose between `MapDataset` and `IterDataset` based on access patterns. `MapDataset` supports efficient random access and is suitable for debugging or when order-dependent operations are needed. `IterDataset` (often created via `MapDataset.to_iter_dataset()`) is designed for performant, sequential iteration, typically used for training loops, especially with prefetching.
Install
-
pip install grain
Imports
- MapDataset
import grain dataset = grain.MapDataset.source([...])
- IterDataset
import grain iter_dataset = dataset.to_iter_dataset()
Quickstart
import grain
dataset = (
grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
.shuffle(seed=42) # Shuffles elements globally.
.map(lambda x: x + 1) # Maps each element.
.batch(batch_size=2) # Batches consecutive elements.
)
print("Processing dataset:")
for batch in dataset:
print(batch)