S3TorchConnector

1.5.0 · active · verified Thu Apr 16

S3TorchConnector provides an efficient integration for PyTorch `Dataset` and `DataLoader` to stream data directly from Amazon S3. It enables training machine learning models on S3-resident data without needing to download it locally, optimized for large-scale and distributed workloads. The current version is 1.5.0, with releases typically aligning with PyTorch and AWS SDK updates.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use both `S3MapDataset` and `S3IterableDataset` with PyTorch's `DataLoader`. It shows how to initialize them with an S3 URI, iterate through data, and includes error handling for common S3 access issues. Ensure your AWS credentials are configured and the specified S3 bucket/prefix contains data.

import torch
from torch.utils.data import DataLoader
from s3torchconnector import S3MapDataset, S3IterableDataset
import os
import io

# --- Configuration for S3 Access ---
# Ensure your AWS credentials are configured (e.g., via AWS CLI, environment variables, or IAM roles).
# Example environment variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION.
#
# IMPORTANT: Replace 's3torchconnector-example-bucket' and 'quickstart-data/' with your
# actual S3 bucket and prefix. The bucket should contain some files (e.g., text files)
# for the example to successfully load data.
S3_BUCKET = os.environ.get('S3_QUICKSTART_BUCKET', 's3torchconnector-example-bucket')
S3_PREFIX = os.environ.get('S3_QUICKSTART_PREFIX', 'quickstart-data/')
S3_URI = f"s3://{S3_BUCKET}/{S3_PREFIX}"

print(f"Attempting to connect to S3 URI: {S3_URI}")
print("Please ensure this bucket/prefix exists and contains data, and your AWS credentials are configured.")
print("If you encounter 'Forbidden' or 'NoCredentialsError', check your AWS setup.")

# --- S3MapDataset Example ---
# S3MapDataset first lists all objects under the given S3 URI prefix, then allows indexed access.
# Suitable when you need a fixed-size dataset and random access.
try:
    print("\n--- S3MapDataset Demonstration ---")
    map_dataset = S3MapDataset(S3_URI)
    print(f"S3MapDataset initialized. Found {len(map_dataset)} objects.")

    if len(map_dataset) > 0:
        # Accessing an item by index
        item_data = map_dataset[0] # Returns a file-like object (BytesIO by default)
        if isinstance(item_data, io.BytesIO):
            content_sample = item_data.read(100).decode('utf-8', errors='ignore') # Read first 100 bytes
            print(f"Sample from first item (S3MapDataset): '{content_sample}'...")
        else:
            print(f"First item type: {type(item_data)}")

        # Using DataLoader with S3MapDataset
        map_dataloader = DataLoader(map_dataset, batch_size=2, num_workers=0) # num_workers=0 for simplicity
        print("Iterating through S3MapDataset with DataLoader:")
        for i, batch in enumerate(map_dataloader):
            print(f"MapDataset Batch {i}: {len(batch)} items.")
            if len(batch) > 0 and isinstance(batch[0], io.BytesIO):
                print(f"  First item in batch content sample: {batch[0].read(30).decode('utf-8', errors='ignore')}...")
            if i >= 1: # Limit iterations for a quick example
                break
    else:
        print("S3MapDataset found no objects. Please ensure your S3 bucket/prefix contains files.")

except Exception as e:
    print(f"Error during S3MapDataset example: {e}")
    print("Ensure AWS credentials are valid and the S3 path exists and is accessible.")

# --- S3IterableDataset Example ---
# S3IterableDataset streams objects one by one as they are iterated.
# Suitable for very large datasets where listing all objects upfront is too slow or memory intensive.
try:
    print("\n--- S3IterableDataset Demonstration ---")
    iterable_dataset = S3IterableDataset(S3_URI)

    # Using DataLoader with S3IterableDataset
    # For num_workers > 0, consider using a `worker_init_fn` for proper distributed data loading.
    iterable_dataloader = DataLoader(iterable_dataset, batch_size=2, num_workers=0)
    print("Iterating through S3IterableDataset with DataLoader:")
    for i, batch in enumerate(iterable_dataloader):
        print(f"IterableDataset Batch {i}: {len(batch)} items.")
        if len(batch) > 0 and isinstance(batch[0], io.BytesIO):
            print(f"  First item in batch content sample: {batch[0].read(30).decode('utf-8', errors='ignore')}...")
        if i >= 1: # Limit iterations
            break
    print("S3IterableDataset iteration complete (limited for quickstart).")

except Exception as e:
    print(f"Error during S3IterableDataset example: {e}")
    print("Ensure AWS credentials are valid and the S3 path exists and is accessible.")

print("\nQuickstart examples concluded.")

view raw JSON →