TorchData
TorchData is a Python library providing composable data loading modules for PyTorch, aiming to enhance `torch.utils.data.DataLoader` and `torch.utils.data.Dataset/IterableDataset` for scalable and performant data pipelines. It focuses on new features like `StatefulDataLoader` for checkpointing and `torchdata.nodes` for flexible data processing graphs. The current version is 0.11.0. After a period of re-evaluation, development has resumed with a focus on iterative enhancements to existing PyTorch data primitives.
Warnings
- breaking DataPipes and DataLoader2, which were core components of earlier TorchData versions, have been largely removed from the library starting with version 0.9.0. They were marked as deprecated in v0.8.0. Subsequent releases, including 0.11.0, do not include or maintain these solutions.
- breaking Python 3.8 support was dropped in TorchData v0.9.0.
- deprecated TorchData has deprecated and removed its conda builds, as PyTorch's official conda channel itself is deprecated.
- gotcha Be aware of specific behaviors in `StatefulDataLoader` related to `num_workers=0` and initial seeding for `RandomSampler` during state loading. These can lead to unexpected iteration patterns if not handled carefully.
Install
-
pip install torchdata
Imports
- StatefulDataLoader
from torchdata.stateful_dataloader import StatefulDataLoader
- nodes
import torchdata.nodes as nodes
- DataPipes
from torchdata.datapipes.iter import IterableWrapper
- DataLoader2
from torchdata.dataloader2 import DataLoader2
Quickstart
import torch
from torch.utils.data import TensorDataset
from torchdata.stateful_dataloader import StatefulDataLoader
# Create a dummy dataset
data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))
dataset = TensorDataset(data, labels)
# Use StatefulDataLoader as a drop-in replacement for torch.utils.data.DataLoader
batch_size = 16
dataloader = StatefulDataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0 # For simplicity, use 0 workers
)
print(f"Number of batches: {len(dataloader)}")
# Iterate through the data
for epoch in range(2):
print(f"\nEpoch {epoch + 1}")
for i, (batch_data, batch_labels) in enumerate(dataloader):
if i % 10 == 0:
print(f" Batch {i}: data_shape={batch_data.shape}, labels_shape={batch_labels.shape}")
# In a real scenario, perform training steps here
# Example of saving and loading state (checkpointing)
# This is a key feature of StatefulDataLoader
state = dataloader.state_dict()
print(f"\nSaved dataloader state: {state.keys()}")
# Simulate continued training or restart
new_dataloader = StatefulDataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
new_dataloader.load_state_dict(state)
print("Loaded dataloader state.")
# Iteration will resume from where it left off
print("Resuming iteration (should continue from saved state):")
for i, (batch_data, batch_labels) in enumerate(new_dataloader):
if i < 3:
print(f" Resumed Batch {i}: data_shape={batch_data.shape}, labels_shape={batch_labels.shape}")