Advanced Features
PyTorch Graph provides advanced features for power users and researchers.
Overview
Advanced features include:
Custom Visualization Parameters: Fine-tune output appearance
Data Export and Analysis: Export graph data for offline analysis
Performance Optimization: Optimize for large models and complex graphs
Integration with Workflows: Seamless integration with existing PyTorch workflows
Custom Styling: Advanced customization options
Custom Visualization Parameters
Fine-tune your visualizations with custom parameters:
from pytorch-graph import ComputationalGraphTracker
# Create tracker with custom settings
tracker = ComputationalGraphTracker(
model=model,
track_memory=True,
track_timing=True,
track_tensor_ops=True
)
tracker.start_tracking()
output = model(input_tensor)
loss = output.sum()
loss.backward()
tracker.stop_tracking()
# Save with custom parameters
tracker.save_graph_png(
filepath="custom_visualization.png",
width=2000, # Custom width
height=1500, # Custom height
dpi=300, # High DPI for publication
show_legend=True, # Show legend
node_size=30, # Node size
font_size=14 # Font size
)
High-Resolution Output
For publication-quality images:
# Ultra-high resolution for large displays
tracker.save_graph_png(
filepath="ultra_hd_graph.png",
width=4000,
height=3000,
dpi=300,
node_size=50,
font_size=20
)
Data Export and Analysis
Export graph data for offline analysis:
# Export complete graph data
tracker.export_graph("complete_graph_data.json")
# Load and analyze exported data
import json
with open("complete_graph_data.json", 'r') as f:
graph_data = json.load(f)
print(f"Total nodes: {len(graph_data['nodes'])}")
print(f"Total edges: {len(graph_data['edges'])}")
# Analyze node types
node_types = {}
for node in graph_data['nodes']:
node_type = node['operation_type']
node_types[node_type] = node_types.get(node_type, 0) + 1
print("Node type distribution:")
for node_type, count in node_types.items():
print(f" {node_type}: {count}")
Custom Analysis Functions
Create custom analysis functions:
def analyze_graph_complexity(graph_data):
"""Analyze the complexity of a computational graph."""
nodes = graph_data['nodes']
edges = graph_data['edges']
# Calculate metrics
total_nodes = len(nodes)
total_edges = len(edges)
avg_connections = total_edges / total_nodes if total_nodes > 0 else 0
# Find most connected nodes
node_connections = {}
for edge in edges:
source = edge['source_id']
target = edge['target_id']
node_connections[source] = node_connections.get(source, 0) + 1
node_connections[target] = node_connections.get(target, 0) + 1
most_connected = max(node_connections.items(), key=lambda x: x[1]) if node_connections else None
return {
'total_nodes': total_nodes,
'total_edges': total_edges,
'average_connections': avg_connections,
'most_connected_node': most_connected
}
# Use custom analysis
complexity_analysis = analyze_graph_complexity(graph_data)
print(f"Graph complexity: {complexity_analysis}")
Performance Optimization
Optimize for large models and complex graphs:
# Optimize for large models
tracker = ComputationalGraphTracker(
model=large_model,
track_memory=True, # Keep memory tracking
track_timing=True, # Keep timing
track_tensor_ops=False # Disable for performance
)
# Use smaller input for testing
test_input = torch.randn(1, 3, 224, 224, requires_grad=True)
tracker.start_tracking()
output = large_model(test_input)
loss = output.sum()
loss.backward()
tracker.stop_tracking()
# Save with optimized settings
tracker.save_graph_png(
"large_model_graph.png",
width=3000, # Larger canvas for complex graphs
height=2000,
dpi=200, # Lower DPI for faster rendering
node_size=20,
font_size=10
)
Memory-Efficient Tracking
For memory-constrained environments:
# Minimal tracking for memory efficiency
tracker = ComputationalGraphTracker(
model=model,
track_memory=False, # Disable memory tracking
track_timing=False, # Disable timing
track_tensor_ops=False # Disable tensor operations
)
# Process in chunks for very large models
def process_large_model_in_chunks(model, input_tensor, chunk_size=1000):
tracker = ComputationalGraphTracker(model, track_memory=False)
tracker.start_tracking()
# Process model
output = model(input_tensor)
loss = output.sum()
loss.backward()
tracker.stop_tracking()
return tracker
Integration with Workflows
Seamlessly integrate with existing PyTorch workflows:
def train_with_graph_tracking(model, dataloader, num_epochs=10):
"""Training loop with computational graph tracking."""
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(dataloader):
# Track computational graph for first batch of each epoch
if batch_idx == 0:
tracker = track_computational_graph(model, data)
# Save graph for this epoch
tracker.save_graph_png(
f"epoch_{epoch}_computational_graph.png",
width=1600,
height=1200,
dpi=300
)
# Get performance metrics
summary = tracker.get_graph_summary()
print(f"Epoch {epoch}: {summary['total_nodes']} operations, "
f"{summary['execution_time']:.4f}s")
# Your existing training code
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
Model Comparison Workflow
Compare multiple models systematically:
def compare_models_comprehensive(models, input_shapes, output_dir="comparison"):
"""Comprehensive model comparison with visualizations."""
import os
os.makedirs(output_dir, exist_ok=True)
results = {}
for name, (model, input_shape) in models.items():
print(f"Analyzing {name}...")
# Architecture visualization
generate_architecture_diagram(
model=model,
input_shape=input_shape,
output_path=f"{output_dir}/{name}_architecture.png",
title=f"{name} Architecture"
)
# Computational graph tracking
input_tensor = torch.randn(*input_shape, requires_grad=True)
tracker = track_computational_graph(model, input_tensor)
tracker.save_graph_png(
f"{output_dir}/{name}_computational_graph.png",
width=1600,
height=1200,
dpi=300
)
# Analysis
model_analysis = analyze_model(model, input_shape=input_shape)
graph_analysis = analyze_computational_graph(model, input_tensor)
results[name] = {
'parameters': model_analysis['total_parameters'],
'model_size': model_analysis['model_size_mb'],
'operations': graph_analysis['summary']['total_nodes'],
'execution_time': graph_analysis['summary']['execution_time']
}
# Save comparison results
with open(f"{output_dir}/comparison_results.json", 'w') as f:
json.dump(results, f, indent=2)
return results
Custom Styling
Create custom visualization styles:
def create_custom_style_graph(tracker, output_path, style_config):
"""Create a graph with custom styling."""
# This would be implemented in the library
# For now, we use the standard method with custom parameters
tracker.save_graph_png(
filepath=output_path,
width=style_config.get('width', 1600),
height=style_config.get('height', 1200),
dpi=style_config.get('dpi', 300),
show_legend=style_config.get('show_legend', True),
node_size=style_config.get('node_size', 25),
font_size=style_config.get('font_size', 12)
)
# Custom style configuration
custom_style = {
'width': 2000,
'height': 1500,
'dpi': 300,
'show_legend': True,
'node_size': 30,
'font_size': 14
}
create_custom_style_graph(tracker, "custom_style_graph.png", custom_style)
Batch Processing
Process multiple models in batch:
def batch_process_models(models, input_shapes, output_dir="batch_output"):
"""Process multiple models in batch."""
import os
os.makedirs(output_dir, exist_ok=True)
for name, (model, input_shape) in models.items():
print(f"Processing {name}...")
# Create output subdirectory
model_dir = os.path.join(output_dir, name)
os.makedirs(model_dir, exist_ok=True)
# Generate all visualizations
generate_architecture_diagram(
model=model,
input_shape=input_shape,
output_path=os.path.join(model_dir, "architecture.png"),
title=f"{name} Architecture"
)
input_tensor = torch.randn(*input_shape, requires_grad=True)
tracker = track_computational_graph(model, input_tensor)
tracker.save_graph_png(
os.path.join(model_dir, "computational_graph.png"),
width=1600,
height=1200,
dpi=300
)
# Export data
tracker.export_graph(os.path.join(model_dir, "graph_data.json"))
print(f"Completed {name}")
Advanced Examples
Research Paper Workflow
Complete workflow for research papers:
def research_paper_workflow(model, input_shape, model_name):
"""Complete workflow for research paper figures."""
print(f"Generating research figures for {model_name}...")
# Architecture diagram (research style)
generate_architecture_diagram(
model=model,
input_shape=input_shape,
output_path=f"{model_name}_architecture_research.png",
style="research_paper",
title=f"{model_name} Architecture",
dpi=300
)
# Computational graph
input_tensor = torch.randn(*input_shape, requires_grad=True)
tracker = track_computational_graph(model, input_tensor)
tracker.save_graph_png(
f"{model_name}_computational_graph.png",
width=2000,
height=1500,
dpi=300,
show_legend=True,
node_size=25,
font_size=12
)
# Analysis data
analysis = analyze_computational_graph(model, input_tensor, detailed=True)
# Save analysis results
with open(f"{model_name}_analysis.json", 'w') as f:
json.dump(analysis, f, indent=2, default=str)
print(f"Research figures generated for {model_name}")
Performance Profiling
Detailed performance profiling:
def profile_model_performance(model, input_tensor, num_runs=10):
"""Detailed performance profiling."""
import time
execution_times = []
memory_usage = []
for i in range(num_runs):
start_time = time.time()
tracker = ComputationalGraphTracker(
model=model,
track_memory=True,
track_timing=True,
track_tensor_ops=True
)
tracker.start_tracking()
output = model(input_tensor)
loss = output.sum()
loss.backward()
tracker.stop_tracking()
end_time = time.time()
execution_times.append(end_time - start_time)
summary = tracker.get_graph_summary()
if summary['memory_usage']:
memory_usage.append(summary['memory_usage'])
# Calculate statistics
avg_time = sum(execution_times) / len(execution_times)
std_time = (sum((t - avg_time) ** 2 for t in execution_times) / len(execution_times)) ** 0.5
print(f"Performance Profiling ({num_runs} runs):")
print(f" Average execution time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f" Min execution time: {min(execution_times):.4f}s")
print(f" Max execution time: {max(execution_times):.4f}s")
if memory_usage:
avg_memory = sum(memory_usage) / len(memory_usage)
print(f" Average memory usage: {avg_memory}")
return {
'execution_times': execution_times,
'memory_usage': memory_usage,
'statistics': {
'average_time': avg_time,
'std_time': std_time,
'min_time': min(execution_times),
'max_time': max(execution_times)
}
}
Best Practices
Use appropriate parameters for your use case
Optimize for performance when working with large models
Export data for offline analysis
Batch process multiple models efficiently
Monitor memory usage in memory-constrained environments
Use high DPI for publication-quality output
Troubleshooting
Common Issues
- Memory issues with large models
Use
track_tensor_ops=Falseand smaller input tensors- Slow rendering with complex graphs
Reduce DPI or use smaller canvas sizes
- Export file too large
Consider filtering the exported data
- Integration issues
Ensure proper error handling in your workflows
See Also
Architecture Visualization - For architecture diagram generation
Computational Graph Tracking - For computational graph analysis
Model Analysis - For model analysis functions