TensorFlow I/O

0.37.1 · active · verified Wed Apr 15

TensorFlow I/O is an extension library that provides a collection of file systems and file formats not natively available in TensorFlow's built-in support. It enhances data input/output capabilities for machine learning workflows, integrating seamlessly with TensorFlow's `tf.data` API. The current version is 0.37.1 and it maintains a frequent release cadence, often with monthly or bi-monthly updates to support new TensorFlow versions and features.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates loading the MNIST dataset directly from compressed URLs using `tfio.IODataset.from_mnist`, showcasing TensorFlow I/O's ability to handle remote filesystems and various data formats. The data is then preprocessed and used to train a simple Keras model.

import tensorflow as tf
import tensorflow_io as tfio

# Read the MNIST data into an IODataset directly from URLs
dataset_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
d_train = tfio.IODataset.from_mnist(
    dataset_url + "train-images-idx3-ubyte.gz",
    dataset_url + "train-labels-idx1-ubyte.gz",
)

# Shuffle the elements of the dataset.
d_train = d_train.shuffle(buffer_size=1024)

# By default image data is uint8, so convert to float32 using map().
d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y))

# Prepare batches the data just like any other tf.data.Dataset
d_train = d_train.batch(32)

# Build a simple Keras model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax),
])

# Compile and fit the model (example uses a small number of steps for brevity)
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("Fitting model...")
# Use a small number of steps per epoch for quick demonstration
model.fit(d_train, epochs=1, steps_per_epoch=2)
print("Model fitting complete.")

view raw JSON →