180 lines
6.5 KiB
Python
180 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Final test to verify the stage detection implementation works correctly.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# Set up Qt environment
|
|
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
|
|
|
|
from PyQt5.QtWidgets import QApplication
|
|
app = QApplication(sys.argv)
|
|
|
|
from core.pipeline import (
|
|
is_model_node, is_input_node, is_output_node,
|
|
get_stage_count, get_pipeline_summary
|
|
)
|
|
from core.nodes.model_node import ModelNode
|
|
from core.nodes.input_node import InputNode
|
|
from core.nodes.output_node import OutputNode
|
|
from core.nodes.preprocess_node import PreprocessNode
|
|
from core.nodes.postprocess_node import PostprocessNode
|
|
|
|
|
|
class MockNodeGraph:
|
|
"""Mock node graph for testing."""
|
|
def __init__(self):
|
|
self.nodes = []
|
|
|
|
def all_nodes(self):
|
|
return self.nodes
|
|
|
|
def add_node(self, node):
|
|
self.nodes.append(node)
|
|
print(f"Added node: {node} (type: {type(node).__name__})")
|
|
|
|
|
|
def test_comprehensive_pipeline():
|
|
"""Test comprehensive pipeline functionality."""
|
|
print("Testing Comprehensive Pipeline...")
|
|
|
|
# Create mock graph
|
|
graph = MockNodeGraph()
|
|
|
|
# Test 1: Empty pipeline
|
|
print("\n1. Empty pipeline:")
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 0, f"Expected 0 stages, got {stage_count}"
|
|
|
|
# Test 2: Add input node
|
|
print("\n2. Add input node:")
|
|
input_node = InputNode()
|
|
graph.add_node(input_node)
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 0, f"Expected 0 stages, got {stage_count}"
|
|
|
|
# Test 3: Add model node (should create 1 stage)
|
|
print("\n3. Add model node:")
|
|
model_node = ModelNode()
|
|
graph.add_node(model_node)
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 1, f"Expected 1 stage, got {stage_count}"
|
|
|
|
# Test 4: Add output node
|
|
print("\n4. Add output node:")
|
|
output_node = OutputNode()
|
|
graph.add_node(output_node)
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 1, f"Expected 1 stage, got {stage_count}"
|
|
|
|
# Test 5: Add preprocess node
|
|
print("\n5. Add preprocess node:")
|
|
preprocess_node = PreprocessNode()
|
|
graph.add_node(preprocess_node)
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 1, f"Expected 1 stage, got {stage_count}"
|
|
|
|
# Test 6: Add postprocess node
|
|
print("\n6. Add postprocess node:")
|
|
postprocess_node = PostprocessNode()
|
|
graph.add_node(postprocess_node)
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 1, f"Expected 1 stage, got {stage_count}"
|
|
|
|
# Test 7: Add second model node (should create 2 stages)
|
|
print("\n7. Add second model node:")
|
|
model_node2 = ModelNode()
|
|
graph.add_node(model_node2)
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 2, f"Expected 2 stages, got {stage_count}"
|
|
|
|
# Test 8: Add third model node (should create 3 stages)
|
|
print("\n8. Add third model node:")
|
|
model_node3 = ModelNode()
|
|
graph.add_node(model_node3)
|
|
stage_count = get_stage_count(graph)
|
|
print(f" Stage count: {stage_count}")
|
|
assert stage_count == 3, f"Expected 3 stages, got {stage_count}"
|
|
|
|
# Test 9: Get pipeline summary
|
|
print("\n9. Get pipeline summary:")
|
|
summary = get_pipeline_summary(graph)
|
|
print(f" Summary: {summary}")
|
|
|
|
expected_fields = ['stage_count', 'valid', 'total_nodes', 'model_nodes', 'input_nodes', 'output_nodes']
|
|
for field in expected_fields:
|
|
assert field in summary, f"Missing field '{field}' in summary"
|
|
|
|
assert summary['stage_count'] == 3, f"Expected 3 stages in summary, got {summary['stage_count']}"
|
|
assert summary['model_nodes'] == 3, f"Expected 3 model nodes in summary, got {summary['model_nodes']}"
|
|
assert summary['input_nodes'] == 1, f"Expected 1 input node in summary, got {summary['input_nodes']}"
|
|
assert summary['output_nodes'] == 1, f"Expected 1 output node in summary, got {summary['output_nodes']}"
|
|
assert summary['total_nodes'] == 7, f"Expected 7 total nodes in summary, got {summary['total_nodes']}"
|
|
|
|
print("✓ All comprehensive tests passed!")
|
|
|
|
|
|
def test_node_detection_robustness():
|
|
"""Test robustness of node detection."""
|
|
print("\nTesting Node Detection Robustness...")
|
|
|
|
# Test with actual node instances
|
|
model_node = ModelNode()
|
|
input_node = InputNode()
|
|
output_node = OutputNode()
|
|
preprocess_node = PreprocessNode()
|
|
postprocess_node = PostprocessNode()
|
|
|
|
# Test detection methods
|
|
assert is_model_node(model_node), "Model node not detected correctly"
|
|
assert is_input_node(input_node), "Input node not detected correctly"
|
|
assert is_output_node(output_node), "Output node not detected correctly"
|
|
|
|
# Test cross-detection (should be False)
|
|
assert not is_model_node(input_node), "Input node incorrectly detected as model"
|
|
assert not is_model_node(output_node), "Output node incorrectly detected as model"
|
|
assert not is_input_node(model_node), "Model node incorrectly detected as input"
|
|
assert not is_input_node(output_node), "Output node incorrectly detected as input"
|
|
assert not is_output_node(model_node), "Model node incorrectly detected as output"
|
|
assert not is_output_node(input_node), "Input node incorrectly detected as output"
|
|
|
|
print("✓ Node detection robustness tests passed!")
|
|
|
|
|
|
def main():
|
|
"""Run all tests."""
|
|
print("Running Final Implementation Tests...")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
test_node_detection_robustness()
|
|
test_comprehensive_pipeline()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("🎉 ALL TESTS PASSED! The stage detection implementation is working correctly.")
|
|
print("\nKey Features Verified:")
|
|
print("✓ Model node detection works correctly")
|
|
print("✓ Stage counting updates when model nodes are added")
|
|
print("✓ Pipeline summary provides accurate information")
|
|
print("✓ Node detection is robust and handles edge cases")
|
|
print("✓ Multiple stages are correctly counted")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |