Architecture Visualization API
This module provides comprehensive functions for generating professional architecture diagrams from PyTorch models.
Core Functions
Classes
Function Details
generate_architecture_diagram
- generate_architecture_diagram(model, input_shape, output_path='architecture.png', title=None, format='png', style='flowchart')
Generate an enhanced flowchart architecture diagram from a PyTorch model and save as PNG.
Parameters:
model (torch.nn.Module): The PyTorch model to visualize
input_shape (tuple): The input tensor shape for the model
output_path (str, optional): Output file path (default: “architecture.png”)
title (str, optional): Diagram title (auto-generated if None)
format (str, optional): Output format (‘png’ or ‘txt’, default: “png”)
style (str, optional): Diagram style (‘flowchart’, ‘standard’, or ‘research_paper’, default: “flowchart”)
Returns: str - Path to the generated diagram file
Raises:
ImportError: If PyTorch is not available
ValueError: If unsupported format is specified
Example:
import torch.nn as nn from pytorch-graph import generate_architecture_diagram model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) ) # Generate architecture diagram path = generate_architecture_diagram( model=model, input_shape=(1, 784), output_path="model_architecture.png", title="My Neural Network", style="flowchart" ) print(f"Diagram saved to: {path}")
save_architecture_diagram
- save_architecture_diagram(model, input_shape, output_path='architecture.png', **kwargs)
Generate and save an enhanced flowchart architecture diagram (alias for generate_architecture_diagram).
Parameters:
model (torch.nn.Module): The PyTorch model to visualize
input_shape (tuple): The input tensor shape for the model
output_path (str, optional): Output file path (default: “architecture.png”)
**kwargs: Additional arguments passed to generate_architecture_diagram
Returns: str - Path to the generated diagram file
generate_research_paper_diagram
- generate_research_paper_diagram(model, input_shape, output_path='model_architecture_paper.png', title=None)
Generate a research paper quality architecture diagram.
Parameters:
model (torch.nn.Module): The PyTorch model to visualize
input_shape (tuple): The input tensor shape for the model
output_path (str, optional): Output file path (default: “model_architecture_paper.png”)
title (str, optional): Diagram title (auto-generated if None)
Returns: str - Path to the generated diagram file
generate_flowchart_diagram
- generate_flowchart_diagram(model, input_shape, output_path='model_flowchart.png', title=None)
Generate a clean flowchart-style architecture diagram with vertical flow.
Parameters:
model (torch.nn.Module): The PyTorch model to visualize
input_shape (tuple): The input tensor shape for the model
output_path (str, optional): Output file path (default: “model_flowchart.png”)
title (str, optional): Diagram title (auto-generated if None)
Returns: str - Path to the generated diagram file
visualize
- visualize(model, input_shape=None, renderer='plotly', **kwargs)
Visualize a PyTorch model in 3D.
Parameters:
model (torch.nn.Module): PyTorch model to visualize
input_shape (tuple, optional): Input tensor shape (if None will try to infer)
renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)
**kwargs: Additional visualization parameters
Returns: Visualization object (Plotly figure or Matplotlib figure)
Raises:
ImportError: If PyTorch is not available
visualize_model
- visualize_model(model, input_shape=None, renderer='plotly', **kwargs)
Alias for visualize() function for backward compatibility.
Parameters: Same as
visualize()Returns: Visualization object
compare_models
- compare_models(models, names=None, input_shapes=None, renderer='plotly', **kwargs)
Compare multiple PyTorch models in a single visualization.
Parameters:
models (list): List of PyTorch models
names (list, optional): Optional list of model names
input_shapes (list, optional): Optional list of input shapes for each model
renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)
**kwargs: Additional visualization parameters
Returns: Comparison visualization object
Raises:
ImportError: If PyTorch is not available
create_architecture_report
- create_architecture_report(model, input_shape=None, output_path='pytorch-graph_report.html')
Create a comprehensive HTML report of the PyTorch model architecture.
Parameters:
model (torch.nn.Module): PyTorch model to analyze
input_shape (tuple, optional): Input tensor shape
output_path (str, optional): Path for the output HTML file (default: “pytorch-graph_report.html”)
Returns: None
Raises:
ImportError: If PyTorch is not available
PyTorchVisualizer Class
- class PyTorchVisualizer(renderer='plotly', layout_style='hierarchical', spacing=2.0, theme='plotly_dark', width=1200, height=800)
Main class for visualizing PyTorch neural network architectures in 3D.
Parameters:
renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)
layout_style (str, optional): Layout algorithm (‘hierarchical’, ‘circular’, ‘spring’, ‘custom’, default: ‘hierarchical’)
spacing (float, optional): Spacing between layers (default: 2.0)
theme (str, optional): Color theme for visualization (default: ‘plotly_dark’)
width (int, optional): Figure width in pixels (default: 1200)
height (int, optional): Figure height in pixels (default: 800)
Methods:
- visualize(model, input_shape=None, title=None, show_connections=True, show_labels=True, show_parameters=False, show_activations=False, optimize_layout=True, device='auto', export_path=None, **kwargs)
Visualize a PyTorch model.
Parameters:
model (torch.nn.Module): PyTorch model to visualize
input_shape (tuple, optional): Input tensor shape (required for detailed analysis)
title (str, optional): Plot title
show_connections (bool, optional): Whether to show connections between layers (default: True)
show_labels (bool, optional): Whether to show layer labels (default: True)
show_parameters (bool, optional): Whether to show parameter count visualization (default: False)
show_activations (bool, optional): Whether to include activation statistics (default: False)
optimize_layout (bool, optional): Whether to optimize layer positions (default: True)
device (str, optional): Device for model analysis (‘auto’, ‘cpu’, ‘cuda’, default: ‘auto’)
export_path (str, optional): Path to export the visualization (optional)
**kwargs: Additional rendering options
Returns: Rendered visualization object
- get_model_summary(model, input_shape=None, device='auto')
Get a comprehensive summary of the PyTorch model.
Parameters:
model (torch.nn.Module): PyTorch model
input_shape (tuple, optional): Input tensor shape
device (str, optional): Device for analysis (default: ‘auto’)
Returns: dict - Dictionary containing model summary
- analyze_model(model, input_shape=None, device='auto', detailed=True)
Perform comprehensive analysis of the PyTorch model.
Parameters:
model (torch.nn.Module): PyTorch model
input_shape (tuple, optional): Input tensor shape
device (str, optional): Device for analysis (default: ‘auto’)
detailed (bool, optional): Whether to include detailed analysis (default: True)
Returns: dict - Dictionary containing comprehensive analysis
- profile_model(model, input_shape, device='cpu', num_runs=100)
Profile PyTorch model performance.
Parameters:
model (torch.nn.Module): PyTorch model
input_shape (tuple): Input tensor shape
device (str, optional): Device for profiling (default: ‘cpu’)
num_runs (int, optional): Number of timing runs (default: 100)
Returns: dict - Dictionary with profiling results
- compare_models(models, names=None, input_shapes=None, device='auto', **kwargs)
Compare multiple PyTorch models in a single visualization.
Parameters:
models (list): List of PyTorch models
names (list, optional): Optional list of model names
input_shapes (list, optional): Optional list of input shapes for each model
device (str, optional): Device for analysis (default: ‘auto’)
**kwargs: Additional visualization options
Returns: Rendered comparison visualization
- visualize_feature_maps(model, input_tensor, layer_names=None, max_channels=16)
Visualize feature maps from convolutional layers.
Parameters:
model (torch.nn.Module): PyTorch CNN model
input_tensor (torch.Tensor): Input tensor for feature extraction
layer_names (list, optional): Specific conv layers to visualize
max_channels (int, optional): Maximum channels per layer to visualize (default: 16)
Returns: Feature map visualization
- create_training_visualization(model, input_shape, num_epochs=1)
Create visualizations showing model changes during training.
Parameters:
model (torch.nn.Module): PyTorch model
input_shape (tuple): Input tensor shape
num_epochs (int, optional): Number of training epochs to simulate (default: 1)
Returns: list - List of visualizations for each epoch
- export_architecture_report(model, input_shape=None, output_path='pytorch_report.html', include_profiling=True)
Export a comprehensive HTML report of the PyTorch model architecture.
Parameters:
model (torch.nn.Module): PyTorch model
input_shape (tuple, optional): Input tensor shape
output_path (str, optional): Path for the output HTML file (default: “pytorch_report.html”)
include_profiling (bool, optional): Whether to include performance profiling (default: True)
Returns: None
- set_theme(theme)
Set the visualization theme.
Parameters:
theme (str): Theme name
Returns: None
- set_layout_style(layout_style)
Set the layout style for positioning layers.
Parameters:
layout_style (str): Layout style name
Returns: None
- set_spacing(spacing)
Set the spacing between layers.
Parameters:
spacing (float): Spacing value
Returns: None
Examples
Basic Architecture Diagram
import torch.nn as nn
from pytorch-graph import generate_architecture_diagram
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# Generate architecture diagram
path = generate_architecture_diagram(
model=model,
input_shape=(1, 784),
output_path="model_architecture.png",
title="My Neural Network",
style="flowchart"
)
print(f"Diagram saved to: {path}")
Research Paper Quality Diagram
from pytorch-graph import generate_research_paper_diagram
path = generate_research_paper_diagram(
model=model,
input_shape=(1, 784),
output_path="paper_architecture.png",
title="Research Model Architecture"
)
3D Visualization
from pytorch-graph import visualize
fig = visualize(
model=model,
input_shape=(1, 784),
renderer='plotly',
title="3D Model Visualization",
show_parameters=True,
show_activations=True
)
fig.show()
Model Comparison
from pytorch-graph import compare_models
models = [mlp_model, cnn_model, resnet_model]
names = ["MLP", "CNN", "ResNet"]
input_shapes = [(1, 784), (1, 3, 32, 32), (1, 3, 224, 224)]
fig = compare_models(
models=models,
names=names,
input_shapes=input_shapes,
renderer='plotly'
)
fig.show()
Comprehensive Report
from pytorch-graph import create_architecture_report
create_architecture_report(
model=model,
input_shape=(1, 784),
output_path="my_model_report.html"
)
Advanced Usage with PyTorchVisualizer
from pytorch-graph import PyTorchVisualizer
# Create visualizer with custom settings
visualizer = PyTorchVisualizer(
renderer='plotly',
layout_style='hierarchical',
spacing=2.5,
theme='plotly_white',
width=1400,
height=900
)
# Visualize with advanced options
fig = visualizer.visualize(
model=model,
input_shape=(1, 784),
title="Advanced Model Visualization",
show_connections=True,
show_labels=True,
show_parameters=True,
show_activations=True,
optimize_layout=True,
device='auto'
)
# Get model analysis
analysis = visualizer.analyze_model(model, input_shape=(1, 784), detailed=True)
print(f"Total parameters: {analysis['basic_info']['total_parameters']:,}")
# Profile performance
profiling = visualizer.profile_model(model, input_shape=(1, 784), num_runs=50)
print(f"Average inference time: {profiling['mean_time_ms']:.2f} ms")
Error Handling
The functions will raise appropriate exceptions for:
ImportError: If PyTorch is not available
ValueError: If unsupported format or style is specified
FileNotFoundError: If output directory doesn’t exist
RuntimeError: If model execution fails
See Also
computational_graph_tracking - For computational graph analysis
Model Analysis API - For model analysis functions
Utilities API - For utility classes and functions