TensorFlow Transform

1.17.0 · active · verified Fri Apr 17

TensorFlow Transform (TFT) is a library for preprocessing data with TensorFlow. It allows users to define a preprocessing function that is applied to raw data *before* training, and then export this function as a TensorFlow graph that can be used for *inference*. This ensures consistency between training and serving. It's often used in conjunction with Apache Beam for distributed processing and is a key component of TensorFlow Extended (TFX). The current version is 1.17.0, following TensorFlow's release cadence with frequent updates.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `tensorflow-transform` to preprocess a small dataset locally. It defines a `preprocessing_fn` to scale numerical features and one-hot encode categorical features, then applies it using Apache Beam's DirectRunner.

import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
import apache_beam as beam
from apache_beam.runners.direct import direct_runner

# 1. Define the schema of the raw data
_RAW_DATA_FEATURE_SPEC = {
    'x': tf.io.FixedLenFeature([], tf.float32),
    'y': tf.io.FixedLenFeature([], tf.string),
    's': tf.io.FixedLenFeature([], tf.float32, default_value=0.0)
}
_RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec(_RAW_DATA_FEATURE_SPEC))

# 2. Define the preprocessing function
def preprocessing_fn(inputs):
    """Preprocesses raw inputs into transformed features."""
    outputs = {}
    outputs['x_scaled'] = tft.scale_to_z_score(inputs['x'])
    outputs['y_one_hot'] = tft.one_hot(
        tft.string_to_int(inputs['y'], vocab_filename='vocab_y'), num_buckets=3
    ) # Assuming max 3 unique values for y
    outputs['s_identity'] = inputs['s'] # Pass through
    return outputs

# 3. Prepare some raw data
raw_data = [
    {'x': 10.0, 'y': 'apple', 's': 1.0},
    {'x': 20.0, 'y': 'banana', 's': 2.0},
    {'x': 30.0, 'y': 'apple', 's': 3.0},
    {'x': 40.0, 'y': 'orange', 's': 4.0},
    {'x': 50.0, 'y': 'banana', 's': 5.0},
]

# 4. Run the transform locally using Apache Beam DirectRunner
with beam.Pipeline(runner=direct_runner.DirectRunner()) as p:
    # Create a PCollection of raw data
    raw_data_pcollection = (
        p
        | 'CreateRawData' >> beam.Create(raw_data)
    )

    # Apply the transform: Analyze (compute stats) and Transform (apply changes)
    transformed_data_pcollection, transform_fn = (
        (raw_data_pcollection, _RAW_DATA_METADATA)
        | 'AnalyzeAndTransform' >> tft.beam.AnalyzeAndTransformDataset(preprocessing_fn)
    )

    # Collect and print the transformed data
    print('Transformed data:')
    _ = (
        transformed_data_pcollection
        | 'PrintTransformedData' >> beam.Map(print)
    )

print('Preprocessing complete.')

view raw JSON →