TensorFlow Model Analysis

0.48.0 · active · verified Fri Apr 17

TensorFlow Model Analysis (TFMA) is a library for performing deep analysis and evaluation of TensorFlow models, especially useful for understanding model performance on different data slices. It is built on Apache Beam, enabling scalable analysis. The current version is 0.48.0, and it follows a regular release cadence, often aligning with TensorFlow and TFX releases.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set up a dummy TensorFlow model, create synthetic TFRecord data, define an `EvalConfig`, and run a basic model analysis with TFMA locally. It shows how to obtain overall and sliced metrics.

import tensorflow as tf
import tensorflow_model_analysis as tfma
import os
import shutil

# 1. Create a simple Keras model and save it
# This model expects a 'feature_1' input and outputs 'prediction'.
class SimplePredictionModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(1, activation='sigmoid')

    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None], dtype=tf.float32, name='feature_1')
    ])
    def serving_default(self, feature_1):
        return {'predictions': self.dense(tf.expand_dims(feature_1, axis=-1))}

model = SimplePredictionModel()
# Initialize weights by calling the serving function once
_ = model.serving_default(tf.constant([1.0, 2.0]))

model_dir = '/tmp/tfma_quickstart_model'
if os.path.exists(model_dir): shutil.rmtree(model_dir)
tf.saved_model.save(model, model_dir, signatures={'serving_default': model.serving_default})

# 2. Create dummy data as TFRecord (TFMA expects tf.train.Example protos)
data_path = '/tmp/tfma_quickstart_data.tfrecord'
if os.path.exists(data_path): os.remove(data_path)

examples_proto = []
for i in range(10):
    feature_val = float(i)
    label_val = 1.0 if i % 2 == 0 else 0.0
    example = tf.train.Example(features=tf.train.Features(feature={
        'feature_1': tf.train.Feature(float_list=tf.train.FloatList(value=[feature_val])),
        'label': tf.train.Feature(float_list=tf.train.FloatList(value=[label_val])),
    }))
    examples_proto.append(example.SerializeToString())

with tf.io.TFRecordWriter(data_path) as writer:
    for ex in examples_proto:
        writer.write(ex)

# 3. Define EvalConfig
eval_config = tfma.EvalConfig(
    model_specs=[tfma.ModelSpec(
        signature_name='serving_default',
        label_key='label',
        prediction_key='predictions' # Key from model output dict
    )],
    metrics_specs=[
        tfma.MetricsSpec(
            metrics=[
                tfma.MetricConfig(class_name='ExampleCount'),
                tfma.MetricConfig(class_name='Accuracy')
            ]
        )
    ],
    slicing_specs=[
        tfma.SlicingSpec(), # Overall slice
        tfma.SlicingSpec(feature_keys=['feature_1']) # Slice by feature_1
    ]
)

# 4. Run Model Analysis
output_dir = '/tmp/tfma_quickstart_output'
if os.path.exists(output_dir): shutil.rmtree(output_dir)

print(f"Running TFMA with model: {model_dir}, data: {data_path}, output: {output_dir}")
results = tfma.run_model_analysis(
    model_location=model_dir,
    data_location=data_path,
    eval_config=eval_config,
    output_path=output_dir,
    # TFMA uses Apache Beam for execution. For local quickstart,
    # default DirectRunner is used. For cloud, configure Beam options.
    # e.g., beam_options=os.environ.get('BEAM_OPTIONS', '').split()
)

print(f"TFMA analysis complete. Results written to: {output_dir}")
# To inspect results (e.g., in a Jupyter Notebook):
# from tensorflow_model_analysis.notebook import visualization
# visualization.display_metrics(output_dir)

view raw JSON →