Petastorm

0.13.1 · active · verified Thu Apr 16

Petastorm is a Python library that enables single-node or distributed training of machine learning models directly from datasets stored in Parquet format. It provides data access for popular frameworks like TensorFlow, PyTorch, and Apache Spark. The current stable version is 0.13.1, with releases typically following a feature-driven cadence, often including release candidates before stable versions.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a data schema, write sample data to a Parquet dataset using `make_writer`, and then read it back using `make_reader`. The example cleans up the temporary directory after execution. For real-world usage, consider configuring `reader_pool_type` and `num_epochs` based on your training requirements.

import os
import shutil
import numpy as np
from petastorm import make_reader, make_writer
from petastorm.unischema import Unischema, UnischemaField, ScalarCodec
from petastorm.codecs import CompressedNdarrayCodec

# 1. Define a schema for your data
MySchema = Unischema(
    'MySchema',
    [
        UnischemaField('id', np.int32, (), ScalarCodec(np.int32), False),
        UnischemaField('value', np.float64, (), ScalarCodec(np.float64), False),
        UnischemaField('image', np.uint8, (10, 10, 3), CompressedNdarrayCodec(), False),
    ]
)

# 2. Define a dataset path (using a temporary local directory for example)
dataset_url = 'file:///tmp/petastorm_example_data'
# Clean up previous data if it exists
if os.path.exists('/tmp/petastorm_example_data'):
    shutil.rmtree('/tmp/petastorm_example_data')

# 3. Write some dummy data to the Parquet dataset
print(f"Writing dummy data to {dataset_url}...")
with make_writer(dataset_url, MySchema, row_group_size_bytes=2 * 1024 * 1024) as writer:
    for i in range(10):
        writer.write(
            MySchema.make_row(
                id=i,
                value=float(i * 10),
                image=np.random.randint(0, 256, size=(10, 10, 3), dtype=np.uint8)
            )
        )
print(f"Successfully wrote 10 rows.")

# 4. Read data using make_reader
# reader_pool_type='thread' is often suitable for local development.
# For production, 'process' might be preferred depending on data access patterns.
print("\nReading data from the dataset:")
with make_reader(dataset_url, reader_pool_type='thread', num_epochs=1) as reader:
    for i, row in enumerate(reader):
        print(f"Row {i}: id={row.id}, value={row.value}, image_shape={row.image.shape}")
        if i >= 2: # Print only a few rows for brevity
            break
print("Finished reading example data.")

# Clean up the temporary dataset
shutil.rmtree('/tmp/petastorm_example_data')

view raw JSON →