Computational Graph Tracking API

This module provides comprehensive computational graph tracking and visualization capabilities for PyTorch models.

Core Functions

Classes

Function Details

track_computational_graph

track_computational_graph(model, input_tensor, track_memory=True, track_timing=True, track_tensor_ops=True)

Track the computational graph of a PyTorch model execution.

Parameters:

  • model (torch.nn.Module): PyTorch model to track

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • track_memory (bool, optional): Whether to track memory usage (default: True)

  • track_timing (bool, optional): Whether to track execution timing (default: True)

  • track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)

Returns: ComputationalGraphTracker - Tracker instance with execution data

Raises:

  • ImportError: If PyTorch is not available

Example:

import torch
import torch.nn as nn
from pytorch-graph import track_computational_graph

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

input_tensor = torch.randn(1, 784, requires_grad=True)
tracker = track_computational_graph(
    model=model,
    input_tensor=input_tensor,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

# Get summary
summary = tracker.get_graph_summary()
print(f"Total operations: {summary['total_nodes']}")

analyze_computational_graph

analyze_computational_graph(model, input_tensor, detailed=True)

Analyze the computational graph of a PyTorch model execution.

Parameters:

  • model (torch.nn.Module): PyTorch model to analyze

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • detailed (bool, optional): Whether to include detailed analysis (default: True)

Returns: dict - Dictionary containing computational graph analysis

Raises:

  • ImportError: If PyTorch is not available

Example:

from pytorch-graph import analyze_computational_graph

analysis = analyze_computational_graph(
    model=model,
    input_tensor=input_tensor,
    detailed=True
)

summary = analysis['summary']
print(f"Total operations: {summary['total_nodes']}")
print(f"Execution time: {summary['execution_time']:.4f}s")

track_computational_graph_execution

track_computational_graph_execution(model, input_tensor, track_memory=True, track_timing=True, track_tensor_ops=True)

Track the computational graph of a PyTorch model execution (alias for track_computational_graph).

Parameters:

  • model (torch.nn.Module): PyTorch model to track

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • track_memory (bool, optional): Whether to track memory usage (default: True)

  • track_timing (bool, optional): Whether to track execution timing (default: True)

  • track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)

Returns: ComputationalGraphTracker - Tracker instance with execution data

analyze_computational_graph_execution

analyze_computational_graph_execution(model, input_tensor, detailed=True)

Analyze the computational graph of a PyTorch model execution (alias for analyze_computational_graph).

Parameters:

  • model (torch.nn.Module): PyTorch model to analyze

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • detailed (bool, optional): Whether to include detailed analysis (default: True)

Returns: dict - Dictionary containing computational graph analysis

visualize_computational_graph

visualize_computational_graph(model, input_tensor, renderer='plotly')

Visualize the computational graph of a PyTorch model execution.

Parameters:

  • model (torch.nn.Module): PyTorch model to visualize

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)

Returns: Visualization object (Plotly figure or Matplotlib figure)

Raises:

  • ImportError: If PyTorch is not available

Example:

from pytorch-graph import visualize_computational_graph

fig = visualize_computational_graph(
    model=model,
    input_tensor=input_tensor,
    renderer='plotly'
)
fig.show()

export_computational_graph

export_computational_graph(model, input_tensor, filepath, format='json')

Export the computational graph of a PyTorch model execution to a file.

Parameters:

  • model (torch.nn.Module): PyTorch model to export

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • filepath (str): Output file path

  • format (str, optional): Export format (‘json’, default: ‘json’)

Returns: str - Path to the exported file

Raises:

  • ImportError: If PyTorch is not available

Example:

from pytorch-graph import export_computational_graph

filepath = export_computational_graph(
    model=model,
    input_tensor=input_tensor,
    filepath='graph.json'
)
print(f"Graph exported to: {filepath}")

save_computational_graph_png

save_computational_graph_png(model, input_tensor, filepath='computational_graph.png', width=1200, height=800, dpi=300, show_legend=True, node_size=20, font_size=10)

Save the computational graph as a high-quality PNG image.

Parameters:

  • model (torch.nn.Module): PyTorch model to visualize

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • filepath (str, optional): Output PNG file path (default: “computational_graph.png”)

  • width (int, optional): Image width in pixels (default: 1200)

  • height (int, optional): Image height in pixels (default: 800)

  • dpi (int, optional): Dots per inch for high resolution (default: 300)

  • show_legend (bool, optional): Whether to show legend (default: True)

  • node_size (int, optional): Size of nodes in the graph (default: 20)

  • font_size (int, optional): Font size for labels (default: 10)

Returns: str - Path to the saved PNG file

Raises:

  • ImportError: If PyTorch is not available

Example:

from pytorch-graph import save_computational_graph_png

png_path = save_computational_graph_png(
    model=model,
    input_tensor=input_tensor,
    filepath="graph.png",
    width=1600,
    height=1200,
    dpi=300
)
print(f"PNG saved to: {png_path}")

ComputationalGraphTracker Class

class ComputationalGraphTracker(model, track_memory=True, track_timing=True, track_tensor_ops=True)

Tracks the computational graph of PyTorch model execution.

Parameters:

  • model (torch.nn.Module): PyTorch model to track

  • track_memory (bool, optional): Whether to track memory usage (default: True)

  • track_timing (bool, optional): Whether to track execution timing (default: True)

  • track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)

Methods:

start_tracking()

Start tracking the computational graph.

Returns: None

stop_tracking()

Stop tracking the computational graph.

Returns: None

get_graph_summary()

Get a summary of the computational graph.

Returns: dict - Dictionary containing graph summary with keys:

  • total_nodes (int): Total number of nodes in the graph

  • total_edges (int): Total number of edges in the graph

  • execution_time (float): Total execution time in seconds

  • memory_usage (str): Memory usage information

  • operation_types (dict): Count of each operation type

  • model_size_mb (float): Model size in megabytes

get_graph_data()

Get the complete graph data for visualization.

Returns: dict - Dictionary containing:

  • nodes (list): List of GraphNode objects as dictionaries

  • edges (list): List of GraphEdge objects as dictionaries

export_graph(filepath, format='json')

Export the computational graph to a file.

Parameters:

  • filepath (str): Output file path

  • format (str, optional): Export format (‘json’, default: ‘json’)

Returns: None

visualize_graph(renderer='plotly')

Visualize the computational graph.

Parameters:

  • renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)

Returns: Visualization object (Plotly figure or Matplotlib figure)

save_graph_png(filepath, width=1200, height=800, dpi=300, show_legend=True, node_size=20, font_size=10)

Save the computational graph as a PNG image with enhanced visualization.

Parameters:

  • filepath (str): Output PNG file path

  • width (int, optional): Image width in pixels (default: 1200)

  • height (int, optional): Image height in pixels (default: 800)

  • dpi (int, optional): Dots per inch for high resolution (default: 300)

  • show_legend (bool, optional): Whether to show legend (default: True)

  • node_size (int, optional): Size of nodes in the graph (default: 20)

  • font_size (int, optional): Font size for labels (default: 10)

Returns: str - Path to the saved PNG file

GraphNode Class

class GraphNode(id, name, operation_type, module_name=None, input_shapes=None, output_shapes=None, parameters=None, execution_time=None, memory_usage=None, metadata=None, parent_ids=None, child_ids=None, timestamp=None)

Represents a node in the computational graph.

Parameters:

  • id (str): Unique identifier for the node

  • name (str): Name of the operation

  • operation_type (OperationType): Type of operation

  • module_name (str, optional): Name of the PyTorch module

  • input_shapes (list, optional): List of input tensor shapes

  • output_shapes (list, optional): List of output tensor shapes

  • parameters (dict, optional): Operation parameters

  • execution_time (float, optional): Execution time in seconds

  • memory_usage (int, optional): Memory usage in bytes

  • metadata (dict, optional): Additional metadata

  • parent_ids (list, optional): List of parent node IDs

  • child_ids (list, optional): List of child node IDs

  • timestamp (float, optional): Timestamp of execution

GraphEdge Class

class GraphEdge(source_id, target_id, edge_type, tensor_shape=None, metadata=None)

Represents an edge in the computational graph.

Parameters:

  • source_id (str): ID of the source node

  • target_id (str): ID of the target node

  • edge_type (str): Type of edge (e.g., ‘data_flow’, ‘gradient_flow’)

  • tensor_shape (tuple, optional): Shape of the tensor flowing through the edge

  • metadata (dict, optional): Additional metadata

OperationType Enum

class OperationType

Types of operations that can be tracked.

Values:

  • FORWARD: Forward pass operation

  • BACKWARD: Backward pass operation

  • TENSOR_OP: Tensor operation (add, multiply, etc.)

  • LAYER_OP: Layer operation

  • GRADIENT_OP: Gradient operation

  • MEMORY_OP: Memory operation

  • CUSTOM: Custom operation

Examples

Basic Computational Graph Tracking

import torch
import torch.nn as nn
from pytorch-graph import track_computational_graph

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

input_tensor = torch.randn(1, 784, requires_grad=True)

# Track computational graph
tracker = track_computational_graph(
    model=model,
    input_tensor=input_tensor,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

# Get summary
summary = tracker.get_graph_summary()
print(f"Total operations: {summary['total_nodes']}")
print(f"Execution time: {summary['execution_time']:.4f}s")

Advanced Usage with ComputationalGraphTracker

from pytorch-graph import ComputationalGraphTracker

# Create tracker with custom settings
tracker = ComputationalGraphTracker(
    model=model,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

# Start tracking
tracker.start_tracking()

# Run your model
output = model(input_tensor)
loss = output.sum()
loss.backward()

# Stop tracking
tracker.stop_tracking()

# Get comprehensive analysis
summary = tracker.get_graph_summary()
print(f"Operations: {summary['total_nodes']:,}")
print(f"Execution time: {summary['execution_time']:.4f}s")

# Save visualization
tracker.save_graph_png(
    "advanced_graph.png",
    width=2000,
    height=1500,
    dpi=300,
    show_legend=True,
    node_size=30,
    font_size=14
)

Graph Analysis

from pytorch-graph import analyze_computational_graph

# Comprehensive analysis
analysis = analyze_computational_graph(
    model=model,
    input_tensor=input_tensor,
    detailed=True
)

summary = analysis['summary']
print(f"Total operations: {summary['total_nodes']:,}")
print(f"Execution time: {summary['execution_time']:.4f}s")

# Performance metrics
if 'performance' in analysis:
    perf = analysis['performance']
    print(f"Operations per second: {perf['operations_per_second']:.2f}")
    print(f"Memory usage: {perf['memory_usage']}")

# Layer-wise analysis
if 'layer_analysis' in analysis:
    for layer_name, operations in analysis['layer_analysis'].items():
        print(f"{layer_name}: {len(operations)} operations")

Data Export

# Export to JSON
tracker.export_graph("graph_data.json")

# Load and inspect exported data
import json
with open("graph_data.json", 'r') as f:
    graph_data = json.load(f)

print(f"Nodes: {len(graph_data['nodes'])}")
print(f"Edges: {len(graph_data['edges'])}")

Visualization

from pytorch-graph import visualize_computational_graph

# Create interactive visualization
fig = visualize_computational_graph(
    model=model,
    input_tensor=input_tensor,
    renderer='plotly'
)
fig.show()

# Save as PNG
from pytorch-graph import save_computational_graph_png

png_path = save_computational_graph_png(
    model=model,
    input_tensor=input_tensor,
    filepath="computational_graph.png",
    width=1600,
    height=1200,
    dpi=300
)

Error Handling

The functions will raise appropriate exceptions for:

  • ImportError: If PyTorch is not available

  • RuntimeError: If model execution fails

  • ValueError: If invalid parameters are provided

  • FileNotFoundError: If output directory doesn’t exist

See Also

  • architecture_visualization - For architecture diagram generation

  • Model Analysis API - For model analysis functions

  • Utilities API - For utility classes and functions