Main Module API

This module provides the main public API for PyTorch Graph. All functions and classes are available directly from the pytorch-graph package.

Core Functions

Architecture Visualization

Model Analysis

Computational Graph Tracking

Classes

Core Classes

Utility Classes

Data Classes

Enums

Complete API Reference

Architecture Visualization Functions

generate_architecture_diagram

generate_architecture_diagram(model, input_shape, output_path='architecture.png', title=None, format='png', style='flowchart')

Generate an enhanced flowchart architecture diagram from a PyTorch model and save as PNG.

Parameters:

  • model (torch.nn.Module): The PyTorch model to visualize

  • input_shape (tuple): The input tensor shape for the model

  • output_path (str, optional): Output file path (default: “architecture.png”)

  • title (str, optional): Diagram title (auto-generated if None)

  • format (str, optional): Output format (‘png’ or ‘txt’, default: “png”)

  • style (str, optional): Diagram style (‘flowchart’, ‘standard’, or ‘research_paper’, default: “flowchart”)

Returns: str - Path to the generated diagram file

save_architecture_diagram

save_architecture_diagram(model, input_shape, output_path='architecture.png', **kwargs)

Generate and save an enhanced flowchart architecture diagram (alias for generate_architecture_diagram).

Parameters:

  • model (torch.nn.Module): The PyTorch model to visualize

  • input_shape (tuple): The input tensor shape for the model

  • output_path (str, optional): Output file path (default: “architecture.png”)

  • **kwargs: Additional arguments passed to generate_architecture_diagram

Returns: str - Path to the generated diagram file

generate_research_paper_diagram

generate_research_paper_diagram(model, input_shape, output_path='model_architecture_paper.png', title=None)

Generate a research paper quality architecture diagram.

Parameters:

  • model (torch.nn.Module): The PyTorch model to visualize

  • input_shape (tuple): The input tensor shape for the model

  • output_path (str, optional): Output file path (default: “model_architecture_paper.png”)

  • title (str, optional): Diagram title (auto-generated if None)

Returns: str - Path to the generated diagram file

generate_flowchart_diagram

generate_flowchart_diagram(model, input_shape, output_path='model_flowchart.png', title=None)

Generate a clean flowchart-style architecture diagram with vertical flow.

Parameters:

  • model (torch.nn.Module): The PyTorch model to visualize

  • input_shape (tuple): The input tensor shape for the model

  • output_path (str, optional): Output file path (default: “model_flowchart.png”)

  • title (str, optional): Diagram title (auto-generated if None)

Returns: str - Path to the generated diagram file

visualize

visualize(model, input_shape=None, renderer='plotly', **kwargs)

Visualize a PyTorch model in 3D.

Parameters:

  • model (torch.nn.Module): PyTorch model to visualize

  • input_shape (tuple, optional): Input tensor shape (if None will try to infer)

  • renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)

  • **kwargs: Additional visualization parameters

Returns: Visualization object (Plotly figure or Matplotlib figure)

visualize_model

visualize_model(model, input_shape=None, renderer='plotly', **kwargs)

Alias for visualize() function for backward compatibility.

Parameters: Same as visualize()

Returns: Visualization object

compare_models

compare_models(models, names=None, input_shapes=None, renderer='plotly', **kwargs)

Compare multiple PyTorch models in a single visualization.

Parameters:

  • models (list): List of PyTorch models

  • names (list, optional): Optional list of model names

  • input_shapes (list, optional): Optional list of input shapes for each model

  • renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)

  • **kwargs: Additional visualization parameters

Returns: Comparison visualization object

create_architecture_report

create_architecture_report(model, input_shape=None, output_path='pytorch-graph_report.html')

Create a comprehensive HTML report of the PyTorch model architecture.

Parameters:

  • model (torch.nn.Module): PyTorch model to analyze

  • input_shape (tuple, optional): Input tensor shape

  • output_path (str, optional): Path for the output HTML file (default: “pytorch-graph_report.html”)

Returns: None

Model Analysis Functions

analyze_model

analyze_model(model, input_shape=None, detailed=True)

Analyze a PyTorch model and return detailed statistics.

Parameters:

  • model (torch.nn.Module): PyTorch model to analyze

  • input_shape (tuple, optional): Input tensor shape

  • detailed (bool, optional): Whether to include detailed layer analysis (default: True)

Returns: dict - Dictionary containing model analysis

profile_model

profile_model(model, input_shape, device='cpu')

Profile a PyTorch model for performance analysis.

Parameters:

  • model (torch.nn.Module): PyTorch model to profile

  • input_shape (tuple): Input tensor shape

  • device (str, optional): Device to run profiling on (‘cpu’ or ‘cuda’, default: ‘cpu’)

Returns: dict - Profiling results dictionary

extract_activations

extract_activations(model, input_tensor, layer_names=None)

Extract intermediate activations from a PyTorch model.

Parameters:

  • model (torch.nn.Module): PyTorch model to extract activations from

  • input_tensor (torch.Tensor): Input tensor for forward pass

  • layer_names (list, optional): Specific layer names to extract (if None, extracts all)

Returns: dict - Dictionary of layer names to activation tensors

Computational Graph Functions

track_computational_graph

track_computational_graph(model, input_tensor, track_memory=True, track_timing=True, track_tensor_ops=True)

Track the computational graph of a PyTorch model execution.

Parameters:

  • model (torch.nn.Module): PyTorch model to track

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • track_memory (bool, optional): Whether to track memory usage (default: True)

  • track_timing (bool, optional): Whether to track execution timing (default: True)

  • track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)

Returns: ComputationalGraphTracker - Tracker instance with execution data

analyze_computational_graph

analyze_computational_graph(model, input_tensor, detailed=True)

Analyze the computational graph of a PyTorch model execution.

Parameters:

  • model (torch.nn.Module): PyTorch model to analyze

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • detailed (bool, optional): Whether to include detailed analysis (default: True)

Returns: dict - Dictionary containing computational graph analysis

track_computational_graph_execution

track_computational_graph_execution(model, input_tensor, track_memory=True, track_timing=True, track_tensor_ops=True)

Track the computational graph of a PyTorch model execution (alias for track_computational_graph).

Parameters:

  • model (torch.nn.Module): PyTorch model to track

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • track_memory (bool, optional): Whether to track memory usage (default: True)

  • track_timing (bool, optional): Whether to track execution timing (default: True)

  • track_tensor_ops (bool, optional): Whether to track tensor operations (default: True)

Returns: ComputationalGraphTracker - Tracker instance with execution data

analyze_computational_graph_execution

analyze_computational_graph_execution(model, input_tensor, detailed=True)

Analyze the computational graph of a PyTorch model execution (alias for analyze_computational_graph).

Parameters:

  • model (torch.nn.Module): PyTorch model to analyze

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • detailed (bool, optional): Whether to include detailed analysis (default: True)

Returns: dict - Dictionary containing computational graph analysis

visualize_computational_graph

visualize_computational_graph(model, input_tensor, renderer='plotly')

Visualize the computational graph of a PyTorch model execution.

Parameters:

  • model (torch.nn.Module): PyTorch model to visualize

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • renderer (str, optional): Rendering backend (‘plotly’ or ‘matplotlib’, default: ‘plotly’)

Returns: Visualization object (Plotly figure or Matplotlib figure)

export_computational_graph

export_computational_graph(model, input_tensor, filepath, format='json')

Export the computational graph of a PyTorch model execution to a file.

Parameters:

  • model (torch.nn.Module): PyTorch model to export

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • filepath (str): Output file path

  • format (str, optional): Export format (‘json’, default: ‘json’)

Returns: str - Path to the exported file

save_computational_graph_png

save_computational_graph_png(model, input_tensor, filepath='computational_graph.png', width=1200, height=800, dpi=300, show_legend=True, node_size=20, font_size=10)

Save the computational graph as a high-quality PNG image.

Parameters:

  • model (torch.nn.Module): PyTorch model to visualize

  • input_tensor (torch.Tensor): Input tensor for the forward pass

  • filepath (str, optional): Output PNG file path (default: “computational_graph.png”)

  • width (int, optional): Image width in pixels (default: 1200)

  • height (int, optional): Image height in pixels (default: 800)

  • dpi (int, optional): Dots per inch for high resolution (default: 300)

  • show_legend (bool, optional): Whether to show legend (default: True)

  • node_size (int, optional): Size of nodes in the graph (default: 20)

  • font_size (int, optional): Font size for labels (default: 10)

Returns: str - Path to the saved PNG file

Quick Start Examples

Basic Usage

import torch
import torch.nn as nn
from pytorch-graph import (
    generate_architecture_diagram,
    analyze_model,
    track_computational_graph
)

# Create a simple model
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# Generate architecture diagram
generate_architecture_diagram(
    model=model,
    input_shape=(1, 784),
    output_path="model_architecture.png",
    title="My Neural Network"
)

# Analyze model
analysis = analyze_model(model, input_shape=(1, 784))
print(f"Total parameters: {analysis['basic_info']['total_parameters']:,}")

# Track computational graph
input_tensor = torch.randn(1, 784, requires_grad=True)
tracker = track_computational_graph(model, input_tensor)
tracker.save_graph_png("computational_graph.png")

Advanced Usage

from pytorch-graph import (
    PyTorchVisualizer,
    ComputationalGraphTracker,
    ModelAnalyzer
)

# Create visualizer
visualizer = PyTorchVisualizer(
    renderer='plotly',
    layout_style='hierarchical',
    spacing=2.5
)

# 3D visualization
fig = visualizer.visualize(
    model=model,
    input_shape=(1, 784),
    show_parameters=True,
    show_activations=True
)
fig.show()

# Comprehensive analysis
analyzer = ModelAnalyzer()
analysis = analyzer.analyze(model, input_shape=(1, 784), detailed=True)

# Computational graph tracking
tracker = ComputationalGraphTracker(
    model=model,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

tracker.start_tracking()
output = model(input_tensor)
loss = output.sum()
loss.backward()
tracker.stop_tracking()

# Get summary
summary = tracker.get_graph_summary()
print(f"Operations: {summary['total_nodes']:,}")
print(f"Execution time: {summary['execution_time']:.4f}s")

Error Handling

All functions will raise appropriate exceptions for:

  • ImportError: If PyTorch is not available

  • RuntimeError: If model execution fails

  • ValueError: If invalid parameters are provided

  • TypeError: If model is not a PyTorch module

  • FileNotFoundError: If output directory doesn’t exist

See Also

  • architecture_visualization - For architecture diagram generation

  • computational_graph_tracking - For computational graph analysis

  • Model Analysis API - For model analysis functions

  • Utilities API - For utility classes and functions