torchprofile
torchprofile is a lightweight Python library designed to accurately count the Multiply-Accumulate Operations (MACs) or FLOPs of PyTorch models. Its current version is 0.1.0. Releases are infrequent but indicate ongoing maintenance, focusing on core profiling capabilities.
Common errors
-
ModuleNotFoundError: No module named 'torchprofile'
cause The `torchprofile` library is not installed in the current Python environment.fix`pip install torchprofile` -
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
cause The input tensor and the model are on different devices (e.g., CPU vs. GPU). `torchprofile` requires them to be consistent.fixEnsure both the model and the input tensor are on the same device before profiling, for example: `model.to('cuda')` and `inputs.to('cuda')`. -
TypeError: object of type <class '...'> has no len()
cause `torchprofile` encountered a custom layer or operation within your model for which it does not have a predefined handler, causing a failure during graph traversal.fixYou likely need to implement and register a custom MACs handler for that specific module type. Refer to `torchprofile.handlers` for examples and instructions on `register_macs_handler`.
Warnings
- gotcha torchprofile strictly counts Multiply-Accumulate Operations (MACs). Be aware that different communities or tools might define FLOPs differently (e.g., 1 MAC = 2 FLOPs), leading to discrepancies if comparing counts across tools.
- gotcha torchprofile might not automatically count operations for highly custom or non-standard `nn.Module` implementations. This can lead to underestimated MACs for models containing such layers.
- breaking As a 0.x version library, torchprofile's API is subject to change without strict semantic versioning. Updates, especially between minor versions (e.g., 0.0.x to 0.1.x), might introduce breaking changes.
Install
-
pip install torchprofile
Imports
- profile_macs
from torchprofile import profile_macs
Quickstart
import torch
from torchvision.models import resnet18
from torchprofile import profile_macs
# Create a sample PyTorch model
model = resnet18()
# Define a dummy input tensor (batch_size, channels, height, width)
# Ensure the input device matches the model's device if applicable
inputs = torch.randn(1, 3, 224, 224)
# Profile the model MACs
macs = profile_macs(model, inputs)
print(f"ResNet-18 MACs: {macs / 1e9:.2f} G")