545 lines
19 KiB
Python
545 lines
19 KiB
Python
"""
|
|
Pipeline stage analysis and management functionality.
|
|
|
|
This module provides functions to analyze pipeline node connections and automatically
|
|
determine the number of stages in a pipeline. Each stage consists of a model node
|
|
with optional preprocessing and postprocessing nodes.
|
|
|
|
Main Components:
|
|
- Stage detection and analysis
|
|
- Pipeline structure validation
|
|
- Stage configuration generation
|
|
- Connection path analysis
|
|
|
|
Usage:
|
|
from cluster4npu_ui.core.pipeline import analyze_pipeline_stages, get_stage_count
|
|
|
|
stage_count = get_stage_count(node_graph)
|
|
stages = analyze_pipeline_stages(node_graph)
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from .nodes.model_node import ModelNode
|
|
from .nodes.preprocess_node import PreprocessNode
|
|
from .nodes.postprocess_node import PostprocessNode
|
|
from .nodes.input_node import InputNode
|
|
from .nodes.output_node import OutputNode
|
|
|
|
|
|
class PipelineStage:
|
|
"""Represents a single stage in the pipeline."""
|
|
|
|
def __init__(self, stage_id: int, model_node: ModelNode):
|
|
self.stage_id = stage_id
|
|
self.model_node = model_node
|
|
self.preprocess_nodes: List[PreprocessNode] = []
|
|
self.postprocess_nodes: List[PostprocessNode] = []
|
|
self.input_connections = []
|
|
self.output_connections = []
|
|
|
|
def add_preprocess_node(self, node: PreprocessNode):
|
|
"""Add a preprocessing node to this stage."""
|
|
self.preprocess_nodes.append(node)
|
|
|
|
def add_postprocess_node(self, node: PostprocessNode):
|
|
"""Add a postprocessing node to this stage."""
|
|
self.postprocess_nodes.append(node)
|
|
|
|
def get_stage_config(self) -> Dict[str, Any]:
|
|
"""Get configuration for this stage."""
|
|
# Get model config safely
|
|
model_config = {}
|
|
try:
|
|
if hasattr(self.model_node, 'get_inference_config'):
|
|
model_config = self.model_node.get_inference_config()
|
|
else:
|
|
model_config = {'node_name': getattr(self.model_node, 'NODE_NAME', 'Unknown Model')}
|
|
except:
|
|
model_config = {'node_name': 'Unknown Model'}
|
|
|
|
# Get preprocess configs safely
|
|
preprocess_configs = []
|
|
for node in self.preprocess_nodes:
|
|
try:
|
|
if hasattr(node, 'get_preprocessing_config'):
|
|
preprocess_configs.append(node.get_preprocessing_config())
|
|
else:
|
|
preprocess_configs.append({'node_name': getattr(node, 'NODE_NAME', 'Unknown Preprocess')})
|
|
except:
|
|
preprocess_configs.append({'node_name': 'Unknown Preprocess'})
|
|
|
|
# Get postprocess configs safely
|
|
postprocess_configs = []
|
|
for node in self.postprocess_nodes:
|
|
try:
|
|
if hasattr(node, 'get_postprocessing_config'):
|
|
postprocess_configs.append(node.get_postprocessing_config())
|
|
else:
|
|
postprocess_configs.append({'node_name': getattr(node, 'NODE_NAME', 'Unknown Postprocess')})
|
|
except:
|
|
postprocess_configs.append({'node_name': 'Unknown Postprocess'})
|
|
|
|
config = {
|
|
'stage_id': self.stage_id,
|
|
'model_config': model_config,
|
|
'preprocess_configs': preprocess_configs,
|
|
'postprocess_configs': postprocess_configs
|
|
}
|
|
return config
|
|
|
|
def validate_stage(self) -> Tuple[bool, str]:
|
|
"""Validate this stage configuration."""
|
|
# Validate model node
|
|
is_valid, error = self.model_node.validate_configuration()
|
|
if not is_valid:
|
|
return False, f"Stage {self.stage_id} model error: {error}"
|
|
|
|
# Validate preprocessing nodes
|
|
for i, node in enumerate(self.preprocess_nodes):
|
|
is_valid, error = node.validate_configuration()
|
|
if not is_valid:
|
|
return False, f"Stage {self.stage_id} preprocess {i} error: {error}"
|
|
|
|
# Validate postprocessing nodes
|
|
for i, node in enumerate(self.postprocess_nodes):
|
|
is_valid, error = node.validate_configuration()
|
|
if not is_valid:
|
|
return False, f"Stage {self.stage_id} postprocess {i} error: {error}"
|
|
|
|
return True, ""
|
|
|
|
|
|
def find_connected_nodes(node, visited=None, direction='forward'):
|
|
"""
|
|
Find all nodes connected to a given node.
|
|
|
|
Args:
|
|
node: Starting node
|
|
visited: Set of already visited nodes
|
|
direction: 'forward' for outputs, 'backward' for inputs
|
|
|
|
Returns:
|
|
List of connected nodes
|
|
"""
|
|
if visited is None:
|
|
visited = set()
|
|
|
|
if node in visited:
|
|
return []
|
|
|
|
visited.add(node)
|
|
connected = []
|
|
|
|
if direction == 'forward':
|
|
# Get connected output nodes
|
|
for output in node.outputs():
|
|
for connected_input in output.connected_inputs():
|
|
connected_node = connected_input.node()
|
|
if connected_node not in visited:
|
|
connected.append(connected_node)
|
|
connected.extend(find_connected_nodes(connected_node, visited, direction))
|
|
else:
|
|
# Get connected input nodes
|
|
for input_port in node.inputs():
|
|
for connected_output in input_port.connected_outputs():
|
|
connected_node = connected_output.node()
|
|
if connected_node not in visited:
|
|
connected.append(connected_node)
|
|
connected.extend(find_connected_nodes(connected_node, visited, direction))
|
|
|
|
return connected
|
|
|
|
|
|
def analyze_pipeline_stages(node_graph) -> List[PipelineStage]:
|
|
"""
|
|
Analyze a node graph to identify pipeline stages.
|
|
|
|
Each stage consists of:
|
|
1. A model node (required) that is connected in the pipeline flow
|
|
2. Optional preprocessing nodes (before model)
|
|
3. Optional postprocessing nodes (after model)
|
|
|
|
Args:
|
|
node_graph: NodeGraphQt graph object
|
|
|
|
Returns:
|
|
List of PipelineStage objects
|
|
"""
|
|
stages = []
|
|
all_nodes = node_graph.all_nodes()
|
|
|
|
# Find all model nodes - these define the stages
|
|
model_nodes = []
|
|
input_nodes = []
|
|
output_nodes = []
|
|
|
|
for node in all_nodes:
|
|
# Detect model nodes
|
|
if is_model_node(node):
|
|
model_nodes.append(node)
|
|
|
|
# Detect input nodes
|
|
elif is_input_node(node):
|
|
input_nodes.append(node)
|
|
|
|
# Detect output nodes
|
|
elif is_output_node(node):
|
|
output_nodes.append(node)
|
|
|
|
if not input_nodes or not output_nodes:
|
|
return [] # Invalid pipeline - must have input and output
|
|
|
|
# Use all model nodes when we have valid input/output structure
|
|
# Simplified approach: if we have input and output nodes, count all model nodes as stages
|
|
connected_model_nodes = model_nodes # Use all model nodes
|
|
|
|
# For nodes without connections, just create stages in the order they appear
|
|
try:
|
|
# Sort model nodes by their position in the pipeline
|
|
model_nodes_with_distance = []
|
|
for model_node in connected_model_nodes:
|
|
# Calculate distance from input nodes
|
|
distance = calculate_distance_from_input(model_node, input_nodes)
|
|
model_nodes_with_distance.append((model_node, distance))
|
|
|
|
# Sort by distance from input (closest first)
|
|
model_nodes_with_distance.sort(key=lambda x: x[1])
|
|
|
|
# Create stages
|
|
for stage_id, (model_node, _) in enumerate(model_nodes_with_distance, 1):
|
|
stage = PipelineStage(stage_id, model_node)
|
|
|
|
# Find preprocessing nodes (nodes that connect to this model but aren't models themselves)
|
|
preprocess_nodes = find_preprocess_nodes_for_model(model_node, all_nodes)
|
|
for preprocess_node in preprocess_nodes:
|
|
stage.add_preprocess_node(preprocess_node)
|
|
|
|
# Find postprocessing nodes (nodes that this model connects to but aren't models)
|
|
postprocess_nodes = find_postprocess_nodes_for_model(model_node, all_nodes)
|
|
for postprocess_node in postprocess_nodes:
|
|
stage.add_postprocess_node(postprocess_node)
|
|
|
|
stages.append(stage)
|
|
except Exception as e:
|
|
# Fallback: just create simple stages for all model nodes
|
|
print(f"Warning: Pipeline distance calculation failed ({e}), using simple stage creation")
|
|
for stage_id, model_node in enumerate(connected_model_nodes, 1):
|
|
stage = PipelineStage(stage_id, model_node)
|
|
stages.append(stage)
|
|
|
|
return stages
|
|
|
|
|
|
def calculate_distance_from_input(target_node, input_nodes):
|
|
"""Calculate the shortest distance from any input node to the target node."""
|
|
min_distance = float('inf')
|
|
|
|
for input_node in input_nodes:
|
|
distance = find_shortest_path_distance(input_node, target_node)
|
|
if distance < min_distance:
|
|
min_distance = distance
|
|
|
|
return min_distance if min_distance != float('inf') else 0
|
|
|
|
|
|
def find_shortest_path_distance(start_node, target_node, visited=None, distance=0):
|
|
"""Find shortest path distance between two nodes."""
|
|
if visited is None:
|
|
visited = set()
|
|
|
|
if start_node == target_node:
|
|
return distance
|
|
|
|
if start_node in visited:
|
|
return float('inf')
|
|
|
|
visited.add(start_node)
|
|
min_distance = float('inf')
|
|
|
|
# Check all connected nodes - handle nodes without proper connections
|
|
try:
|
|
if hasattr(start_node, 'outputs'):
|
|
for output in start_node.outputs():
|
|
if hasattr(output, 'connected_inputs'):
|
|
for connected_input in output.connected_inputs():
|
|
if hasattr(connected_input, 'node'):
|
|
connected_node = connected_input.node()
|
|
if connected_node not in visited:
|
|
path_distance = find_shortest_path_distance(
|
|
connected_node, target_node, visited.copy(), distance + 1
|
|
)
|
|
min_distance = min(min_distance, path_distance)
|
|
except:
|
|
# If there's any error in path finding, return a default distance
|
|
pass
|
|
|
|
return min_distance
|
|
|
|
|
|
def find_preprocess_nodes_for_model(model_node, all_nodes):
|
|
"""Find preprocessing nodes that connect to the given model node."""
|
|
preprocess_nodes = []
|
|
|
|
# Get all nodes that connect to the model's inputs
|
|
for input_port in model_node.inputs():
|
|
for connected_output in input_port.connected_outputs():
|
|
connected_node = connected_output.node()
|
|
if isinstance(connected_node, PreprocessNode):
|
|
preprocess_nodes.append(connected_node)
|
|
|
|
return preprocess_nodes
|
|
|
|
|
|
def find_postprocess_nodes_for_model(model_node, all_nodes):
|
|
"""Find postprocessing nodes that the given model node connects to."""
|
|
postprocess_nodes = []
|
|
|
|
# Get all nodes that the model connects to
|
|
for output in model_node.outputs():
|
|
for connected_input in output.connected_inputs():
|
|
connected_node = connected_input.node()
|
|
if isinstance(connected_node, PostprocessNode):
|
|
postprocess_nodes.append(connected_node)
|
|
|
|
return postprocess_nodes
|
|
|
|
|
|
def is_model_node(node):
|
|
"""Check if a node is a model node using multiple detection methods."""
|
|
if hasattr(node, '__identifier__'):
|
|
identifier = node.__identifier__
|
|
if 'model' in identifier.lower():
|
|
return True
|
|
if hasattr(node, 'type_') and 'model' in str(node.type_).lower():
|
|
return True
|
|
if hasattr(node, 'NODE_NAME') and 'model' in str(node.NODE_NAME).lower():
|
|
return True
|
|
if 'model' in str(type(node)).lower():
|
|
return True
|
|
# Check if it's our ModelNode class
|
|
if hasattr(node, 'get_inference_config'):
|
|
return True
|
|
# Check for ExactModelNode
|
|
if 'exactmodel' in str(type(node)).lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_input_node(node):
|
|
"""Check if a node is an input node using multiple detection methods."""
|
|
if hasattr(node, '__identifier__'):
|
|
identifier = node.__identifier__
|
|
if 'input' in identifier.lower():
|
|
return True
|
|
if hasattr(node, 'type_') and 'input' in str(node.type_).lower():
|
|
return True
|
|
if hasattr(node, 'NODE_NAME') and 'input' in str(node.NODE_NAME).lower():
|
|
return True
|
|
if 'input' in str(type(node)).lower():
|
|
return True
|
|
# Check if it's our InputNode class
|
|
if hasattr(node, 'get_input_config'):
|
|
return True
|
|
# Check for ExactInputNode
|
|
if 'exactinput' in str(type(node)).lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_output_node(node):
|
|
"""Check if a node is an output node using multiple detection methods."""
|
|
if hasattr(node, '__identifier__'):
|
|
identifier = node.__identifier__
|
|
if 'output' in identifier.lower():
|
|
return True
|
|
if hasattr(node, 'type_') and 'output' in str(node.type_).lower():
|
|
return True
|
|
if hasattr(node, 'NODE_NAME') and 'output' in str(node.NODE_NAME).lower():
|
|
return True
|
|
if 'output' in str(type(node)).lower():
|
|
return True
|
|
# Check if it's our OutputNode class
|
|
if hasattr(node, 'get_output_config'):
|
|
return True
|
|
# Check for ExactOutputNode
|
|
if 'exactoutput' in str(type(node)).lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_stage_count(node_graph) -> int:
|
|
"""
|
|
Get the number of stages in a pipeline.
|
|
|
|
Args:
|
|
node_graph: NodeGraphQt graph object
|
|
|
|
Returns:
|
|
Number of stages (model nodes) in the pipeline
|
|
"""
|
|
if not node_graph:
|
|
return 0
|
|
|
|
all_nodes = node_graph.all_nodes()
|
|
|
|
# Use robust detection for model nodes
|
|
model_nodes = [node for node in all_nodes if is_model_node(node)]
|
|
|
|
return len(model_nodes)
|
|
|
|
|
|
def validate_pipeline_structure(node_graph) -> Tuple[bool, str]:
|
|
"""
|
|
Validate the overall pipeline structure.
|
|
|
|
Args:
|
|
node_graph: NodeGraphQt graph object
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
if not node_graph:
|
|
return False, "No pipeline graph provided"
|
|
|
|
all_nodes = node_graph.all_nodes()
|
|
|
|
# Check for required node types using our detection functions
|
|
input_nodes = [node for node in all_nodes if is_input_node(node)]
|
|
output_nodes = [node for node in all_nodes if is_output_node(node)]
|
|
model_nodes = [node for node in all_nodes if is_model_node(node)]
|
|
|
|
if not input_nodes:
|
|
return False, "Pipeline must have at least one input node"
|
|
|
|
if not output_nodes:
|
|
return False, "Pipeline must have at least one output node"
|
|
|
|
if not model_nodes:
|
|
return False, "Pipeline must have at least one model node"
|
|
|
|
# Skip connectivity checks for now since nodes may not have proper connections
|
|
# In a real NodeGraphQt environment, this would check actual connections
|
|
|
|
return True, ""
|
|
|
|
|
|
def is_node_connected_to_pipeline(node, input_nodes, output_nodes):
|
|
"""Check if a node is connected to both input and output sides of the pipeline."""
|
|
# Check if there's a path from any input to this node
|
|
connected_to_input = any(
|
|
has_path_between_nodes(input_node, node) for input_node in input_nodes
|
|
)
|
|
|
|
# Check if there's a path from this node to any output
|
|
connected_to_output = any(
|
|
has_path_between_nodes(node, output_node) for output_node in output_nodes
|
|
)
|
|
|
|
return connected_to_input and connected_to_output
|
|
|
|
|
|
def has_path_between_nodes(start_node, end_node, visited=None):
|
|
"""Check if there's a path between two nodes."""
|
|
if visited is None:
|
|
visited = set()
|
|
|
|
if start_node == end_node:
|
|
return True
|
|
|
|
if start_node in visited:
|
|
return False
|
|
|
|
visited.add(start_node)
|
|
|
|
# Check all connected nodes
|
|
try:
|
|
if hasattr(start_node, 'outputs'):
|
|
for output in start_node.outputs():
|
|
if hasattr(output, 'connected_inputs'):
|
|
for connected_input in output.connected_inputs():
|
|
if hasattr(connected_input, 'node'):
|
|
connected_node = connected_input.node()
|
|
if has_path_between_nodes(connected_node, end_node, visited):
|
|
return True
|
|
elif hasattr(output, 'connected_ports'):
|
|
# Alternative connection method
|
|
for connected_port in output.connected_ports():
|
|
if hasattr(connected_port, 'node'):
|
|
connected_node = connected_port.node()
|
|
if has_path_between_nodes(connected_node, end_node, visited):
|
|
return True
|
|
except Exception:
|
|
# If there's any error accessing connections, assume no path
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def get_pipeline_summary(node_graph) -> Dict[str, Any]:
|
|
"""
|
|
Get a summary of the pipeline structure.
|
|
|
|
Args:
|
|
node_graph: NodeGraphQt graph object
|
|
|
|
Returns:
|
|
Dictionary containing pipeline summary information
|
|
"""
|
|
if not node_graph:
|
|
return {'stage_count': 0, 'valid': False, 'error': 'No pipeline graph'}
|
|
|
|
all_nodes = node_graph.all_nodes()
|
|
|
|
# Count nodes by type using robust detection
|
|
input_count = 0
|
|
output_count = 0
|
|
model_count = 0
|
|
preprocess_count = 0
|
|
postprocess_count = 0
|
|
|
|
for node in all_nodes:
|
|
# Detect input nodes
|
|
if is_input_node(node):
|
|
input_count += 1
|
|
|
|
# Detect output nodes
|
|
elif is_output_node(node):
|
|
output_count += 1
|
|
|
|
# Detect model nodes
|
|
elif is_model_node(node):
|
|
model_count += 1
|
|
|
|
# Detect preprocess nodes
|
|
elif ((hasattr(node, '__identifier__') and 'preprocess' in node.__identifier__.lower()) or \
|
|
(hasattr(node, 'type_') and 'preprocess' in str(node.type_).lower()) or \
|
|
(hasattr(node, 'NODE_NAME') and 'preprocess' in str(node.NODE_NAME).lower()) or \
|
|
('preprocess' in str(type(node)).lower()) or \
|
|
('exactpreprocess' in str(type(node)).lower()) or \
|
|
hasattr(node, 'get_preprocessing_config')):
|
|
preprocess_count += 1
|
|
|
|
# Detect postprocess nodes
|
|
elif ((hasattr(node, '__identifier__') and 'postprocess' in node.__identifier__.lower()) or \
|
|
(hasattr(node, 'type_') and 'postprocess' in str(node.type_).lower()) or \
|
|
(hasattr(node, 'NODE_NAME') and 'postprocess' in str(node.NODE_NAME).lower()) or \
|
|
('postprocess' in str(type(node)).lower()) or \
|
|
('exactpostprocess' in str(type(node)).lower()) or \
|
|
hasattr(node, 'get_postprocessing_config')):
|
|
postprocess_count += 1
|
|
|
|
stages = analyze_pipeline_stages(node_graph)
|
|
is_valid, error = validate_pipeline_structure(node_graph)
|
|
|
|
return {
|
|
'stage_count': len(stages),
|
|
'valid': is_valid,
|
|
'error': error if not is_valid else None,
|
|
'stages': [stage.get_stage_config() for stage in stages],
|
|
'total_nodes': len(all_nodes),
|
|
'input_nodes': input_count,
|
|
'output_nodes': output_count,
|
|
'model_nodes': model_count,
|
|
'preprocess_nodes': preprocess_count,
|
|
'postprocess_nodes': postprocess_count
|
|
} |