Torch Model Archiver
Torch Model Archiver is a dedicated command-line tool used for creating archives of trained PyTorch neural network models (typically `.pth` or TorchScript files) into a `.mar` (Model ARchive) format. These `.mar` files are specifically designed to be consumed and served by TorchServe for inference. The library is part of the larger PyTorch/Serve ecosystem and frequently updates in conjunction with TorchServe releases, with the current version being 0.12.0.
Warnings
- breaking Starting with TorchServe v0.11.1 (and consequently `torch-model-archiver` being part of this ecosystem), token authorization is enabled by default for all HTTP/S and gRPC APIs.
- gotcha When archiving a model that uses custom metrics with `add_metric`, the default metric type inferred changed to `COUNTER` in v0.8.2.
- gotcha When using `--extra-files` to include additional Python modules or configuration files, remember that all files are flattened into a single folder within the `.mar` archive. This might require adjusting relative import paths in your handler or model code.
- gotcha Compatibility between `torch-model-archiver` and `TorchServe` versions is critical. Using mismatched versions can lead to unexpected behavior or failure to serve models correctly.
- gotcha Older versions of `torch-model-archiver` might not fully support or be optimized for newer PyTorch versions, especially PyTorch 2.x features.
Install
-
pip install torch-model-archiver
Imports
- torch-model-archiver
This library is primarily a command-line interface (CLI) tool and is not typically imported for direct Python usage. It's executed as a shell command.
Quickstart
# Assume you have a PyTorch model 'model.py' and a serialized state_dict 'model.pth'
# Also assume you have a handler 'handler.py' (or use a default one like 'image_classifier')
# Create a simple dummy model.py and handler.py for demonstration:
# model.py:
# import torch.nn as nn
# class MyModel(nn.Module):
# def __init__(self):
# super(MyModel, self).__init__()
# self.linear = nn.Linear(10, 1)
# def forward(self, x):
# return self.linear(x)
#
# handler.py (minimal):
# from ts.torch_handler.base_handler import BaseHandler
# class MyHandler(BaseHandler):
# def preprocess(self, data):
# # Implement your data preprocessing logic
# return data
# def postprocess(self, data):
# # Implement your data postprocessing logic
# return data
# Command to archive a model (example with a hypothetical densenet161 setup):
# Ensure 'densenet161_model.py', 'densenet161_state.pth', and 'index_to_name.json' exist
# For a real run, replace paths with actual files and ensure handler logic matches the model.
# Example from TorchServe docs (adjust paths if running locally without cloning the repo)
# This assumes a model file like 'densenet_161/model.py' and a state dict like 'densenet161-8d451a50.pth'
# and a default handler 'image_classifier'
#
# Make a dummy model_store directory
import os
os.makedirs('model_store', exist_ok=True)
# This example is illustrative. For a runnable quickstart, you'd need to provide actual model.py, .pth, and handler files.
# A fully runnable quickstart often involves downloading example assets from the TorchServe repo.
# This specific command uses a generic handler and placeholder files.
# In a real scenario, you'd replace 'my_model.py', 'my_model_state.pth', and 'my_handler.py' with your actual files.
# We are using 'image_classifier' as a built-in handler for demonstration purposes.
print("To create a model archive (.mar) file:")
print("torch-model-archiver --model-name mymodel --version 1.0 --model-file path/to/my_model.py --serialized-file path/to/my_model_state.pth --handler image_classifier --export-path model_store -f")
print("\nThis command will create 'model_store/mymodel.mar'")
# Example using subprocess (if you wanted to run it from Python)
import subprocess
# This path is relative to the torchserve repo; adjust if you cloned it elsewhere or use your own model files.
# For a truly isolated example, you'd need to create dummy files or download real ones.
model_name = "densenet161"
model_version = "1.0"
# Placeholder paths for demonstration
model_file_path = "./dummy_model.py"
serialized_file_path = "./dummy_state.pth"
export_path = "model_store"
handler_name = "image_classifier" # Using a default handler for simplicity
# Create dummy files if they don't exist for the subprocess command to not error immediately
with open(model_file_path, "w") as f:
f.write("import torch.nn as nn\nclass MyModel(nn.Module):\n def __init__(self):\n super().__init__()\n self.linear = nn.Linear(10, 1)\n def forward(self, x):\n return self.linear(x)")
# Create a dummy serialized file (e.g., an empty file or a minimal PyTorch save)
import torch
torch.save({'state_dict': {}}, serialized_file_path)
cmd = [
"torch-model-archiver",
"--model-name", model_name,
"--version", model_version,
"--model-file", model_file_path,
"--serialized-file", serialized_file_path,
"--handler", handler_name,
"--export-path", export_path,
"-f" # Force overwrite if file exists
]
try:
# Not actually running this in a quickstart as it requires external files, just showing the structure
# subprocess.run(cmd, check=True, capture_output=True)
# print(f"Successfully created {export_path}/{model_name}.mar")
pass # Suppress actual execution for quickstart to avoid requiring external files
except subprocess.CalledProcessError as e:
print(f"Error archiving model: {e.stderr.decode()}")
except FileNotFoundError:
print("Error: 'torch-model-archiver' command not found. Please ensure the library is installed and in your PATH.")
# Clean up dummy files
os.remove(model_file_path)
os.remove(serialized_file_path)
# os.rmdir(export_path) # Don't remove if you expect a .mar file for a real test