""" Pipeline stage analysis and management functionality. This module provides functions to analyze pipeline node connections and automatically determine the number of stages in a pipeline. Each stage consists of a model node with optional preprocessing and postprocessing nodes. Main Components: - Stage detection and analysis - Pipeline structure validation - Stage configuration generation - Connection path analysis Usage: from cluster4npu_ui.core.pipeline import analyze_pipeline_stages, get_stage_count stage_count = get_stage_count(node_graph) stages = analyze_pipeline_stages(node_graph) """ from typing import List, Dict, Any, Optional, Tuple from .nodes.model_node import ModelNode from .nodes.preprocess_node import PreprocessNode from .nodes.postprocess_node import PostprocessNode from .nodes.input_node import InputNode from .nodes.output_node import OutputNode class PipelineStage: """Represents a single stage in the pipeline.""" def __init__(self, stage_id: int, model_node: ModelNode): self.stage_id = stage_id self.model_node = model_node self.preprocess_nodes: List[PreprocessNode] = [] self.postprocess_nodes: List[PostprocessNode] = [] self.input_connections = [] self.output_connections = [] def add_preprocess_node(self, node: PreprocessNode): """Add a preprocessing node to this stage.""" self.preprocess_nodes.append(node) def add_postprocess_node(self, node: PostprocessNode): """Add a postprocessing node to this stage.""" self.postprocess_nodes.append(node) def get_stage_config(self) -> Dict[str, Any]: """Get configuration for this stage.""" # Get model config safely model_config = {} try: if hasattr(self.model_node, 'get_inference_config'): model_config = self.model_node.get_inference_config() else: model_config = {'node_name': getattr(self.model_node, 'NODE_NAME', 'Unknown Model')} except: model_config = {'node_name': 'Unknown Model'} # Get preprocess configs safely preprocess_configs = [] for node in self.preprocess_nodes: try: if hasattr(node, 'get_preprocessing_config'): preprocess_configs.append(node.get_preprocessing_config()) else: preprocess_configs.append({'node_name': getattr(node, 'NODE_NAME', 'Unknown Preprocess')}) except: preprocess_configs.append({'node_name': 'Unknown Preprocess'}) # Get postprocess configs safely postprocess_configs = [] for node in self.postprocess_nodes: try: if hasattr(node, 'get_postprocessing_config'): postprocess_configs.append(node.get_postprocessing_config()) else: postprocess_configs.append({'node_name': getattr(node, 'NODE_NAME', 'Unknown Postprocess')}) except: postprocess_configs.append({'node_name': 'Unknown Postprocess'}) config = { 'stage_id': self.stage_id, 'model_config': model_config, 'preprocess_configs': preprocess_configs, 'postprocess_configs': postprocess_configs } return config def validate_stage(self) -> Tuple[bool, str]: """Validate this stage configuration.""" # Validate model node is_valid, error = self.model_node.validate_configuration() if not is_valid: return False, f"Stage {self.stage_id} model error: {error}" # Validate preprocessing nodes for i, node in enumerate(self.preprocess_nodes): is_valid, error = node.validate_configuration() if not is_valid: return False, f"Stage {self.stage_id} preprocess {i} error: {error}" # Validate postprocessing nodes for i, node in enumerate(self.postprocess_nodes): is_valid, error = node.validate_configuration() if not is_valid: return False, f"Stage {self.stage_id} postprocess {i} error: {error}" return True, "" def find_connected_nodes(node, visited=None, direction='forward'): """ Find all nodes connected to a given node. Args: node: Starting node visited: Set of already visited nodes direction: 'forward' for outputs, 'backward' for inputs Returns: List of connected nodes """ if visited is None: visited = set() if node in visited: return [] visited.add(node) connected = [] if direction == 'forward': # Get connected output nodes for output in node.outputs(): for connected_input in output.connected_inputs(): connected_node = connected_input.node() if connected_node not in visited: connected.append(connected_node) connected.extend(find_connected_nodes(connected_node, visited, direction)) else: # Get connected input nodes for input_port in node.inputs(): for connected_output in input_port.connected_outputs(): connected_node = connected_output.node() if connected_node not in visited: connected.append(connected_node) connected.extend(find_connected_nodes(connected_node, visited, direction)) return connected def analyze_pipeline_stages(node_graph) -> List[PipelineStage]: """ Analyze a node graph to identify pipeline stages. Each stage consists of: 1. A model node (required) that is connected in the pipeline flow 2. Optional preprocessing nodes (before model) 3. Optional postprocessing nodes (after model) Args: node_graph: NodeGraphQt graph object Returns: List of PipelineStage objects """ stages = [] all_nodes = node_graph.all_nodes() # Find all model nodes - these define the stages model_nodes = [] input_nodes = [] output_nodes = [] for node in all_nodes: # Detect model nodes if is_model_node(node): model_nodes.append(node) # Detect input nodes elif is_input_node(node): input_nodes.append(node) # Detect output nodes elif is_output_node(node): output_nodes.append(node) if not input_nodes or not output_nodes: return [] # Invalid pipeline - must have input and output # Use all model nodes when we have valid input/output structure # Simplified approach: if we have input and output nodes, count all model nodes as stages connected_model_nodes = model_nodes # Use all model nodes # For nodes without connections, just create stages in the order they appear try: # Sort model nodes by their position in the pipeline model_nodes_with_distance = [] for model_node in connected_model_nodes: # Calculate distance from input nodes distance = calculate_distance_from_input(model_node, input_nodes) model_nodes_with_distance.append((model_node, distance)) # Sort by distance from input (closest first) model_nodes_with_distance.sort(key=lambda x: x[1]) # Create stages for stage_id, (model_node, _) in enumerate(model_nodes_with_distance, 1): stage = PipelineStage(stage_id, model_node) # Find preprocessing nodes (nodes that connect to this model but aren't models themselves) preprocess_nodes = find_preprocess_nodes_for_model(model_node, all_nodes) for preprocess_node in preprocess_nodes: stage.add_preprocess_node(preprocess_node) # Find postprocessing nodes (nodes that this model connects to but aren't models) postprocess_nodes = find_postprocess_nodes_for_model(model_node, all_nodes) for postprocess_node in postprocess_nodes: stage.add_postprocess_node(postprocess_node) stages.append(stage) except Exception as e: # Fallback: just create simple stages for all model nodes print(f"Warning: Pipeline distance calculation failed ({e}), using simple stage creation") for stage_id, model_node in enumerate(connected_model_nodes, 1): stage = PipelineStage(stage_id, model_node) stages.append(stage) return stages def calculate_distance_from_input(target_node, input_nodes): """Calculate the shortest distance from any input node to the target node.""" min_distance = float('inf') for input_node in input_nodes: distance = find_shortest_path_distance(input_node, target_node) if distance < min_distance: min_distance = distance return min_distance if min_distance != float('inf') else 0 def find_shortest_path_distance(start_node, target_node, visited=None, distance=0): """Find shortest path distance between two nodes.""" if visited is None: visited = set() if start_node == target_node: return distance if start_node in visited: return float('inf') visited.add(start_node) min_distance = float('inf') # Check all connected nodes - handle nodes without proper connections try: if hasattr(start_node, 'outputs'): for output in start_node.outputs(): if hasattr(output, 'connected_inputs'): for connected_input in output.connected_inputs(): if hasattr(connected_input, 'node'): connected_node = connected_input.node() if connected_node not in visited: path_distance = find_shortest_path_distance( connected_node, target_node, visited.copy(), distance + 1 ) min_distance = min(min_distance, path_distance) except: # If there's any error in path finding, return a default distance pass return min_distance def find_preprocess_nodes_for_model(model_node, all_nodes): """Find preprocessing nodes that connect to the given model node.""" preprocess_nodes = [] # Get all nodes that connect to the model's inputs for input_port in model_node.inputs(): for connected_output in input_port.connected_outputs(): connected_node = connected_output.node() if isinstance(connected_node, PreprocessNode): preprocess_nodes.append(connected_node) return preprocess_nodes def find_postprocess_nodes_for_model(model_node, all_nodes): """Find postprocessing nodes that the given model node connects to.""" postprocess_nodes = [] # Get all nodes that the model connects to for output in model_node.outputs(): for connected_input in output.connected_inputs(): connected_node = connected_input.node() if isinstance(connected_node, PostprocessNode): postprocess_nodes.append(connected_node) return postprocess_nodes def is_model_node(node): """Check if a node is a model node using multiple detection methods.""" if hasattr(node, '__identifier__'): identifier = node.__identifier__ if 'model' in identifier.lower(): return True if hasattr(node, 'type_') and 'model' in str(node.type_).lower(): return True if hasattr(node, 'NODE_NAME') and 'model' in str(node.NODE_NAME).lower(): return True if 'model' in str(type(node)).lower(): return True # Check if it's our ModelNode class if hasattr(node, 'get_inference_config'): return True # Check for ExactModelNode if 'exactmodel' in str(type(node)).lower(): return True return False def is_input_node(node): """Check if a node is an input node using multiple detection methods.""" if hasattr(node, '__identifier__'): identifier = node.__identifier__ if 'input' in identifier.lower(): return True if hasattr(node, 'type_') and 'input' in str(node.type_).lower(): return True if hasattr(node, 'NODE_NAME') and 'input' in str(node.NODE_NAME).lower(): return True if 'input' in str(type(node)).lower(): return True # Check if it's our InputNode class if hasattr(node, 'get_input_config'): return True # Check for ExactInputNode if 'exactinput' in str(type(node)).lower(): return True return False def is_output_node(node): """Check if a node is an output node using multiple detection methods.""" if hasattr(node, '__identifier__'): identifier = node.__identifier__ if 'output' in identifier.lower(): return True if hasattr(node, 'type_') and 'output' in str(node.type_).lower(): return True if hasattr(node, 'NODE_NAME') and 'output' in str(node.NODE_NAME).lower(): return True if 'output' in str(type(node)).lower(): return True # Check if it's our OutputNode class if hasattr(node, 'get_output_config'): return True # Check for ExactOutputNode if 'exactoutput' in str(type(node)).lower(): return True return False def get_stage_count(node_graph) -> int: """ Get the number of stages in a pipeline. Args: node_graph: NodeGraphQt graph object Returns: Number of stages (model nodes) in the pipeline """ if not node_graph: return 0 all_nodes = node_graph.all_nodes() # Use robust detection for model nodes model_nodes = [node for node in all_nodes if is_model_node(node)] return len(model_nodes) def validate_pipeline_structure(node_graph) -> Tuple[bool, str]: """ Validate the overall pipeline structure. Args: node_graph: NodeGraphQt graph object Returns: Tuple of (is_valid, error_message) """ if not node_graph: return False, "No pipeline graph provided" all_nodes = node_graph.all_nodes() # Check for required node types using our detection functions input_nodes = [node for node in all_nodes if is_input_node(node)] output_nodes = [node for node in all_nodes if is_output_node(node)] model_nodes = [node for node in all_nodes if is_model_node(node)] if not input_nodes: return False, "Pipeline must have at least one input node" if not output_nodes: return False, "Pipeline must have at least one output node" if not model_nodes: return False, "Pipeline must have at least one model node" # Skip connectivity checks for now since nodes may not have proper connections # In a real NodeGraphQt environment, this would check actual connections return True, "" def is_node_connected_to_pipeline(node, input_nodes, output_nodes): """Check if a node is connected to both input and output sides of the pipeline.""" # Check if there's a path from any input to this node connected_to_input = any( has_path_between_nodes(input_node, node) for input_node in input_nodes ) # Check if there's a path from this node to any output connected_to_output = any( has_path_between_nodes(node, output_node) for output_node in output_nodes ) return connected_to_input and connected_to_output def has_path_between_nodes(start_node, end_node, visited=None): """Check if there's a path between two nodes.""" if visited is None: visited = set() if start_node == end_node: return True if start_node in visited: return False visited.add(start_node) # Check all connected nodes try: if hasattr(start_node, 'outputs'): for output in start_node.outputs(): if hasattr(output, 'connected_inputs'): for connected_input in output.connected_inputs(): if hasattr(connected_input, 'node'): connected_node = connected_input.node() if has_path_between_nodes(connected_node, end_node, visited): return True elif hasattr(output, 'connected_ports'): # Alternative connection method for connected_port in output.connected_ports(): if hasattr(connected_port, 'node'): connected_node = connected_port.node() if has_path_between_nodes(connected_node, end_node, visited): return True except Exception: # If there's any error accessing connections, assume no path pass return False def get_pipeline_summary(node_graph) -> Dict[str, Any]: """ Get a summary of the pipeline structure. Args: node_graph: NodeGraphQt graph object Returns: Dictionary containing pipeline summary information """ if not node_graph: return {'stage_count': 0, 'valid': False, 'error': 'No pipeline graph'} all_nodes = node_graph.all_nodes() # Count nodes by type using robust detection input_count = 0 output_count = 0 model_count = 0 preprocess_count = 0 postprocess_count = 0 for node in all_nodes: # Detect input nodes if is_input_node(node): input_count += 1 # Detect output nodes elif is_output_node(node): output_count += 1 # Detect model nodes elif is_model_node(node): model_count += 1 # Detect preprocess nodes elif ((hasattr(node, '__identifier__') and 'preprocess' in node.__identifier__.lower()) or \ (hasattr(node, 'type_') and 'preprocess' in str(node.type_).lower()) or \ (hasattr(node, 'NODE_NAME') and 'preprocess' in str(node.NODE_NAME).lower()) or \ ('preprocess' in str(type(node)).lower()) or \ ('exactpreprocess' in str(type(node)).lower()) or \ hasattr(node, 'get_preprocessing_config')): preprocess_count += 1 # Detect postprocess nodes elif ((hasattr(node, '__identifier__') and 'postprocess' in node.__identifier__.lower()) or \ (hasattr(node, 'type_') and 'postprocess' in str(node.type_).lower()) or \ (hasattr(node, 'NODE_NAME') and 'postprocess' in str(node.NODE_NAME).lower()) or \ ('postprocess' in str(type(node)).lower()) or \ ('exactpostprocess' in str(type(node)).lower()) or \ hasattr(node, 'get_postprocessing_config')): postprocess_count += 1 stages = analyze_pipeline_stages(node_graph) is_valid, error = validate_pipeline_structure(node_graph) return { 'stage_count': len(stages), 'valid': is_valid, 'error': error if not is_valid else None, 'stages': [stage.get_stage_config() for stage in stages], 'total_nodes': len(all_nodes), 'input_nodes': input_count, 'output_nodes': output_count, 'model_nodes': model_count, 'preprocess_nodes': preprocess_count, 'postprocess_nodes': postprocess_count }