LiteRT

2.1.4 · active · verified Wed Apr 15

LiteRT is Google's high-performance, open-source inference framework for deploying Machine Learning and Generative AI models on edge devices, including mobile, desktop, web, and IoT platforms. It evolved from TensorFlow Lite, offering enhanced performance, unified APIs, and broad hardware acceleration (CPU, GPU, NPU). It is production-ready, powering on-device GenAI experiences in various Google products. The current PyPI version is 2.1.4.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to load and run a LiteRT (.tflite) model using the Python runtime. It initializes the interpreter, prepares a dummy input tensor, performs inference, and retrieves the output. Replace `model.tflite` with the actual path to your LiteRT model.

import numpy as np
from tflite_runtime.interpreter import Interpreter
import os

# Ensure you have a .tflite model file, e.g., downloaded from Google AI Edge.
# For this example, we'll assume 'model.tflite' exists in the current directory.
# Replace 'model.tflite' with your actual model path.
model_path = os.environ.get('LITERT_MODEL_PATH', 'model.tflite')

try:
    # Load the TFLite model and allocate tensors.
    interpreter = Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    # Get input and output tensor details.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Assuming a single input tensor for simplicity
    input_shape = input_details[0]['shape']
    input_dtype = input_details[0]['dtype']

    # Create a dummy input tensor (replace with actual data for your model)
    input_data = np.array(np.random.random_sample(input_shape), dtype=input_dtype)

    # Set the tensor to point to the input data to be inferred.
    interpreter.set_tensor(input_details[0]['index'], input_data)

    # Run inference.
    interpreter.invoke()

    # Get the output tensor.
    # Assuming a single output tensor for simplicity
    output_data = interpreter.get_tensor(output_details[0]['index'])

    print(f"Model loaded from: {model_path}")
    print(f"Input shape: {input_shape}, Dtype: {input_dtype}")
    print(f"Output data shape: {output_data.shape}, Dtype: {output_data.dtype}")
    print(f"First 5 output values: {output_data.flatten()[:5]}")

except FileNotFoundError:
    print(f"Error: Model file not found at '{model_path}'. Please provide a valid .tflite model path.")
except Exception as e:
    print(f"An error occurred during model inference: {e}")

view raw JSON →