TorchData

0.11.0 · active · verified Fri Apr 10

TorchData is a Python library providing composable data loading modules for PyTorch, aiming to enhance `torch.utils.data.DataLoader` and `torch.utils.data.Dataset/IterableDataset` for scalable and performant data pipelines. It focuses on new features like `StatefulDataLoader` for checkpointing and `torchdata.nodes` for flexible data processing graphs. The current version is 0.11.0. After a period of re-evaluation, development has resumed with a focus on iterative enhancements to existing PyTorch data primitives.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `StatefulDataLoader`, which is a key enhancement of `torch.utils.data.DataLoader` provided by TorchData. It's a drop-in replacement that adds checkpointing capabilities.

import torch
from torch.utils.data import TensorDataset
from torchdata.stateful_dataloader import StatefulDataLoader

# Create a dummy dataset
data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))
dataset = TensorDataset(data, labels)

# Use StatefulDataLoader as a drop-in replacement for torch.utils.data.DataLoader
batch_size = 16
dataloader = StatefulDataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0 # For simplicity, use 0 workers
)

print(f"Number of batches: {len(dataloader)}")

# Iterate through the data
for epoch in range(2):
    print(f"\nEpoch {epoch + 1}")
    for i, (batch_data, batch_labels) in enumerate(dataloader):
        if i % 10 == 0:
            print(f"  Batch {i}: data_shape={batch_data.shape}, labels_shape={batch_labels.shape}")
        # In a real scenario, perform training steps here

# Example of saving and loading state (checkpointing)
# This is a key feature of StatefulDataLoader
state = dataloader.state_dict()
print(f"\nSaved dataloader state: {state.keys()}")

# Simulate continued training or restart
new_dataloader = StatefulDataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
new_dataloader.load_state_dict(state)
print("Loaded dataloader state.")

# Iteration will resume from where it left off
print("Resuming iteration (should continue from saved state):")
for i, (batch_data, batch_labels) in enumerate(new_dataloader):
    if i < 3:
        print(f"  Resumed Batch {i}: data_shape={batch_data.shape}, labels_shape={batch_labels.shape}")

view raw JSON →