Model summary in PyTorch
torchinfo is a Python library that provides a detailed summary of PyTorch models, similar to the functionality found in TensorFlow's `model.summary()`. It displays information such as layer types, input/output shapes, kernel sizes, number of parameters, and computational operations (Mult-Adds). The library is a modern, re-written successor to older projects like `torchsummary` and `torchsummaryX`. The current version is 1.8.0, released in May 2023. While the last release is from May 2023, the project is considered actively developed and maintained within the PyTorch ecosystem.
Warnings
- breaking Python 3.6 support was deprecated starting from version 1.6.0. Users on Python 3.6 should install `torchinfo` v1.5.4 or an earlier version.
- gotcha When using `input_size` without explicitly specifying `dtypes`, `torchinfo` defaults to `torch.float`. This can lead to `RuntimeError` for models expecting integer tensor inputs (e.g., embedding layers in NLP models like HuggingFace Transformers).
- gotcha `torchinfo` attempts to infer the device for model and input data, but mismatches between the model's device and the `input_data`'s device can cause `RuntimeError` (e.g., 'Expected all tensors to be on the same device').
- gotcha In Jupyter Notebooks or Google Colab, the output of `summary(model, ...)` may not display correctly unless explicitly wrapped in a `print()` statement.
- gotcha Enabling `cache_forward_pass=True` for performance can lead to incorrect or outdated summaries if the model architecture or input data/sizes are modified without invalidating the cache.
- gotcha For 1D input tensors, `input_size` must be specified as a tuple with a trailing comma, e.g., `(10,)` instead of `(10)`, to be correctly interpreted.
- gotcha Version 1.8.0 introduced the use of `tensor.untyped_storage()` for PyTorch 2.0. While intended for compatibility, ensure that your PyTorch version is adequately supported for features that rely on this.
Install
-
pip install torchinfo
Imports
- summary
from torchinfo import summary
Quickstart
import torch
import torch.nn as nn
from torchinfo import summary
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
# Calculate input features for the linear layer
# ((28 - 5 + 1) / 2 - 5 + 1) / 2 = ((24 / 2) - 4) / 2 = (12 - 4) / 2 = 8 / 2 = 4
# Output channels for conv2 is 20, so 20 * 4 * 4 = 320
self.fc = nn.Linear(320, 10)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 320) # Flatten
x = self.fc(x)
return x
model = SimpleCNN()
# Provide an example input_size (batch_size, channels, height, width)
summary(model, input_size=(1, 1, 28, 28))