Torchsummary
`torchsummary` provides a Keras-like `model.summary()` functionality for PyTorch models, displaying layer names, output shapes, parameter counts, and trainable parameters. It helps in quickly understanding the architecture and memory footprint of a neural network. The current version is 1.5.1, with releases occurring as needed for bug fixes and minor enhancements rather than a strict schedule.
Common errors
-
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1, 224, 224] to have 3 channels, but got 1 channels instead
cause The `input_size` passed to `summary()` does not match the number of input channels expected by the model's first layer.fixAdjust `input_size` to correctly reflect the `(channels, height, width)` of your input data. For example, if the model expects 3 channels, `input_size=(3, H, W)`. -
TypeError: summary() got an unexpected keyword argument 'input_data'
cause This error typically occurs when attempting to use Keras-like argument names or non-existent arguments with `torchsummary.summary()`. The correct argument for specifying the input tensor shape (without the batch dimension) is `input_size`.fixReplace `input_data` or `input_shape` with `input_size`. -
AttributeError: 'torch.device' object has no attribute 'lower' or similar device-related error when passing device to summary()
cause The `device` argument in `summary()` expects a string representation of the device (e.g., "cpu", "cuda"), not a `torch.device` object.fixConvert your `torch.device` object to a string before passing it: `device=str(your_torch_device_object)`.
Warnings
- gotcha The `input_size` parameter must be a tuple representing the shape of a *single input sample* (excluding the batch dimension). For example, a batch of 64 images of size (3, 224, 224) would require `input_size=(3, 224, 224)`. Mismatches are a common cause of `RuntimeError` or incorrect summaries.
- gotcha `torchsummary` internally uses `torch.jit.trace` to analyze the model. Models with dynamic control flow (e.g., if-statements or loops whose behavior depends on input data values) may not be correctly summarized or might raise `RuntimeError` during tracing.
- gotcha The `device` parameter in `summary()` expects a string (`"cpu"` or `"cuda"`) not a `torch.device` object directly.
Install
-
pip install torchsummary
Imports
- summary
from torchsummary.torchsummary import summary
from torchsummary import summary
Quickstart
import torch
import torch.nn as nn
from torchsummary import summary
# Define a simple model
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.flatten = nn.Flatten()
# Calculate input size for linear layer: (28 - 5 + 1) / 2 = 12
# So, 10 channels * 12 * 12 pixels = 1440
self.fc1 = nn.Linear(10 * 12 * 12, 50)
self.relu2 = nn.ReLU()
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.flatten(x)
x = self.relu2(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleCNN()
# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Print summary for an input image of 1 channel, 28x28 pixels
# input_size should be (channels, height, width) *without* the batch dimension
print(summary(model, input_size=(1, 28, 28), device=str(device)))