TensorFlow Datasets

4.9.9 · active · verified Sat Apr 11

TensorFlow Datasets (TFDS) is a library that provides a comprehensive collection of ready-to-use datasets for machine learning frameworks like TensorFlow, JAX, and PyTorch. It handles the complexities of downloading, preparing, and constructing data into `tf.data.Dataset` or `np.array` objects in a deterministic manner, enabling easy-to-use and high-performance input pipelines. The library maintains an active release cadence, with stable versions typically released every few months, alongside daily nightly builds.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to load the MNIST dataset using `tfds.load()`, retrieve training and testing splits, and configure a basic TensorFlow `tf.data.Dataset` input pipeline. It also shows how to inspect dataset metadata and iterate through a sample batch.

import tensorflow_datasets as tfds
import tensorflow as tf

# Load the MNIST dataset
# It will download and prepare the dataset if not already present.
(ds_train, ds_test), info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,  # Returns (image, label) tuples
    with_info=True
)

# Build your input pipeline
ds_train = ds_train.shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.batch(32).prefetch(tf.data.AUTOTUNE)

# Iterate and print a sample
print(f"Dataset info: {info.name} version {info.version}")
for image, label in ds_train.take(1):
    print(f"Sample image shape: {image.shape}, label: {label.numpy()}")

view raw JSON →