WebDataset

1.0.2 · active · verified Sat Apr 11

WebDataset is a high-performance Python-based I/O system for deep learning and data processing, current version 1.0.2. It implements the PyTorch IterableDataset interface, enabling efficient streaming access to datasets stored in POSIX tar archives. It supports sharding for large datasets and is compatible with PyTorch's DataLoader, facilitating scalable and latency-insensitive data pipelines for various data types including images, audio, and video. The library is actively maintained with frequent releases adding new features and bug fixes.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to create a `webdataset` pipeline to load data from a remote TAR file, apply shuffling and decoding, extract specific components like images and JSON metadata, and then preprocess and batch the samples. It shows the typical 'fluid' interface with chained method calls and how it integrates with PyTorch-style data iteration. It fetches from a publicly available OpenImages shard, decodes using PIL, and extracts components into tuples.

import webdataset as wds
import torch
import os
from itertools import islice

# Example URL to a public WebDataset shard. In a real scenario, this would be your dataset path(s).
# For local files: url = "file:./my_dataset-{0000..0009}.tar"
# For cloud storage: url = "pipe:gsutil cat gs://my-bucket/dataset-{0000..0009}.tar"
url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar"

# Define a simple preprocessing function (e.g., for images and labels)
def preprocess(sample):
    # Assuming 'jpg' for image and 'json' for metadata (e.g., labels)
    image = sample['jpg']
    metadata = sample.get('json')
    
    # Example: convert image to PyTorch tensor and extract a dummy label
    # In a real scenario, you'd decode and transform the image bytes properly
    # For this example, we'll just return a placeholder tensor and label
    # (webdataset.decode() would handle actual image decoding)
    
    # If actual image decoding is not done yet, 'image' might be bytes.
    # For a quickstart without full image processing libs, mock it:
    if isinstance(image, bytes):
        # Mock a tensor, in a real app, use PIL/torchvision transforms
        processed_image = torch.randn(3, 224, 224) # e.g., C, H, W
    else:
        processed_image = image # If decode() was used earlier

    # Extract a dummy label from metadata, or just use a placeholder
    label = 0 # Placeholder label
    if metadata and isinstance(metadata, dict) and 'annotations' in metadata:
        try:
            label = metadata['annotations'][0]['category_id']
        except (IndexError, KeyError):
            pass

    return processed_image, label

# Create a WebDataset pipeline
dataset = (
    wds.WebDataset(url) # Load from URL
    .shuffle(100)        # Shuffle samples within a buffer
    .decode("pil")       # Decode images using PIL (requires Pillow installed)
    .to_tuple("jpg", "json") # Extract 'jpg' and 'json' components as a tuple
    .map(preprocess)    # Apply custom preprocessing
    .batched(16)        # Batch samples
)

# Use with PyTorch DataLoader (optional, for parallel loading and iteration)
# If you don't use PyTorch, you can iterate directly over 'dataset'
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset, num_workers=4, batch_size=None) # batch_size=None if .batched() is used above

print(f"Accessing the first 2 batches from: {url}")

# Iterate over a few batches
for i, (images, labels) in enumerate(islice(dataset, 2)):
    print(f"Batch {i+1}:")
    print(f"  Images shape: {images.shape}")
    print(f"  Labels: {labels}")
    if i == 1:
        break

print("Quickstart complete.")

view raw JSON →