Cluster/ui/windows/pipeline_editor.py
2025-07-17 17:04:56 +08:00

667 lines
26 KiB
Python

# """
# Pipeline Editor window with stage counting functionality.
# This module provides the main pipeline editor interface with visual node-based
# pipeline design and automatic stage counting display.
# Main Components:
# - PipelineEditor: Main pipeline editor window
# - Stage counting display in canvas
# - Node graph integration
# - Pipeline validation and analysis
# Usage:
# from cluster4npu_ui.ui.windows.pipeline_editor import PipelineEditor
# editor = PipelineEditor()
# editor.show()
# """
# import sys
# from PyQt5.QtWidgets import (QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
# QLabel, QStatusBar, QFrame, QPushButton, QAction,
# QMenuBar, QToolBar, QSplitter, QTextEdit, QMessageBox,
# QScrollArea)
# from PyQt5.QtCore import Qt, QTimer, pyqtSignal
# from PyQt5.QtGui import QFont, QPixmap, QIcon, QTextCursor
# try:
# from NodeGraphQt import NodeGraph
# from NodeGraphQt.constants import IN_PORT, OUT_PORT
# NODEGRAPH_AVAILABLE = True
# except ImportError:
# NODEGRAPH_AVAILABLE = False
# print("NodeGraphQt not available. Install with: pip install NodeGraphQt")
# from ...core.pipeline import get_stage_count, analyze_pipeline_stages, get_pipeline_summary
# from ...core.nodes.exact_nodes import (
# ExactInputNode, ExactModelNode, ExactPreprocessNode,
# ExactPostprocessNode, ExactOutputNode
# )
# # Keep the original imports as fallback
# try:
# from ...core.nodes.model_node import ModelNode
# from ...core.nodes.preprocess_node import PreprocessNode
# from ...core.nodes.postprocess_node import PostprocessNode
# from ...core.nodes.input_node import InputNode
# from ...core.nodes.output_node import OutputNode
# except ImportError:
# # Use ExactNodes as fallback
# ModelNode = ExactModelNode
# PreprocessNode = ExactPreprocessNode
# PostprocessNode = ExactPostprocessNode
# InputNode = ExactInputNode
# OutputNode = ExactOutputNode
# class StageCountWidget(QWidget):
# """Widget to display stage count information in the pipeline editor."""
# def __init__(self, parent=None):
# super().__init__(parent)
# self.stage_count = 0
# self.pipeline_valid = True
# self.pipeline_error = ""
# self.setup_ui()
# self.setFixedSize(200, 80)
# def setup_ui(self):
# """Setup the stage count widget UI."""
# layout = QVBoxLayout()
# layout.setContentsMargins(10, 5, 10, 5)
# # Stage count label
# self.stage_label = QLabel("Stages: 0")
# self.stage_label.setFont(QFont("Arial", 11, QFont.Bold))
# self.stage_label.setStyleSheet("color: #2E7D32; font-weight: bold;")
# # Status label
# self.status_label = QLabel("Ready")
# self.status_label.setFont(QFont("Arial", 9))
# self.status_label.setStyleSheet("color: #666666;")
# # Error label (initially hidden)
# self.error_label = QLabel("")
# self.error_label.setFont(QFont("Arial", 8))
# self.error_label.setStyleSheet("color: #D32F2F;")
# self.error_label.setWordWrap(True)
# self.error_label.setMaximumHeight(30)
# self.error_label.hide()
# layout.addWidget(self.stage_label)
# layout.addWidget(self.status_label)
# layout.addWidget(self.error_label)
# self.setLayout(layout)
# # Style the widget
# self.setStyleSheet("""
# StageCountWidget {
# background-color: #F5F5F5;
# border: 1px solid #E0E0E0;
# border-radius: 5px;
# }
# """)
# def update_stage_count(self, count: int, valid: bool = True, error: str = ""):
# """Update the stage count display."""
# self.stage_count = count
# self.pipeline_valid = valid
# self.pipeline_error = error
# # Update stage count
# self.stage_label.setText(f"Stages: {count}")
# # Update status and styling
# if not valid:
# self.stage_label.setStyleSheet("color: #D32F2F; font-weight: bold;")
# self.status_label.setText("Invalid Pipeline")
# self.status_label.setStyleSheet("color: #D32F2F;")
# self.error_label.setText(error)
# self.error_label.show()
# else:
# self.stage_label.setStyleSheet("color: #2E7D32; font-weight: bold;")
# if count == 0:
# self.status_label.setText("No stages defined")
# self.status_label.setStyleSheet("color: #FF8F00;")
# else:
# self.status_label.setText(f"Pipeline ready ({count} stage{'s' if count != 1 else ''})")
# self.status_label.setStyleSheet("color: #2E7D32;")
# self.error_label.hide()
# class PipelineEditor(QMainWindow):
# """
# Main pipeline editor window with stage counting functionality.
# This window provides a visual node-based pipeline editor with automatic
# stage detection and counting displayed in the canvas.
# """
# # Signals
# pipeline_changed = pyqtSignal()
# stage_count_changed = pyqtSignal(int)
# def __init__(self, parent=None):
# super().__init__(parent)
# self.node_graph = None
# self.stage_count_widget = None
# self.analysis_timer = None
# self.previous_stage_count = 0 # Track previous stage count for comparison
# self.setup_ui()
# self.setup_node_graph()
# self.setup_analysis_timer()
# # Connect signals
# self.pipeline_changed.connect(self.analyze_pipeline)
# # Initial analysis
# print("Pipeline Editor initialized")
# self.analyze_pipeline()
# def setup_ui(self):
# """Setup the main UI components."""
# self.setWindowTitle("Pipeline Editor - Cluster4NPU")
# self.setGeometry(100, 100, 1200, 800)
# # Create central widget
# central_widget = QWidget()
# self.setCentralWidget(central_widget)
# # Create main layout
# main_layout = QVBoxLayout()
# central_widget.setLayout(main_layout)
# # Create splitter for main content
# splitter = QSplitter(Qt.Horizontal)
# main_layout.addWidget(splitter)
# # Left panel for node graph
# self.graph_widget = QWidget()
# self.graph_layout = QVBoxLayout()
# self.graph_widget.setLayout(self.graph_layout)
# splitter.addWidget(self.graph_widget)
# # Right panel for properties and tools
# right_panel = QWidget()
# right_panel.setMaximumWidth(300)
# right_layout = QVBoxLayout()
# right_panel.setLayout(right_layout)
# # Stage count widget (positioned at bottom right)
# self.stage_count_widget = StageCountWidget()
# right_layout.addWidget(self.stage_count_widget)
# # Properties panel
# properties_label = QLabel("Properties")
# properties_label.setFont(QFont("Arial", 10, QFont.Bold))
# right_layout.addWidget(properties_label)
# self.properties_text = QTextEdit()
# self.properties_text.setMaximumHeight(200)
# self.properties_text.setReadOnly(True)
# right_layout.addWidget(self.properties_text)
# # Pipeline info panel
# info_label = QLabel("Pipeline Info")
# info_label.setFont(QFont("Arial", 10, QFont.Bold))
# right_layout.addWidget(info_label)
# self.info_text = QTextEdit()
# self.info_text.setReadOnly(True)
# right_layout.addWidget(self.info_text)
# splitter.addWidget(right_panel)
# # Set splitter proportions
# splitter.setSizes([800, 300])
# # Create toolbar
# self.create_toolbar()
# # Create status bar
# self.create_status_bar()
# # Apply styling
# self.apply_styling()
# def create_toolbar(self):
# """Create the toolbar with pipeline operations."""
# toolbar = self.addToolBar("Pipeline Operations")
# # Add nodes actions
# add_input_action = QAction("Add Input", self)
# add_input_action.triggered.connect(self.add_input_node)
# toolbar.addAction(add_input_action)
# add_model_action = QAction("Add Model", self)
# add_model_action.triggered.connect(self.add_model_node)
# toolbar.addAction(add_model_action)
# add_preprocess_action = QAction("Add Preprocess", self)
# add_preprocess_action.triggered.connect(self.add_preprocess_node)
# toolbar.addAction(add_preprocess_action)
# add_postprocess_action = QAction("Add Postprocess", self)
# add_postprocess_action.triggered.connect(self.add_postprocess_node)
# toolbar.addAction(add_postprocess_action)
# add_output_action = QAction("Add Output", self)
# add_output_action.triggered.connect(self.add_output_node)
# toolbar.addAction(add_output_action)
# toolbar.addSeparator()
# # Pipeline actions
# validate_action = QAction("Validate Pipeline", self)
# validate_action.triggered.connect(self.validate_pipeline)
# toolbar.addAction(validate_action)
# clear_action = QAction("Clear Pipeline", self)
# clear_action.triggered.connect(self.clear_pipeline)
# toolbar.addAction(clear_action)
# def create_status_bar(self):
# """Create the status bar."""
# self.status_bar = QStatusBar()
# self.setStatusBar(self.status_bar)
# self.status_bar.showMessage("Ready")
# def setup_node_graph(self):
# """Setup the node graph widget."""
# if not NODEGRAPH_AVAILABLE:
# # Show error message
# error_label = QLabel("NodeGraphQt not available. Please install it to use the pipeline editor.")
# error_label.setAlignment(Qt.AlignCenter)
# error_label.setStyleSheet("color: red; font-size: 14px;")
# self.graph_layout.addWidget(error_label)
# return
# # Create node graph
# self.node_graph = NodeGraph()
# # Register node types - use ExactNode classes
# print("Registering nodes with NodeGraphQt...")
# # Try to register ExactNode classes first
# try:
# self.node_graph.register_node(ExactInputNode)
# print(f"✓ Registered ExactInputNode with identifier {ExactInputNode.__identifier__}")
# except Exception as e:
# print(f"✗ Failed to register ExactInputNode: {e}")
# try:
# self.node_graph.register_node(ExactModelNode)
# print(f"✓ Registered ExactModelNode with identifier {ExactModelNode.__identifier__}")
# except Exception as e:
# print(f"✗ Failed to register ExactModelNode: {e}")
# try:
# self.node_graph.register_node(ExactPreprocessNode)
# print(f"✓ Registered ExactPreprocessNode with identifier {ExactPreprocessNode.__identifier__}")
# except Exception as e:
# print(f"✗ Failed to register ExactPreprocessNode: {e}")
# try:
# self.node_graph.register_node(ExactPostprocessNode)
# print(f"✓ Registered ExactPostprocessNode with identifier {ExactPostprocessNode.__identifier__}")
# except Exception as e:
# print(f"✗ Failed to register ExactPostprocessNode: {e}")
# try:
# self.node_graph.register_node(ExactOutputNode)
# print(f"✓ Registered ExactOutputNode with identifier {ExactOutputNode.__identifier__}")
# except Exception as e:
# print(f"✗ Failed to register ExactOutputNode: {e}")
# print("Node graph setup completed successfully")
# # Connect node graph signals
# self.node_graph.node_created.connect(self.on_node_created)
# self.node_graph.node_deleted.connect(self.on_node_deleted)
# self.node_graph.connection_changed.connect(self.on_connection_changed)
# # Connect additional signals for more comprehensive updates
# if hasattr(self.node_graph, 'nodes_deleted'):
# self.node_graph.nodes_deleted.connect(self.on_nodes_deleted)
# if hasattr(self.node_graph, 'connection_sliced'):
# self.node_graph.connection_sliced.connect(self.on_connection_changed)
# # Add node graph widget to layout
# self.graph_layout.addWidget(self.node_graph.widget)
# def setup_analysis_timer(self):
# """Setup timer for pipeline analysis."""
# self.analysis_timer = QTimer()
# self.analysis_timer.setSingleShot(True)
# self.analysis_timer.timeout.connect(self.analyze_pipeline)
# self.analysis_timer.setInterval(500) # 500ms delay
# def apply_styling(self):
# """Apply custom styling to the editor."""
# self.setStyleSheet("""
# QMainWindow {
# background-color: #FAFAFA;
# }
# QToolBar {
# background-color: #FFFFFF;
# border: 1px solid #E0E0E0;
# spacing: 5px;
# padding: 5px;
# }
# QToolBar QAction {
# padding: 5px 10px;
# margin: 2px;
# border: 1px solid #E0E0E0;
# border-radius: 3px;
# background-color: #FFFFFF;
# }
# QToolBar QAction:hover {
# background-color: #F5F5F5;
# }
# QTextEdit {
# border: 1px solid #E0E0E0;
# border-radius: 3px;
# padding: 5px;
# background-color: #FFFFFF;
# }
# QLabel {
# color: #333333;
# }
# """)
# def add_input_node(self):
# """Add an input node to the pipeline."""
# if self.node_graph:
# print("Adding Input Node via toolbar...")
# # Try multiple identifier formats
# identifiers = [
# 'com.cluster.input_node',
# 'com.cluster.input_node.ExactInputNode',
# 'com.cluster.input_node.ExactInputNode.ExactInputNode'
# ]
# node = self.create_node_with_fallback(identifiers, "Input Node")
# self.schedule_analysis()
# def add_model_node(self):
# """Add a model node to the pipeline."""
# if self.node_graph:
# print("Adding Model Node via toolbar...")
# # Try multiple identifier formats
# identifiers = [
# 'com.cluster.model_node',
# 'com.cluster.model_node.ExactModelNode',
# 'com.cluster.model_node.ExactModelNode.ExactModelNode'
# ]
# node = self.create_node_with_fallback(identifiers, "Model Node")
# self.schedule_analysis()
# def add_preprocess_node(self):
# """Add a preprocess node to the pipeline."""
# if self.node_graph:
# print("Adding Preprocess Node via toolbar...")
# # Try multiple identifier formats
# identifiers = [
# 'com.cluster.preprocess_node',
# 'com.cluster.preprocess_node.ExactPreprocessNode',
# 'com.cluster.preprocess_node.ExactPreprocessNode.ExactPreprocessNode'
# ]
# node = self.create_node_with_fallback(identifiers, "Preprocess Node")
# self.schedule_analysis()
# def add_postprocess_node(self):
# """Add a postprocess node to the pipeline."""
# if self.node_graph:
# print("Adding Postprocess Node via toolbar...")
# # Try multiple identifier formats
# identifiers = [
# 'com.cluster.postprocess_node',
# 'com.cluster.postprocess_node.ExactPostprocessNode',
# 'com.cluster.postprocess_node.ExactPostprocessNode.ExactPostprocessNode'
# ]
# node = self.create_node_with_fallback(identifiers, "Postprocess Node")
# self.schedule_analysis()
# def add_output_node(self):
# """Add an output node to the pipeline."""
# if self.node_graph:
# print("Adding Output Node via toolbar...")
# # Try multiple identifier formats
# identifiers = [
# 'com.cluster.output_node',
# 'com.cluster.output_node.ExactOutputNode',
# 'com.cluster.output_node.ExactOutputNode.ExactOutputNode'
# ]
# node = self.create_node_with_fallback(identifiers, "Output Node")
# self.schedule_analysis()
# def create_node_with_fallback(self, identifiers, node_type):
# """Try to create a node with multiple identifier fallbacks."""
# for identifier in identifiers:
# try:
# node = self.node_graph.create_node(identifier)
# print(f"✓ Successfully created {node_type} with identifier: {identifier}")
# return node
# except Exception as e:
# continue
# print(f"Failed to create {node_type} with any identifier: {identifiers}")
# return None
# def validate_pipeline(self):
# """Validate the current pipeline configuration."""
# if not self.node_graph:
# return
# print("🔍 Validating pipeline...")
# summary = get_pipeline_summary(self.node_graph)
# if summary['valid']:
# print(f"Pipeline validation passed - {summary['stage_count']} stages, {summary['total_nodes']} nodes")
# QMessageBox.information(self, "Pipeline Validation",
# f"Pipeline is valid!\n\n"
# f"Stages: {summary['stage_count']}\n"
# f"Total nodes: {summary['total_nodes']}")
# else:
# print(f"Pipeline validation failed: {summary['error']}")
# QMessageBox.warning(self, "Pipeline Validation",
# f"Pipeline validation failed:\n\n{summary['error']}")
# def clear_pipeline(self):
# """Clear the entire pipeline."""
# if self.node_graph:
# print("🗑️ Clearing entire pipeline...")
# self.node_graph.clear_session()
# self.schedule_analysis()
# def schedule_analysis(self):
# """Schedule pipeline analysis after a delay."""
# if self.analysis_timer:
# self.analysis_timer.start()
# def analyze_pipeline(self):
# """Analyze the current pipeline and update stage count."""
# if not self.node_graph:
# return
# try:
# # Get pipeline summary
# summary = get_pipeline_summary(self.node_graph)
# current_stage_count = summary['stage_count']
# # Print detailed pipeline analysis
# self.print_pipeline_analysis(summary, current_stage_count)
# # Update stage count widget
# self.stage_count_widget.update_stage_count(
# current_stage_count,
# summary['valid'],
# summary.get('error', '')
# )
# # Update info panel
# self.update_info_panel(summary)
# # Update status bar
# if summary['valid']:
# self.status_bar.showMessage(f"Pipeline ready - {current_stage_count} stages")
# else:
# self.status_bar.showMessage(f"Pipeline invalid - {summary.get('error', 'Unknown error')}")
# # Update previous count for next comparison
# self.previous_stage_count = current_stage_count
# # Emit signal
# self.stage_count_changed.emit(current_stage_count)
# except Exception as e:
# print(f"X Pipeline analysis error: {str(e)}")
# self.stage_count_widget.update_stage_count(0, False, f"Analysis error: {str(e)}")
# self.status_bar.showMessage(f"Analysis error: {str(e)}")
# def print_pipeline_analysis(self, summary, current_stage_count):
# """Print detailed pipeline analysis to terminal."""
# # Check if stage count changed
# if current_stage_count != self.previous_stage_count:
# if self.previous_stage_count == 0 and current_stage_count > 0:
# print(f"Initial stage count: {current_stage_count}")
# elif current_stage_count != self.previous_stage_count:
# change = current_stage_count - self.previous_stage_count
# if change > 0:
# print(f"Stage count increased: {self.previous_stage_count} → {current_stage_count} (+{change})")
# else:
# print(f"Stage count decreased: {self.previous_stage_count} → {current_stage_count} ({change})")
# # Always print current pipeline status for clarity
# print(f"Current Pipeline Status:")
# print(f" • Stages: {current_stage_count}")
# print(f" • Total Nodes: {summary['total_nodes']}")
# print(f" • Model Nodes: {summary['model_nodes']}")
# print(f" • Input Nodes: {summary['input_nodes']}")
# print(f" • Output Nodes: {summary['output_nodes']}")
# print(f" • Preprocess Nodes: {summary['preprocess_nodes']}")
# print(f" • Postprocess Nodes: {summary['postprocess_nodes']}")
# print(f" • Valid: {'V' if summary['valid'] else 'X'}")
# if not summary['valid'] and summary.get('error'):
# print(f" • Error: {summary['error']}")
# # Print stage details if available
# if summary.get('stages') and len(summary['stages']) > 0:
# print(f"Stage Details:")
# for i, stage in enumerate(summary['stages'], 1):
# model_name = stage['model_config'].get('node_name', 'Unknown Model')
# preprocess_count = len(stage['preprocess_configs'])
# postprocess_count = len(stage['postprocess_configs'])
# stage_info = f" Stage {i}: {model_name}"
# if preprocess_count > 0:
# stage_info += f" (with {preprocess_count} preprocess)"
# if postprocess_count > 0:
# stage_info += f" (with {postprocess_count} postprocess)"
# print(stage_info)
# elif current_stage_count > 0:
# print(f"{current_stage_count} stage(s) detected but details not available")
# print("─" * 50) # Separator line
# def update_info_panel(self, summary):
# """Update the pipeline info panel with analysis results."""
# info_text = f"""Pipeline Analysis:
# Stage Count: {summary['stage_count']}
# Valid: {'Yes' if summary['valid'] else 'No'}
# {f"Error: {summary['error']}" if summary.get('error') else ""}
# Node Statistics:
# - Total Nodes: {summary['total_nodes']}
# - Input Nodes: {summary['input_nodes']}
# - Model Nodes: {summary['model_nodes']}
# - Preprocess Nodes: {summary['preprocess_nodes']}
# - Postprocess Nodes: {summary['postprocess_nodes']}
# - Output Nodes: {summary['output_nodes']}
# Stages:"""
# for i, stage in enumerate(summary.get('stages', []), 1):
# info_text += f"\n Stage {i}: {stage['model_config']['node_name']}"
# if stage['preprocess_configs']:
# info_text += f" (with {len(stage['preprocess_configs'])} preprocess)"
# if stage['postprocess_configs']:
# info_text += f" (with {len(stage['postprocess_configs'])} postprocess)"
# self.info_text.setPlainText(info_text)
# def on_node_created(self, node):
# """Handle node creation."""
# node_type = self.get_node_type_name(node)
# print(f"+ Node added: {node_type}")
# self.schedule_analysis()
# def on_node_deleted(self, node):
# """Handle node deletion."""
# node_type = self.get_node_type_name(node)
# print(f"- Node removed: {node_type}")
# self.schedule_analysis()
# def on_nodes_deleted(self, nodes):
# """Handle multiple node deletion."""
# node_types = [self.get_node_type_name(node) for node in nodes]
# print(f"- Multiple nodes removed: {', '.join(node_types)}")
# self.schedule_analysis()
# def on_connection_changed(self, input_port, output_port):
# """Handle connection changes."""
# print(f"🔗 Connection changed: {input_port} <-> {output_port}")
# self.schedule_analysis()
# def get_node_type_name(self, node):
# """Get a readable name for the node type."""
# if hasattr(node, 'NODE_NAME'):
# return node.NODE_NAME
# elif hasattr(node, '__identifier__'):
# # Convert identifier to readable name
# identifier = node.__identifier__
# if 'model' in identifier:
# return "Model Node"
# elif 'input' in identifier:
# return "Input Node"
# elif 'output' in identifier:
# return "Output Node"
# elif 'preprocess' in identifier:
# return "Preprocess Node"
# elif 'postprocess' in identifier:
# return "Postprocess Node"
# # Fallback to class name
# return type(node).__name__
# def get_current_stage_count(self):
# """Get the current stage count."""
# return self.stage_count_widget.stage_count if self.stage_count_widget else 0
# def get_pipeline_summary(self):
# """Get the current pipeline summary."""
# if self.node_graph:
# return get_pipeline_summary(self.node_graph)
# return {'stage_count': 0, 'valid': False, 'error': 'No pipeline graph'}
# def main():
# """Main function for testing the pipeline editor."""
# from PyQt5.QtWidgets import QApplication
# app = QApplication(sys.argv)
# editor = PipelineEditor()
# editor.show()
# sys.exit(app.exec_())
# if __name__ == '__main__':
# main()