Torchviz

0.0.3 · active · verified Tue Apr 14

Torchviz is a small Python package designed to create visualizations of PyTorch execution graphs and traces. It provides a way to visually inspect the computational flow of a neural network model, which is helpful for understanding architecture, debugging, and optimization. The current version is 0.0.3, with releases typically tied to PyTorch ecosystem updates.

Warnings

Install

Imports

Quickstart

This quickstart defines a simple PyTorch neural network, performs a forward pass with an input tensor that requires gradients, and then uses `make_dot` to generate and optionally save a visualization of the computational graph. Ensure Graphviz is installed on your system for rendering.

import torch
import torch.nn as nn
from torchviz import make_dot

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the model
model = SimpleNN()

# Generate a random input, ensuring requires_grad=True for graph generation
input_data = torch.randn(1, 10, requires_grad=True)

# Perform a forward pass
output = model(input_data)

# Visualize the computational graph
graph = make_dot(output, params=dict(model.named_parameters()))

# To save the visualization to a file (e.g., PNG)
# graph.render("computational_graph", format="png", cleanup=True)

# In Jupyter/Colab, the graph can often be displayed directly by just calling it
# graph

view raw JSON →