torchprofile

0.1.0 · active · verified Fri Apr 17

torchprofile is a lightweight Python library designed to accurately count the Multiply-Accumulate Operations (MACs) or FLOPs of PyTorch models. Its current version is 0.1.0. Releases are infrequent but indicate ongoing maintenance, focusing on core profiling capabilities.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize a standard PyTorch model and use `torchprofile.profile_macs` to count its Multiply-Accumulate Operations (MACs) for a given input tensor. The result is printed in GigaMACs.

import torch
from torchvision.models import resnet18
from torchprofile import profile_macs

# Create a sample PyTorch model
model = resnet18()

# Define a dummy input tensor (batch_size, channels, height, width)
# Ensure the input device matches the model's device if applicable
inputs = torch.randn(1, 3, 224, 224)

# Profile the model MACs
macs = profile_macs(model, inputs)

print(f"ResNet-18 MACs: {macs / 1e9:.2f} G")

view raw JSON →