WebDataset
WebDataset is a high-performance Python-based I/O system for deep learning and data processing, current version 1.0.2. It implements the PyTorch IterableDataset interface, enabling efficient streaming access to datasets stored in POSIX tar archives. It supports sharding for large datasets and is compatible with PyTorch's DataLoader, facilitating scalable and latency-insensitive data pipelines for various data types including images, audio, and video. The library is actively maintained with frequent releases adding new features and bug fixes.
Warnings
- gotcha WebDataset implements PyTorch's `IterableDataset` and thus does not provide a `__len__` method by default. Code expecting `len(dataset)` will raise a `TypeError`. To provide a length, you must explicitly add `with_length(N)` to your pipeline. This also impacts deterministic epoch balancing in distributed training.
- breaking Direct string arguments like `decode('PIL')` or `decode('numpy')` for decoding images were deprecated in favor of using actual functions (e.g., `decode(wds.decode('pil'))` or `decode('rgb')`, `decode('torchrgb')`). This change improves clarity and flexibility.
- gotcha Using the `pipe:` protocol with untrusted or unescaped URLs can lead to shell injection vulnerabilities, as `webdataset` executes shell commands.
- gotcha WebDataset relies heavily on external command-line tools like `curl`, `gsutil`, `aws`, and `file` for core I/O and type detection. This can affect portability across different operating systems or environments where these tools are not available or behave differently, and complicates error handling.
- gotcha Achieving precisely balanced epochs and avoiding sample repetition in multi-worker or distributed training setups (especially with `resampled=True` and shuffling) can be complex. Older usage of `repeat` argument might be outdated. Workers can endlessly repeat their shard shares without proper configuration.
- gotcha Long delays before the first batch, or inconsistent batch completion times, can occur due to large batch sizes, large shuffle buffers requiring time to fill, or slow underlying disk/storage access. This is often a configuration issue rather than a `webdataset` bug.
Install
-
pip install webdataset -
pip install git+https://github.com/webdataset/webdataset.git
Imports
- webdataset
import webdataset as wds
- WebDataset
dataset = wds.WebDataset(url)
Quickstart
import webdataset as wds
import torch
import os
from itertools import islice
# Example URL to a public WebDataset shard. In a real scenario, this would be your dataset path(s).
# For local files: url = "file:./my_dataset-{0000..0009}.tar"
# For cloud storage: url = "pipe:gsutil cat gs://my-bucket/dataset-{0000..0009}.tar"
url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar"
# Define a simple preprocessing function (e.g., for images and labels)
def preprocess(sample):
# Assuming 'jpg' for image and 'json' for metadata (e.g., labels)
image = sample['jpg']
metadata = sample.get('json')
# Example: convert image to PyTorch tensor and extract a dummy label
# In a real scenario, you'd decode and transform the image bytes properly
# For this example, we'll just return a placeholder tensor and label
# (webdataset.decode() would handle actual image decoding)
# If actual image decoding is not done yet, 'image' might be bytes.
# For a quickstart without full image processing libs, mock it:
if isinstance(image, bytes):
# Mock a tensor, in a real app, use PIL/torchvision transforms
processed_image = torch.randn(3, 224, 224) # e.g., C, H, W
else:
processed_image = image # If decode() was used earlier
# Extract a dummy label from metadata, or just use a placeholder
label = 0 # Placeholder label
if metadata and isinstance(metadata, dict) and 'annotations' in metadata:
try:
label = metadata['annotations'][0]['category_id']
except (IndexError, KeyError):
pass
return processed_image, label
# Create a WebDataset pipeline
dataset = (
wds.WebDataset(url) # Load from URL
.shuffle(100) # Shuffle samples within a buffer
.decode("pil") # Decode images using PIL (requires Pillow installed)
.to_tuple("jpg", "json") # Extract 'jpg' and 'json' components as a tuple
.map(preprocess) # Apply custom preprocessing
.batched(16) # Batch samples
)
# Use with PyTorch DataLoader (optional, for parallel loading and iteration)
# If you don't use PyTorch, you can iterate directly over 'dataset'
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset, num_workers=4, batch_size=None) # batch_size=None if .batched() is used above
print(f"Accessing the first 2 batches from: {url}")
# Iterate over a few batches
for i, (images, labels) in enumerate(islice(dataset, 2)):
print(f"Batch {i+1}:")
print(f" Images shape: {images.shape}")
print(f" Labels: {labels}")
if i == 1:
break
print("Quickstart complete.")