Model summary in PyTorch

1.8.0 · active · verified Mon Apr 13

torchinfo is a Python library that provides a detailed summary of PyTorch models, similar to the functionality found in TensorFlow's `model.summary()`. It displays information such as layer types, input/output shapes, kernel sizes, number of parameters, and computational operations (Mult-Adds). The library is a modern, re-written successor to older projects like `torchsummary` and `torchsummaryX`. The current version is 1.8.0, released in May 2023. While the last release is from May 2023, the project is considered actively developed and maintained within the PyTorch ecosystem.

Warnings

Install

Imports

Quickstart

This example defines a simple Convolutional Neural Network (CNN) using `torch.nn.Module` and then uses `torchinfo.summary` to print a summary of its architecture, including layer shapes and parameter counts. The `input_size` parameter is crucial for `torchinfo` to perform a forward pass and calculate intermediate shapes.

import torch
import torch.nn as nn
from torchinfo import summary

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        # Calculate input features for the linear layer
        # ((28 - 5 + 1) / 2 - 5 + 1) / 2 = ((24 / 2) - 4) / 2 = (12 - 4) / 2 = 8 / 2 = 4
        # Output channels for conv2 is 20, so 20 * 4 * 4 = 320
        self.fc = nn.Linear(320, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 320) # Flatten
        x = self.fc(x)
        return x

model = SimpleCNN()
# Provide an example input_size (batch_size, channels, height, width)
summary(model, input_size=(1, 1, 28, 28))

view raw JSON →