TensorFlow Estimator

2.15.0 · maintenance · verified Thu Apr 09

The TensorFlow Estimator library provides a high-level API for training machine learning models, simplifying the process of training, evaluating, and predicting. While still maintained for existing projects, it is officially not recommended for new code, with Keras being the preferred API in modern TensorFlow. It generally follows the TensorFlow core release cadence.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to create, train, and make predictions with a `DNNRegressor` using `tf.estimator`. It defines a simple input function for training and prediction, and then trains the model for a few steps before predicting on new data.

import tensorflow as tf

# Define feature columns
feature_columns = [
    tf.feature_column.numeric_column('x', shape=[1])
]

# Define the estimator
estimator = tf.estimator.DNNRegressor(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    model_dir='/tmp/DNNRegressor_model'
)

# Define input function for training
def input_fn_train():
    features = {'x': tf.constant([1., 2., 3., 4.])}
    labels = tf.constant([0., -1., -2., -3.])
    return tf.data.Dataset.from_tensor_slices((features, labels)).repeat().batch(2)

# Define input function for prediction
def input_fn_predict():
    features = {'x': tf.constant([5., 6.])}
    return tf.data.Dataset.from_tensor_slices(features).batch(2)

# Train the estimator
print('Training the model...')
estimator.train(input_fn=input_fn_train, steps=100)
print('Training complete.')

# Predict
print('Making predictions...')
predictions = list(estimator.predict(input_fn=input_fn_predict))
for p in predictions:
    print(f"Prediction: {p['predictions'][0]:.2f}")

view raw JSON →