SageMaker Inference Toolkit

1.10.1 · active · verified Thu Apr 16

The sagemaker-inference toolkit is an open-source Python library designed to simplify the creation of serving containers for machine learning models on Amazon SageMaker. It provides a model serving stack built on Multi Model Server (MMS), enabling users to easily implement custom inference logic. The current version is 1.10.1, with a regular release cadence addressing bug fixes and new features, including support for newer Python versions and improved dependency management.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the core pattern for using `sagemaker-inference` to create a custom inference handler. It defines a `CustomInferenceHandler` class that extends `DefaultInferenceHandler`, overriding `model_fn`, `input_fn`, `predict_fn`, and `output_fn`. These functions are responsible for loading the model, deserializing input, making predictions, and serializing output, respectively. This file would typically be part of your model archive within a SageMaker custom container.

import os
import json

from sagemaker_inference.default_inference_handler import DefaultInferenceHandler
from sagemaker_inference import content_types, decoder, encoder

class CustomInferenceHandler(DefaultInferenceHandler):
    def default_model_fn(self, model_dir, context=None):
        """Loads a dummy model for demonstration. In a real scenario, this would load
        your actual trained model from `model_dir`.
        """
        print(f"Loading model from: {model_dir}")
        # Simulate loading a model artifact
        # For example, if you had a 'model.pkl' in model_dir
        # model_path = os.path.join(model_dir, 'model.pkl')
        # model = joblib.load(model_path)
        return {"status": "model_loaded", "path": model_dir}

    def default_input_fn(self, input_data, content_type, context=None):
        """Deserializes the input data from the request. Supports JSON and CSV.
        """
        if content_type == content_types.JSON:
            return decoder.decode(input_data, content_type)
        elif content_type == content_types.CSV:
            # Assuming CSV is a simple string for this example
            return input_data.decode('utf-8').split(',')
        else:
            raise ValueError(f"Unsupported content type: {content_type}")

    def default_predict_fn(self, data, model, context=None):
        """Makes a dummy prediction based on the input data and the loaded model.
        """
        print(f"Performing prediction with model: {model} and data: {data}")
        if isinstance(data, dict) and 'instances' in data:
            # Assume a common inference request format
            predictions = [item * 2 for item in data['instances']]
        elif isinstance(data, list):
            predictions = [item + "_processed" for item in data]
        else:
            predictions = f"Processed: {data}"
        return {"predictions": predictions}

    def default_output_fn(self, prediction, accept, context=None):
        """Serializes the prediction result to the requested accept type.
        Supports JSON.
        """
        if accept == content_types.JSON:
            return encoder.encode(prediction, accept)
        else:
            raise ValueError(f"Unsupported accept type: {accept}")

# To run this in a SageMaker container, you would have a Dockerfile
# that installs sagemaker-inference and multi-model-server, copies this file
# as 'inference.py' and sets up the entrypoint to start the model server.
# e.g., using sagemaker_inference.model_server.start_model_server()

# Example of how to manually test the handler (not typically run directly in a quickstart)
if __name__ == '__main__':
    handler = CustomInferenceHandler()
    model = handler.default_model_fn('/opt/ml/model') # Simulates model_dir

    test_json_input = json.dumps({"instances": [1, 2, 3]}).encode('utf-8')
    json_data = handler.default_input_fn(test_json_input, content_types.JSON)
    json_prediction = handler.default_predict_fn(json_data, model)
    json_output = handler.default_output_fn(json_prediction, content_types.JSON)
    print(f"JSON Inference Result: {json_output.decode('utf-8')}")

    test_csv_input = b'hello,world'
    csv_data = handler.default_input_fn(test_csv_input, content_types.CSV)
    csv_prediction = handler.default_predict_fn(csv_data, model)
    csv_output = handler.default_output_fn(csv_prediction, content_types.JSON) # Output as JSON for simplicity
    print(f"CSV Inference Result: {csv_output.decode('utf-8')}")

view raw JSON →