Cluster/tests/test_node_detection.py
2025-07-17 17:04:56 +08:00

125 lines
3.9 KiB
Python

#!/usr/bin/env python3
"""
Test script to verify node detection methods work correctly.
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Mock Qt application for testing
import os
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
# Create a minimal Qt application
from PyQt5.QtWidgets import QApplication
import sys
app = QApplication(sys.argv)
from core.pipeline import is_model_node, is_input_node, is_output_node, get_stage_count
from core.nodes.model_node import ModelNode
from core.nodes.input_node import InputNode
from core.nodes.output_node import OutputNode
from core.nodes.preprocess_node import PreprocessNode
from core.nodes.postprocess_node import PostprocessNode
class MockNodeGraph:
"""Mock node graph for testing."""
def __init__(self):
self.nodes = []
def all_nodes(self):
return self.nodes
def add_node(self, node):
self.nodes.append(node)
def test_node_detection():
"""Test node detection methods."""
print("Testing Node Detection Methods...")
# Create node instances
input_node = InputNode()
model_node = ModelNode()
output_node = OutputNode()
preprocess_node = PreprocessNode()
postprocess_node = PostprocessNode()
# Test detection
print(f"Input node detection: {is_input_node(input_node)}")
print(f"Model node detection: {is_model_node(model_node)}")
print(f"Output node detection: {is_output_node(output_node)}")
# Test cross-detection (should be False)
print(f"Model node detected as input: {is_input_node(model_node)}")
print(f"Input node detected as model: {is_model_node(input_node)}")
print(f"Output node detected as model: {is_model_node(output_node)}")
# Test with mock graph
graph = MockNodeGraph()
graph.add_node(input_node)
graph.add_node(model_node)
graph.add_node(output_node)
stage_count = get_stage_count(graph)
print(f"Stage count: {stage_count}")
# Add another model node
model_node2 = ModelNode()
graph.add_node(model_node2)
stage_count2 = get_stage_count(graph)
print(f"Stage count after adding second model: {stage_count2}")
assert stage_count == 1, f"Expected 1 stage, got {stage_count}"
assert stage_count2 == 2, f"Expected 2 stages, got {stage_count2}"
print("✓ Node detection tests passed")
def test_node_properties():
"""Test node properties for detection."""
print("\nTesting Node Properties...")
model_node = ModelNode()
print(f"Model node type: {type(model_node)}")
print(f"Model node identifier: {getattr(model_node, '__identifier__', 'None')}")
print(f"Model node NODE_NAME: {getattr(model_node, 'NODE_NAME', 'None')}")
print(f"Has get_inference_config: {hasattr(model_node, 'get_inference_config')}")
input_node = InputNode()
print(f"Input node type: {type(input_node)}")
print(f"Input node identifier: {getattr(input_node, '__identifier__', 'None')}")
print(f"Input node NODE_NAME: {getattr(input_node, 'NODE_NAME', 'None')}")
print(f"Has get_input_config: {hasattr(input_node, 'get_input_config')}")
output_node = OutputNode()
print(f"Output node type: {type(output_node)}")
print(f"Output node identifier: {getattr(output_node, '__identifier__', 'None')}")
print(f"Output node NODE_NAME: {getattr(output_node, 'NODE_NAME', 'None')}")
print(f"Has get_output_config: {hasattr(output_node, 'get_output_config')}")
def main():
"""Run all tests."""
print("Running Node Detection Tests...")
print("=" * 50)
try:
test_node_properties()
test_node_detection()
print("\n" + "=" * 50)
print("All tests passed! ✓")
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == '__main__':
main()