Computational Graph Tracking

PyTorch Graph provides comprehensive computational graph tracking and visualization capabilities.

Overview

The computational graph tracking module allows you to:

  • Track Complete Graphs: Capture every operation in your model’s computational graph

  • Analyze Performance: Monitor execution time, memory usage, and operation counts

  • Visualize Operations: Create professional diagrams of the computational graph

  • Export Data: Save graph data in JSON format for further analysis

  • Replace torchviz Cleanly: Use a familiar make_dot(...) API without needing Graphviz

Key Features

  • Maximal Traversal: No artificial limits on graph depth or operation count

  • Full Method Names: Complete operation names without truncation

  • Smart Arrow Positioning: Arrows connect node edges properly without crossing over boxes

  • Compact Layout: Eliminates gaps and breaks for continuous flow

  • Professional Quality: High-resolution output up to 300 DPI

  • Richer Exports: Save the same graph as PNG, DOT, or JSON

Basic Usage

Track a computational graph with the convenience function:

import torch
import torch.nn as nn
from pytorch-graph import track_computational_graph

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

input_tensor = torch.randn(1, 784, requires_grad=True)

# Track computational graph
tracker = track_computational_graph(
    model=model,
    input_tensor=input_tensor,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

# Save visualization
tracker.save_graph_png("computational_graph.png")

Torchviz-Style Usage

Use make_dot(...) when you want the familiar torchviz flow, but with richer exports and no Graphviz system dependency for PNG output:

import torch
import torch.nn as nn
from pytorch_graph import make_dot

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

input_tensor = torch.randn(1, 784, requires_grad=True)
output = model(input_tensor)

graph = make_dot(
    output,
    params=dict(model.named_parameters()),
    model=model,
    inputs=input_tensor,
    output_names=["logits"],
    show_metadata=True
)

graph.render("autograd_graph", format="png")
graph.render("autograd_graph", format="dot")
graph.render("autograd_graph", format="json")

Compared to torchviz, this API adds:

  • Named parameter nodes and explicit output nodes

  • PNG rendering without requiring Graphviz on the system

  • Optional model-aware enrichment with module names, tensor shapes, and parameter counts

  • Multi-input tracing when passing model plus inputs

Tracker Control

Use the ComputationalGraphTracker class for full control:

from pytorch-graph import ComputationalGraphTracker

# Create tracker with custom settings
tracker = ComputationalGraphTracker(
    model=model,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

# Start tracking
tracker.start_tracking()

# Run your model
output = model(input_tensor)
loss = output.sum()
loss.backward()

# Stop tracking
tracker.stop_tracking()

# Get comprehensive analysis
summary = tracker.get_graph_summary()
print(f"Operations: {summary['total_nodes']:,}")
print(f"Execution time: {summary['execution_time']:.4f}s")

# Save with custom parameters
tracker.save_graph_png(
    "advanced_graph.png",
    width=2000,
    height=1500,
    dpi=300,
    show_legend=True,
    node_size=30,
    font_size=14
)

Analysis Functions

Get detailed analysis of your computational graph:

from pytorch-graph import analyze_computational_graph

# Comprehensive analysis
analysis = analyze_computational_graph(
    model=model,
    input_tensor=input_tensor,
    detailed=True
)

summary = analysis['summary']
print(f"Total operations: {summary['total_nodes']:,}")
print(f"Execution time: {summary['execution_time']:.4f}s")

# Performance metrics
if 'performance' in analysis:
    perf = analysis['performance']
    print(f"Operations per second: {perf['operations_per_second']:.2f}")
    print(f"Memory usage: {perf['memory_usage']}")

# Layer-wise analysis
if 'layer_analysis' in analysis:
    for layer_name, operations in analysis['layer_analysis'].items():
        print(f"{layer_name}: {len(operations)} operations")

Data Export

Export graph data for offline analysis:

# Export to JSON
tracker.export_graph("graph_data.json")

# Load and inspect exported data
import json
with open("graph_data.json", 'r') as f:
    graph_data = json.load(f)

print(f"Nodes: {len(graph_data['nodes'])}")
print(f"Edges: {len(graph_data['edges'])}")

Visualization Features

High-Quality Output

tracker.save_graph_png(
    filepath="publication_quality.png",
    width=2000,           # Custom width
    height=1500,          # Custom height
    dpi=300,              # High DPI for publication
    show_legend=True,     # Show legend
    node_size=30,         # Node size
    font_size=14          # Font size
)

Custom Styling

The computational graph visualization includes:

  • Full Method Names: Complete operation names without truncation

  • Smart Arrow Positioning: Arrows connect node edges properly

  • Compact Layout: No gaps or breaks in the graph

  • Professional Styling: Enhanced colors and typography

  • Intelligent Legends: Automatic positioning without overlap

Examples

CNN Computational Graph

cnn_model = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(64, 10)
)

input_tensor = torch.randn(1, 3, 32, 32, requires_grad=True)
tracker = track_computational_graph(cnn_model, input_tensor)
tracker.save_graph_png("cnn_computational_graph.png")

Complex Model Analysis

class ComplexModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = ComplexModel()
input_tensor = torch.randn(1, 3, 32, 32, requires_grad=True)

# Track with performance monitoring
tracker = ComputationalGraphTracker(
    model=model,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

tracker.start_tracking()
output = model(input_tensor)
loss = output.sum()
loss.backward()
tracker.stop_tracking()

# Get performance summary
summary = tracker.get_graph_summary()
print(f"Memory usage: {summary['memory_usage']}")
print(f"Execution time: {summary['execution_time']:.4f}s")

Best Practices

  • Use appropriate input tensors that match your model’s expected input

  • Enable memory tracking for performance analysis

  • Use high DPI (300) for publication-quality output

  • Export graph data for offline analysis of complex models

  • Monitor memory usage when tracking large models

Performance Tips

  • Disable tensor operation tracking for very large models to improve performance

  • Use smaller input tensors for initial testing

  • Export graph data for offline analysis of complex models

  • Monitor memory usage when tracking large models

Troubleshooting

Common Issues

Memory issues with large models

Use track_tensor_ops=False for better performance

Long operation names

The system automatically handles long names without truncation

Large graph visualization

Increase image size and use high DPI for better quality

Import errors

Ensure all dependencies are installed: pip install torch matplotlib

See Also