TensorFlow Decision Forests
TensorFlow Decision Forests (TF-DF) is a Python library that integrates state-of-the-art decision forest algorithms (such as Random Forests and Gradient Boosted Trees) directly into TensorFlow and Keras. It enables training, serving, and interpreting these models for classification, regression, and ranking tasks. Built on the highly optimized C++ Yggdrasil Decision Forests (YDF) library, it is actively maintained with frequent releases, typically every few months.
Warnings
- breaking TensorFlow Version Compatibility: Each TF-DF version is strictly tied to a specific TensorFlow version due to ABI compatibility. Using incompatible versions will lead to cryptic C++ runtime errors (e.g., 'undefined symbol').
- deprecated Loss function `LAMBDA_MART_NDCG5` has been renamed to `LAMBDA_MART_NDCG`.
- gotcha Keras 3 Incompatibility: TF-DF is not yet compatible with Keras 3.
- gotcha Windows Support Limitations: A native Windows Pip package is not available.
- gotcha No GPU/TPU Support: TF-DF models currently do not leverage GPUs or TPUs for training or inference.
- gotcha Recommendation to migrate to Yggdrasil Decision Forests (YDF) for new projects.
Install
-
pip install tensorflow_decision_forests -
pip install tensorflow_decision_forests wurlitzer
Imports
- tensorflow_decision_forests
import tensorflow_decision_forests as tfdf
- RandomForestModel
tfdf.keras.RandomForestModel
- pd_dataframe_to_tf_dataset
tfdf.keras.pd_dataframe_to_tf_dataset
Quickstart
import os
import tensorflow as tf
import tensorflow_decision_forests as tfdf
import pandas as pd
# Ensure Keras 2 compatibility, often needed in TF-DF environments
os.environ['TF_USE_LEGACY_KERAS'] = '1'
# Create a dummy dataset (replace with your actual data loading)
data = {
'feature_1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'feature_2': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C'],
'label': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
}
df = pd.DataFrame(data)
# Convert the Pandas DataFrame to a TensorFlow Dataset
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(df, label='label')
# Create and train a Random Forest model
model = tfdf.keras.RandomForestModel()
model.fit(train_ds)
# (Optional) Evaluate the model
# model.evaluate(train_ds) # Use a separate test_ds for proper evaluation
print("Model trained successfully!")
model.summary()