XProf Profiler Plugin

2.22.1 · active · verified Sun Apr 12

The XProf Profiler Plugin is a powerful tool for profiling and performance analysis of machine learning models across various frameworks, including TensorFlow, JAX, and PyTorch/XLA. It helps users understand, debug, and optimize their programs to run efficiently on CPUs, GPUs, and TPUs. The current version is 2.22.1, and the library follows the TensorFlow versioning scheme, with frequent updates and releases.

Warnings

Install

Quickstart

This quickstart demonstrates how to generate profiling data for a TensorFlow model and then launch TensorBoard to visualize it. The `tensorboard-plugin-profile` automatically integrates with TensorBoard once installed. After running the Python code, open a new terminal and execute the `tensorboard` command provided to view the profile data in your browser under the 'Profile' tab.

import tensorflow as tf
from datetime import datetime
import os

# Ensure log directory exists
log_dir = os.path.join("logs", "profile", datetime.now().strftime("%Y%m%d-%H%M%S"))
os.makedirs(log_dir, exist_ok=True)

# Dummy model and data for profiling
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(10, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy')
data = tf.random.normal(shape=(100, 10))
labels = tf.random.uniform(shape=(100, 1), maxval=2, dtype=tf.int64)

# Option 1: Programmatic profiling with tf.profiler.experimental
print(f"Starting programmatic profile, data will be in {log_dir}")
with tf.profiler.experimental.Profile(log_dir):
    model.fit(data, labels, epochs=2, batch_size=32)

# Option 2: Using TensorBoard Keras Callback for profiling specific batches
# tb_callback = tf.keras.callbacks.TensorBoard(
#     log_dir=log_dir,
#     profile_batch='1,3' # Profile batches 1 to 3
# )
# model.fit(data, labels, epochs=2, batch_size=32, callbacks=[tb_callback])

print("Profiling data generated. To view, run TensorBoard in your terminal:")
print(f"tensorboard --logdir={os.path.abspath('logs')}")
print("Then open your browser to http://localhost:6006/#profile")

view raw JSON →