Computational Graph Tracking API
This module provides comprehensive computational graph tracking and visualization capabilities for PyTorch models.
Core Functions
Classes
Function Details
track_computational_graph
- track_computational_graph(model, input_tensor, track_memory=True, track_timing=True, track_tensor_ops=True)
Track the computational graph of a PyTorch model execution.
Parameters:
model (torch.nn.Module): PyTorch model to track
input_tensor (torch.Tensor): Input tensor for the forward pass
track_memory (bool, optional): Whether to track memory usage (default: True)
track_timing (bool, optional): Whether to track execution timing (default: True)
track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)
Returns: ComputationalGraphTracker - Tracker instance with execution data
Raises:
ImportError: If PyTorch is not available
Example:
import torch import torch.nn as nn from pytorch-graph import track_computational_graph model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) ) input_tensor = torch.randn(1, 784, requires_grad=True) tracker = track_computational_graph( model=model, input_tensor=input_tensor, track_memory=True, track_timing=True, track_tensor_ops=True ) # Get summary summary = tracker.get_graph_summary() print(f"Total operations: {summary['total_nodes']}")
analyze_computational_graph
- analyze_computational_graph(model, input_tensor, detailed=True)
Analyze the computational graph of a PyTorch model execution.
Parameters:
model (torch.nn.Module): PyTorch model to analyze
input_tensor (torch.Tensor): Input tensor for the forward pass
detailed (bool, optional): Whether to include detailed analysis (default: True)
Returns: dict - Dictionary containing computational graph analysis
Raises:
ImportError: If PyTorch is not available
Example:
from pytorch-graph import analyze_computational_graph analysis = analyze_computational_graph( model=model, input_tensor=input_tensor, detailed=True ) summary = analysis['summary'] print(f"Total operations: {summary['total_nodes']}") print(f"Execution time: {summary['execution_time']:.4f}s")
track_computational_graph_execution
- track_computational_graph_execution(model, input_tensor, track_memory=True, track_timing=True, track_tensor_ops=True)
Track the computational graph of a PyTorch model execution (alias for track_computational_graph).
Parameters:
model (torch.nn.Module): PyTorch model to track
input_tensor (torch.Tensor): Input tensor for the forward pass
track_memory (bool, optional): Whether to track memory usage (default: True)
track_timing (bool, optional): Whether to track execution timing (default: True)
track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)
Returns: ComputationalGraphTracker - Tracker instance with execution data
analyze_computational_graph_execution
- analyze_computational_graph_execution(model, input_tensor, detailed=True)
Analyze the computational graph of a PyTorch model execution (alias for analyze_computational_graph).
Parameters:
model (torch.nn.Module): PyTorch model to analyze
input_tensor (torch.Tensor): Input tensor for the forward pass
detailed (bool, optional): Whether to include detailed analysis (default: True)
Returns: dict - Dictionary containing computational graph analysis
visualize_computational_graph
- visualize_computational_graph(model, input_tensor, renderer='plotly')
Visualize the computational graph of a PyTorch model execution.
Parameters:
model (torch.nn.Module): PyTorch model to visualize
input_tensor (torch.Tensor): Input tensor for the forward pass
renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)
Returns: Visualization object (Plotly figure or Matplotlib figure)
Raises:
ImportError: If PyTorch is not available
Example:
from pytorch-graph import visualize_computational_graph fig = visualize_computational_graph( model=model, input_tensor=input_tensor, renderer='plotly' ) fig.show()
export_computational_graph
- export_computational_graph(model, input_tensor, filepath, format='json')
Export the computational graph of a PyTorch model execution to a file.
Parameters:
model (torch.nn.Module): PyTorch model to export
input_tensor (torch.Tensor): Input tensor for the forward pass
filepath (str): Output file path
format (str, optional): Export format (‘json’, default: ‘json’)
Returns: str - Path to the exported file
Raises:
ImportError: If PyTorch is not available
Example:
from pytorch-graph import export_computational_graph filepath = export_computational_graph( model=model, input_tensor=input_tensor, filepath='graph.json' ) print(f"Graph exported to: {filepath}")
save_computational_graph_png
- save_computational_graph_png(model, input_tensor, filepath='computational_graph.png', width=1200, height=800, dpi=300, show_legend=True, node_size=20, font_size=10)
Save the computational graph as a high-quality PNG image.
Parameters:
model (torch.nn.Module): PyTorch model to visualize
input_tensor (torch.Tensor): Input tensor for the forward pass
filepath (str, optional): Output PNG file path (default: “computational_graph.png”)
width (int, optional): Image width in pixels (default: 1200)
height (int, optional): Image height in pixels (default: 800)
dpi (int, optional): Dots per inch for high resolution (default: 300)
show_legend (bool, optional): Whether to show legend (default: True)
node_size (int, optional): Size of nodes in the graph (default: 20)
font_size (int, optional): Font size for labels (default: 10)
Returns: str - Path to the saved PNG file
Raises:
ImportError: If PyTorch is not available
Example:
from pytorch-graph import save_computational_graph_png png_path = save_computational_graph_png( model=model, input_tensor=input_tensor, filepath="graph.png", width=1600, height=1200, dpi=300 ) print(f"PNG saved to: {png_path}")
ComputationalGraphTracker Class
- class ComputationalGraphTracker(model, track_memory=True, track_timing=True, track_tensor_ops=True)
Tracks the computational graph of PyTorch model execution.
Parameters:
model (torch.nn.Module): PyTorch model to track
track_memory (bool, optional): Whether to track memory usage (default: True)
track_timing (bool, optional): Whether to track execution timing (default: True)
track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)
Methods:
- start_tracking()
Start tracking the computational graph.
Returns: None
- stop_tracking()
Stop tracking the computational graph.
Returns: None
- get_graph_summary()
Get a summary of the computational graph.
Returns: dict - Dictionary containing graph summary with keys:
total_nodes (int): Total number of nodes in the graph
total_edges (int): Total number of edges in the graph
execution_time (float): Total execution time in seconds
memory_usage (str): Memory usage information
operation_types (dict): Count of each operation type
model_size_mb (float): Model size in megabytes
- get_graph_data()
Get the complete graph data for visualization.
Returns: dict - Dictionary containing:
nodes (list): List of GraphNode objects as dictionaries
edges (list): List of GraphEdge objects as dictionaries
- export_graph(filepath, format='json')
Export the computational graph to a file.
Parameters:
filepath (str): Output file path
format (str, optional): Export format (‘json’, default: ‘json’)
Returns: None
- visualize_graph(renderer='plotly')
Visualize the computational graph.
Parameters:
renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)
Returns: Visualization object (Plotly figure or Matplotlib figure)
- save_graph_png(filepath, width=1200, height=800, dpi=300, show_legend=True, node_size=20, font_size=10)
Save the computational graph as a PNG image with enhanced visualization.
Parameters:
filepath (str): Output PNG file path
width (int, optional): Image width in pixels (default: 1200)
height (int, optional): Image height in pixels (default: 800)
dpi (int, optional): Dots per inch for high resolution (default: 300)
show_legend (bool, optional): Whether to show legend (default: True)
node_size (int, optional): Size of nodes in the graph (default: 20)
font_size (int, optional): Font size for labels (default: 10)
Returns: str - Path to the saved PNG file
GraphNode Class
- class GraphNode(id, name, operation_type, module_name=None, input_shapes=None, output_shapes=None, parameters=None, execution_time=None, memory_usage=None, metadata=None, parent_ids=None, child_ids=None, timestamp=None)
Represents a node in the computational graph.
Parameters:
id (str): Unique identifier for the node
name (str): Name of the operation
operation_type (OperationType): Type of operation
module_name (str, optional): Name of the PyTorch module
input_shapes (list, optional): List of input tensor shapes
output_shapes (list, optional): List of output tensor shapes
parameters (dict, optional): Operation parameters
execution_time (float, optional): Execution time in seconds
memory_usage (int, optional): Memory usage in bytes
metadata (dict, optional): Additional metadata
parent_ids (list, optional): List of parent node IDs
child_ids (list, optional): List of child node IDs
timestamp (float, optional): Timestamp of execution
GraphEdge Class
- class GraphEdge(source_id, target_id, edge_type, tensor_shape=None, metadata=None)
Represents an edge in the computational graph.
Parameters:
source_id (str): ID of the source node
target_id (str): ID of the target node
edge_type (str): Type of edge (e.g., ‘data_flow’, ‘gradient_flow’)
tensor_shape (tuple, optional): Shape of the tensor flowing through the edge
metadata (dict, optional): Additional metadata
OperationType Enum
- class OperationType
Types of operations that can be tracked.
Values:
FORWARD: Forward pass operation
BACKWARD: Backward pass operation
TENSOR_OP: Tensor operation (add, multiply, etc.)
LAYER_OP: Layer operation
GRADIENT_OP: Gradient operation
MEMORY_OP: Memory operation
CUSTOM: Custom operation
Examples
Basic Computational Graph Tracking
import torch
import torch.nn as nn
from pytorch-graph import track_computational_graph
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
input_tensor = torch.randn(1, 784, requires_grad=True)
# Track computational graph
tracker = track_computational_graph(
model=model,
input_tensor=input_tensor,
track_memory=True,
track_timing=True,
track_tensor_ops=True
)
# Get summary
summary = tracker.get_graph_summary()
print(f"Total operations: {summary['total_nodes']}")
print(f"Execution time: {summary['execution_time']:.4f}s")
Advanced Usage with ComputationalGraphTracker
from pytorch-graph import ComputationalGraphTracker
# Create tracker with custom settings
tracker = ComputationalGraphTracker(
model=model,
track_memory=True,
track_timing=True,
track_tensor_ops=True
)
# Start tracking
tracker.start_tracking()
# Run your model
output = model(input_tensor)
loss = output.sum()
loss.backward()
# Stop tracking
tracker.stop_tracking()
# Get comprehensive analysis
summary = tracker.get_graph_summary()
print(f"Operations: {summary['total_nodes']:,}")
print(f"Execution time: {summary['execution_time']:.4f}s")
# Save visualization
tracker.save_graph_png(
"advanced_graph.png",
width=2000,
height=1500,
dpi=300,
show_legend=True,
node_size=30,
font_size=14
)
Graph Analysis
from pytorch-graph import analyze_computational_graph
# Comprehensive analysis
analysis = analyze_computational_graph(
model=model,
input_tensor=input_tensor,
detailed=True
)
summary = analysis['summary']
print(f"Total operations: {summary['total_nodes']:,}")
print(f"Execution time: {summary['execution_time']:.4f}s")
# Performance metrics
if 'performance' in analysis:
perf = analysis['performance']
print(f"Operations per second: {perf['operations_per_second']:.2f}")
print(f"Memory usage: {perf['memory_usage']}")
# Layer-wise analysis
if 'layer_analysis' in analysis:
for layer_name, operations in analysis['layer_analysis'].items():
print(f"{layer_name}: {len(operations)} operations")
Data Export
# Export to JSON
tracker.export_graph("graph_data.json")
# Load and inspect exported data
import json
with open("graph_data.json", 'r') as f:
graph_data = json.load(f)
print(f"Nodes: {len(graph_data['nodes'])}")
print(f"Edges: {len(graph_data['edges'])}")
Visualization
from pytorch-graph import visualize_computational_graph
# Create interactive visualization
fig = visualize_computational_graph(
model=model,
input_tensor=input_tensor,
renderer='plotly'
)
fig.show()
# Save as PNG
from pytorch-graph import save_computational_graph_png
png_path = save_computational_graph_png(
model=model,
input_tensor=input_tensor,
filepath="computational_graph.png",
width=1600,
height=1200,
dpi=300
)
Error Handling
The functions will raise appropriate exceptions for:
ImportError: If PyTorch is not available
RuntimeError: If model execution fails
ValueError: If invalid parameters are provided
FileNotFoundError: If output directory doesn’t exist
See Also
architecture_visualization - For architecture diagram generation
Model Analysis API - For model analysis functions
Utilities API - For utility classes and functions