Grain (ML Data Library)

0.2.16 · active · verified Sat Apr 11

Grain is a Python library from Google designed for efficiently loading and transforming data, primarily for machine learning model training and evaluation, particularly with JAX. It emphasizes flexibility, speed, and determinism in data processing pipelines. The library is actively developed, currently at version 0.2.16, with frequent updates including new features, bug fixes, and deprecations.

Warnings

Install

Imports

Quickstart

This example demonstrates how to create a simple `MapDataset` from a list, apply common transformations like shuffling, mapping, and batching, and then iterate through the processed data. It showcases the declarative chaining API for data pipeline construction.

import grain

dataset = (
    grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    .shuffle(seed=42) # Shuffles elements globally.
    .map(lambda x: x + 1) # Maps each element.
    .batch(batch_size=2) # Batches consecutive elements.
)

print("Processing dataset:")
for batch in dataset:
    print(batch)

view raw JSON →