TensorFlow Model Optimization Toolkit

raw JSON →
0.8.0 verified Mon Apr 27 auth: no python

A suite of tools for optimizing machine learning models for deployment, including quantization-aware training, pruning, and weight clustering. Current version 0.8.0, tested against TensorFlow 2.14.1. Released approximately every 4–6 months.

pip install tensorflow-model-optimization
error ModuleNotFoundError: No module named 'tensorflow_model_optimization.quantization.keras.quantize_model'
cause The submodule may not be imported correctly; the function is inside a deeper package.
fix
Use: from tensorflow_model_optimization.quantization.keras import quantize_model
error ValueError: Unknown layer: QuantizeWrapper. Please ensure this layer is imported.
cause When loading a quantized model, the custom QuantizeWrapper class is not registered.
fix
Use: with tfmot.quantization.keras.quantize_scope(): model = tf.keras.models.load_model('path')
error AttributeError: module 'tensorflow_model_optimization' has no attribute 'sparsity'
cause Importing tfmot.sparsity directly fails because the package is not fully imported.
fix
Use: from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude
gotcha The 'quantize_model' function only works with Keras v2 (tf.keras). Using tfmot with tf.compat.v1.keras or standalone Keras may fail.
fix Ensure you are using tf.keras, not keras directly. Upgrade to TensorFlow 2.14+.
breaking In version 0.7.0+, the default QAT API changed: QuantizeWrapperV2 now preserves the order of weights. Existing models serialized with QuantizeWrapper may not load correctly.
fix Re-apply quantization after upgrading to 0.7.0+. For backward compatibility, use tfmot.quantization.keras.QuantizeWrapper (old) instead of QuantizeWrapperV2 if needed.
gotcha Pruning and QAT layers cannot be used with tf.function tracing when combined. Debugging is tricky.
fix Avoid decorating quantized/pruned models with @tf.function. Instead, call the model eagerly or use Keras training loop.
deprecated Weight clustering API is no longer actively maintained and may be removed in future versions.
fix Use QAT or pruning instead; consider alternatives like TensorFlow Lite's built-in clustering.
breaking Version 0.7.2 removed support for PeepholeLSTMCell. Loading models using this layer will fail.
fix Replace PeepholeLSTMCell with standard LSTM or custom implementation before upgrading.

Demonstrates quantization-aware training with a simple model, then converts to TFLite.

import tensorflow as tf
import tensorflow_model_optimization as tfmot

# Quantization aware training example
def get_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(32, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

model = get_model()
quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(model)

# Compile and train
q_aware_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
import numpy as np
train_images = np.random.random((100, 784)).astype('float32')
train_labels = np.random.randint(10, size=(100,))
q_aware_model.fit(train_images, train_labels, epochs=1, verbose=0)

# Convert to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
print('TFLite model size:', len(tflite_model))