MosaicML Streaming

0.13.0 · active · verified Wed Apr 15

MosaicML Streaming (StreamingDataset) provides PyTorch-compatible datasets that can be efficiently streamed from cloud-based object stores (S3, GCS, Azure Blob Storage, Hugging Face Hub) or local filesystems. It enables training on large datasets without needing to download them entirely beforehand, improving data loading performance and reducing storage costs. The library is actively maintained with frequent updates, currently at version 0.13.0.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize `StreamingDataset` and integrate it with `torch.utils.data.DataLoader`. It sets up a minimal local MDS dataset for immediate testing. For cloud usage, ensure the `remote` parameter points to your cloud object storage path and that necessary cloud provider credentials are correctly configured in your environment.

import os
import torch
from streaming import StreamingDataset
from torch.utils.data import DataLoader
import json

# Define local paths for quickstart demonstration
# In a real scenario, 'remote' would point to your cloud MDS dataset
# (e.g., "s3://my-bucket/data" or "gs://my-bucket/data").
# Ensure cloud credentials are set in environment variables for cloud remotes.
local_remote_path = "quickstart_mds_data"
local_cache_path = "quickstart_mds_cache"

# --- Create a dummy MDS dataset for local testing if it doesn't exist ---
# For actual use, you'd generate MDS datasets with `streaming.MDSWriter`
# or point to existing ones in cloud storage.
if not os.path.exists(local_remote_path):
    print(f"Creating dummy MDS data in '{local_remote_path}'...")
    os.makedirs(local_remote_path)
    # A minimal `index.json` is required by StreamingDataset
    index_data = {
        "version": 2,
        "shards": [
            {"shard": 0, "samples": 2, "hash": "dummy_hash", "size": 100, 
             "raw_data_size": 80, "zip_data_size": 20, "compression": None, 
             "format": None}
        ]
    }
    with open(os.path.join(local_remote_path, 'index.json'), 'w') as f:
        json.dump(index_data, f)
    # A minimal shard file is also expected, though its content won't be processed
    # in this simplified example without actual schema.
    with open(os.path.join(local_remote_path, '00000.mds'), 'wb') as f:
        f.write(b'dummy_data_content_for_shard_0')
    print("Dummy MDS data created.")
else:
    print(f"Using existing dummy MDS data in '{local_remote_path}'.")

os.makedirs(local_cache_path, exist_ok=True)
# --- End of dummy MDS creation ---

# 1. Initialize the StreamingDataset
dataset = StreamingDataset(
    local=local_cache_path,  # Local cache directory for downloaded shards
    remote=local_remote_path, # Path to your MDS dataset (local or cloud)
    shuffle=True,
    batch_size=1, # Adjust batch size for internal buffering
    # Other parameters like `predownload` can be tuned for performance
)

# 2. Create a PyTorch DataLoader
dataloader = DataLoader(
    dataset=dataset,
    batch_size=1, # DataLoader batch size
    num_workers=0, # Use 0 workers for simple local testing to avoid multiprocessing issues
)

# 3. Iterate over the data
print(f"Dataset has {len(dataset)} samples.")
for i, batch in enumerate(dataloader):
    # In this dummy setup, 'batch' will be raw bytes as no actual data schema is defined.
    # With a real MDS dataset, this would be structured data (e.g., dicts, tensors).
    print(f"Batch {i}: {batch}")
    if i >= 1: # Process a few batches for demonstration
        break

# Note: For production use, remember to configure cloud credentials
# (e.g., via environment variables like AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY,
# or cloud provider CLI configs) if 'remote' points to cloud storage.

view raw JSON →