TensorFlow Estimator
The TensorFlow Estimator library provides a high-level API for training machine learning models, simplifying the process of training, evaluating, and predicting. While still maintained for existing projects, it is officially not recommended for new code, with Keras being the preferred API in modern TensorFlow. It generally follows the TensorFlow core release cadence.
Warnings
- deprecated Estimators are not recommended for new code. TensorFlow's official stance is to prefer Keras for new development due to its simpler API, better integration with TF2.x eager execution, and broader community support.
- gotcha Estimators are designed primarily for graph execution in a TF1.x style. While they run in TF2.x, they often hide the underlying TF2.x features and may not fully leverage eager execution in the same way native Keras models do, potentially leading to less intuitive debugging or performance limitations compared to pure Keras.
- gotcha The `tensorflow-estimator` package provides the `tf.estimator` namespace, but it's typically installed automatically as a dependency of the main `tensorflow` package. Manually installing `tensorflow-estimator` without `tensorflow` will not provide a functional `tf.estimator` API unless `tensorflow` is already present or installed separately.
Install
-
pip install tensorflow-estimator -
pip install tensorflow
Imports
- Estimator
import tensorflow as tf estimator = tf.estimator.Estimator(...)
- DNNClassifier
import tensorflow as tf classifier = tf.estimator.DNNClassifier(...)
Quickstart
import tensorflow as tf
# Define feature columns
feature_columns = [
tf.feature_column.numeric_column('x', shape=[1])
]
# Define the estimator
estimator = tf.estimator.DNNRegressor(
feature_columns=feature_columns,
hidden_units=[10, 10],
model_dir='/tmp/DNNRegressor_model'
)
# Define input function for training
def input_fn_train():
features = {'x': tf.constant([1., 2., 3., 4.])}
labels = tf.constant([0., -1., -2., -3.])
return tf.data.Dataset.from_tensor_slices((features, labels)).repeat().batch(2)
# Define input function for prediction
def input_fn_predict():
features = {'x': tf.constant([5., 6.])}
return tf.data.Dataset.from_tensor_slices(features).batch(2)
# Train the estimator
print('Training the model...')
estimator.train(input_fn=input_fn_train, steps=100)
print('Training complete.')
# Predict
print('Making predictions...')
predictions = list(estimator.predict(input_fn=input_fn_predict))
for p in predictions:
print(f"Prediction: {p['predictions'][0]:.2f}")