Torchviz
Torchviz is a small Python package designed to create visualizations of PyTorch execution graphs and traces. It provides a way to visually inspect the computational flow of a neural network model, which is helpful for understanding architecture, debugging, and optimization. The current version is 0.0.3, with releases typically tied to PyTorch ecosystem updates.
Warnings
- breaking Torchviz critically depends on the external Graphviz software package being installed on your operating system, not just the Python `graphviz` package. Without a system-wide Graphviz installation (e.g., `brew install graphviz` on macOS, `sudo apt-get install graphviz` on Ubuntu, or installing from graphviz.org for Windows), `make_dot` will not be able to render images, often failing silently or with obscure errors about missing executables.
- gotcha For `make_dot` to generate a meaningful computational graph, at least one of the input tensors to the operation being visualized must have `requires_grad=True`. If no tensor requires a gradient, PyTorch's autograd engine (which `torchviz` inspects) won't track operations, resulting in an empty or incomplete graph.
- deprecated The `make_dot_from_trace` function, intended for use with `torch.jit.trace`, is noted to be less robust and 'does not always work' according to the official documentation, especially with newer PyTorch versions. Past issues have reported it not working with PyTorch 1.0.
- gotcha The `show_attrs=True` and `show_saved=True` parameters, which provide additional details about the graph nodes (attributes and saved tensors for the backward pass), are only supported with PyTorch versions 1.9 and later. Using them with older PyTorch versions might not have the intended effect or could lead to errors.
Install
-
pip install torchviz -
pip install graphviz # System-level Graphviz is also required (e.g., `brew install graphviz` on macOS, `sudo apt-get install graphviz` on Ubuntu)
Imports
- make_dot
from torchviz import make_dot
- make_dot_from_trace
from torchviz import make_dot_from_trace
Quickstart
import torch
import torch.nn as nn
from torchviz import make_dot
# Define a simple neural network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Instantiate the model
model = SimpleNN()
# Generate a random input, ensuring requires_grad=True for graph generation
input_data = torch.randn(1, 10, requires_grad=True)
# Perform a forward pass
output = model(input_data)
# Visualize the computational graph
graph = make_dot(output, params=dict(model.named_parameters()))
# To save the visualization to a file (e.g., PNG)
# graph.render("computational_graph", format="png", cleanup=True)
# In Jupyter/Colab, the graph can often be displayed directly by just calling it
# graph