TensorFlow Transform
TensorFlow Transform (TFT) is a library for preprocessing data with TensorFlow. It allows users to define a preprocessing function that is applied to raw data *before* training, and then export this function as a TensorFlow graph that can be used for *inference*. This ensures consistency between training and serving. It's often used in conjunction with Apache Beam for distributed processing and is a key component of TensorFlow Extended (TFX). The current version is 1.17.0, following TensorFlow's release cadence with frequent updates.
Common errors
-
AttributeError: module 'tensorflow.compat.v2.io.gfile' has no attribute 'Exists'
cause This error often occurs when TensorFlow's `tf.io.gfile` (or `tf.compat.v2.io.gfile`) is used in a context where the underlying filesystem (e.g., local, GCS) is not properly initialized or accessible, or if the `tensorflow` version is mismatched with `tensorflow-transform` requirements.fixEnsure `tensorflow` and `tensorflow-transform` versions are compatible. Verify file paths are correct and accessible by the running user/service. For cloud storage, ensure appropriate authentication and permissions are set for the Beam runner. -
TypeError: unsupported operand type(s) for +: 'Tensor' and 'NoneType'
cause This typically happens inside `preprocessing_fn` if a feature is expected to be present but is missing in some input records, leading to a `None` value being passed to a TensorFlow operation that expects a `Tensor`.fixUse `tf.io.FixedLenFeature` with `default_value` when defining your `_RAW_DATA_FEATURE_SPEC` to handle missing values gracefully. Alternatively, use `tf.where` or `tf.cond` to handle `None` or empty tensors explicitly within your `preprocessing_fn`. -
tf.errors.FailedPreconditionError: Table not initialized.
cause This error often occurs when using `tft.string_to_int` or other vocabulary-based transformations, and the underlying lookup table (built from the vocabulary generated during the 'Analyze' phase) is not properly initialized before the 'Transform' phase or when attempting inference.fixEnsure the `AnalyzeAndTransformDataset` pipeline completes successfully, generating the `transform_fn` and the associated vocabulary. When serving, load the `transform_fn` correctly and ensure all assets (including vocabularies) are present and accessible in the exported `TransformGraph`.
Warnings
- breaking TensorFlow 2.x Compatibility: The `preprocessing_fn` passed to `AnalyzeAndTransformDataset` is traced into a TensorFlow graph, and for TFT versions >= 1.0, it runs in a TF2 context. Mixing TF1 `tf.compat.v1` APIs or session-based operations directly within `preprocessing_fn` can lead to errors.
- gotcha Two-Pass Transformation Model: TFT operates in two phases: 'Analyze' and 'Transform'. The 'Analyze' phase computes statistics (e.g., min/max for scaling, vocabulary for string-to-int) over the entire dataset. The 'Transform' phase then applies these computed statistics to individual data points. New users often misunderstand that `preprocessing_fn` is traced and executed as a graph, not a simple row-wise Python function.
- gotcha Apache Beam Integration for Scale: While TFT can run locally with Beam's `DirectRunner`, its primary use case is distributed processing with other Beam runners (e.g., Dataflow, Flink, Spark). Misconfiguring Beam runners, I/O connectors, or managing large datasets can be a source of errors and performance bottlenecks.
Install
-
pip install tensorflow-transform
Imports
- tensorflow_transform
import tensorflow_transform as tft
- tensorflow_transform.tf_utils
import tensorflow_transform.tf_utils as tf_utils
- tensorflow_transform.beam.impl
from tensorflow_transform.beam import impl as beam_impl
- DatasetMetadata
from tensorflow_transform.metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_metadata
Quickstart
import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
import apache_beam as beam
from apache_beam.runners.direct import direct_runner
# 1. Define the schema of the raw data
_RAW_DATA_FEATURE_SPEC = {
'x': tf.io.FixedLenFeature([], tf.float32),
'y': tf.io.FixedLenFeature([], tf.string),
's': tf.io.FixedLenFeature([], tf.float32, default_value=0.0)
}
_RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
schema_utils.schema_from_feature_spec(_RAW_DATA_FEATURE_SPEC))
# 2. Define the preprocessing function
def preprocessing_fn(inputs):
"""Preprocesses raw inputs into transformed features."""
outputs = {}
outputs['x_scaled'] = tft.scale_to_z_score(inputs['x'])
outputs['y_one_hot'] = tft.one_hot(
tft.string_to_int(inputs['y'], vocab_filename='vocab_y'), num_buckets=3
) # Assuming max 3 unique values for y
outputs['s_identity'] = inputs['s'] # Pass through
return outputs
# 3. Prepare some raw data
raw_data = [
{'x': 10.0, 'y': 'apple', 's': 1.0},
{'x': 20.0, 'y': 'banana', 's': 2.0},
{'x': 30.0, 'y': 'apple', 's': 3.0},
{'x': 40.0, 'y': 'orange', 's': 4.0},
{'x': 50.0, 'y': 'banana', 's': 5.0},
]
# 4. Run the transform locally using Apache Beam DirectRunner
with beam.Pipeline(runner=direct_runner.DirectRunner()) as p:
# Create a PCollection of raw data
raw_data_pcollection = (
p
| 'CreateRawData' >> beam.Create(raw_data)
)
# Apply the transform: Analyze (compute stats) and Transform (apply changes)
transformed_data_pcollection, transform_fn = (
(raw_data_pcollection, _RAW_DATA_METADATA)
| 'AnalyzeAndTransform' >> tft.beam.AnalyzeAndTransformDataset(preprocessing_fn)
)
# Collect and print the transformed data
print('Transformed data:')
_ = (
transformed_data_pcollection
| 'PrintTransformedData' >> beam.Map(print)
)
print('Preprocessing complete.')