Cluster/core/functions/InferencePipeline.py
2025-07-17 17:04:56 +08:00

595 lines
25 KiB
Python

from typing import List, Dict, Any, Optional, Callable, Union
import threading
import queue
import time
import traceback
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from Multidongle import MultiDongle, PreProcessor, PostProcessor, DataProcessor
@dataclass
class StageConfig:
"""Configuration for a single pipeline stage"""
stage_id: str
port_ids: List[int]
scpu_fw_path: str
ncpu_fw_path: str
model_path: str
upload_fw: bool = False
max_queue_size: int = 50
# Inter-stage processing
input_preprocessor: Optional[PreProcessor] = None # Before this stage
output_postprocessor: Optional[PostProcessor] = None # After this stage
# Stage-specific processing
stage_preprocessor: Optional[PreProcessor] = None # MultiDongle preprocessor
stage_postprocessor: Optional[PostProcessor] = None # MultiDongle postprocessor
@dataclass
class PipelineData:
"""Data structure flowing through pipeline"""
data: Any # Main data (image, features, etc.)
metadata: Dict[str, Any] # Additional info
stage_results: Dict[str, Any] # Results from each stage
pipeline_id: str # Unique identifier for this data flow
timestamp: float
class PipelineStage:
"""Single stage in the inference pipeline"""
def __init__(self, config: StageConfig):
self.config = config
self.stage_id = config.stage_id
# Initialize MultiDongle for this stage
self.multidongle = MultiDongle(
port_id=config.port_ids,
scpu_fw_path=config.scpu_fw_path,
ncpu_fw_path=config.ncpu_fw_path,
model_path=config.model_path,
upload_fw=config.upload_fw,
auto_detect=config.auto_detect if hasattr(config, 'auto_detect') else False,
max_queue_size=config.max_queue_size
)
# Store preprocessor and postprocessor for later use
self.stage_preprocessor = config.stage_preprocessor
self.stage_postprocessor = config.stage_postprocessor
self.max_queue_size = config.max_queue_size
# Inter-stage processors
self.input_preprocessor = config.input_preprocessor
self.output_postprocessor = config.output_postprocessor
# Threading for this stage
self.input_queue = queue.Queue(maxsize=config.max_queue_size)
self.output_queue = queue.Queue(maxsize=config.max_queue_size)
self.worker_thread = None
self.running = False
self._stop_event = threading.Event()
# Statistics
self.processed_count = 0
self.error_count = 0
self.processing_times = []
def initialize(self):
"""Initialize the stage"""
print(f"[Stage {self.stage_id}] Initializing...")
try:
self.multidongle.initialize()
self.multidongle.start()
print(f"[Stage {self.stage_id}] Initialized successfully")
except Exception as e:
print(f"[Stage {self.stage_id}] Initialization failed: {e}")
raise
def start(self):
"""Start the stage worker thread"""
if self.worker_thread and self.worker_thread.is_alive():
return
self.running = True
self._stop_event.clear()
self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self.worker_thread.start()
print(f"[Stage {self.stage_id}] Worker thread started")
def stop(self):
"""Stop the stage gracefully"""
print(f"[Stage {self.stage_id}] Stopping...")
self.running = False
self._stop_event.set()
# Put sentinel to unblock worker
try:
self.input_queue.put(None, timeout=1.0)
except queue.Full:
pass
# Wait for worker thread
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=3.0)
if self.worker_thread.is_alive():
print(f"[Stage {self.stage_id}] Warning: Worker thread didn't stop cleanly")
# Stop MultiDongle
self.multidongle.stop()
print(f"[Stage {self.stage_id}] Stopped")
def _worker_loop(self):
"""Main worker loop for processing data"""
print(f"[Stage {self.stage_id}] Worker loop started")
while self.running and not self._stop_event.is_set():
try:
# Get input data
try:
pipeline_data = self.input_queue.get(timeout=0.1)
if pipeline_data is None: # Sentinel value
continue
except queue.Empty:
continue
start_time = time.time()
# Process data through this stage
processed_data = self._process_data(pipeline_data)
# Record processing time
processing_time = time.time() - start_time
self.processing_times.append(processing_time)
if len(self.processing_times) > 1000: # Keep only recent times
self.processing_times = self.processing_times[-500:]
self.processed_count += 1
# Put result to output queue
try:
self.output_queue.put(processed_data, block=False)
except queue.Full:
# Drop oldest and add new
try:
self.output_queue.get_nowait()
self.output_queue.put(processed_data, block=False)
except queue.Empty:
pass
except Exception as e:
self.error_count += 1
print(f"[Stage {self.stage_id}] Processing error: {e}")
traceback.print_exc()
print(f"[Stage {self.stage_id}] Worker loop stopped")
def _process_data(self, pipeline_data: PipelineData) -> PipelineData:
"""Process data through this stage"""
try:
current_data = pipeline_data.data
# Debug: Print data info
if isinstance(current_data, np.ndarray):
print(f"[Stage {self.stage_id}] Input data: shape={current_data.shape}, dtype={current_data.dtype}")
# Step 1: Input preprocessing (inter-stage)
if self.input_preprocessor:
if isinstance(current_data, np.ndarray):
print(f"[Stage {self.stage_id}] Applying input preprocessor...")
current_data = self.input_preprocessor.process(
current_data,
self.multidongle.model_input_shape,
'BGR565' # Default format
)
print(f"[Stage {self.stage_id}] After input preprocess: shape={current_data.shape}, dtype={current_data.dtype}")
# Step 2: Always preprocess image data for MultiDongle
processed_data = None
if isinstance(current_data, np.ndarray) and len(current_data.shape) == 3:
# Always use MultiDongle's preprocess_frame to ensure correct format
print(f"[Stage {self.stage_id}] Preprocessing frame for MultiDongle...")
processed_data = self.multidongle.preprocess_frame(current_data, 'BGR565')
print(f"[Stage {self.stage_id}] After MultiDongle preprocess: shape={processed_data.shape}, dtype={processed_data.dtype}")
# Validate processed data
if processed_data is None:
raise ValueError("MultiDongle preprocess_frame returned None")
if not isinstance(processed_data, np.ndarray):
raise ValueError(f"MultiDongle preprocess_frame returned {type(processed_data)}, expected np.ndarray")
elif isinstance(current_data, dict) and 'raw_output' in current_data:
# This is result from previous stage, not suitable for direct inference
print(f"[Stage {self.stage_id}] Warning: Received processed result instead of image data")
processed_data = current_data
else:
print(f"[Stage {self.stage_id}] Warning: Unexpected data type: {type(current_data)}")
processed_data = current_data
# Step 3: MultiDongle inference
if isinstance(processed_data, np.ndarray):
print(f"[Stage {self.stage_id}] Sending to MultiDongle: shape={processed_data.shape}, dtype={processed_data.dtype}")
self.multidongle.put_input(processed_data, 'BGR565')
# Get inference result with timeout
inference_result = {}
timeout_start = time.time()
while time.time() - timeout_start < 5.0: # 5 second timeout
result = self.multidongle.get_latest_inference_result(timeout=0.1)
print(f"[Stage {self.stage_id}] Got result from MultiDongle: {result}")
# Check if result is valid (not None, not (None, None))
if result is not None:
if isinstance(result, tuple) and len(result) == 2:
# Handle tuple results like (probability, result_string)
prob, result_str = result
if prob is not None and result_str is not None:
print(f"[Stage {self.stage_id}] Valid result: prob={prob}, result={result_str}")
inference_result = result
break
else:
print(f"[Stage {self.stage_id}] Invalid tuple result: prob={prob}, result={result_str}")
elif isinstance(result, dict):
if result: # Non-empty dict
print(f"[Stage {self.stage_id}] Valid dict result: {result}")
inference_result = result
break
else:
print(f"[Stage {self.stage_id}] Empty dict result")
else:
print(f"[Stage {self.stage_id}] Other result type: {type(result)}")
inference_result = result
break
else:
print(f"[Stage {self.stage_id}] No result yet, waiting...")
time.sleep(0.01)
# Check if inference_result is empty (handle both dict and tuple types)
if (inference_result is None or
(isinstance(inference_result, dict) and not inference_result) or
(isinstance(inference_result, tuple) and (not inference_result or inference_result == (None, None)))):
print(f"[Stage {self.stage_id}] Warning: No inference result received after 5 second timeout")
inference_result = {'probability': 0.0, 'result': 'No Result'}
else:
print(f"[Stage {self.stage_id}] ✅ Successfully received inference result: {inference_result}")
# Step 3: Output postprocessing (inter-stage)
processed_result = inference_result
if self.output_postprocessor:
if 'raw_output' in inference_result:
processed_result = self.output_postprocessor.process(
inference_result['raw_output']
)
# Merge with original result
processed_result.update(inference_result)
# Step 4: Update pipeline data
pipeline_data.stage_results[self.stage_id] = processed_result
pipeline_data.data = processed_result # Pass result as data to next stage
pipeline_data.metadata[f'{self.stage_id}_timestamp'] = time.time()
return pipeline_data
except Exception as e:
print(f"[Stage {self.stage_id}] Data processing error: {e}")
# Return data with error info
pipeline_data.stage_results[self.stage_id] = {
'error': str(e),
'probability': 0.0,
'result': 'Processing Error'
}
return pipeline_data
def put_data(self, data: PipelineData, timeout: float = 1.0) -> bool:
"""Put data into this stage's input queue"""
try:
self.input_queue.put(data, timeout=timeout)
return True
except queue.Full:
return False
def get_result(self, timeout: float = 0.1) -> Optional[PipelineData]:
"""Get result from this stage's output queue"""
try:
return self.output_queue.get(timeout=timeout)
except queue.Empty:
return None
def get_statistics(self) -> Dict[str, Any]:
"""Get stage statistics"""
avg_processing_time = (
sum(self.processing_times) / len(self.processing_times)
if self.processing_times else 0.0
)
multidongle_stats = self.multidongle.get_statistics()
return {
'stage_id': self.stage_id,
'processed_count': self.processed_count,
'error_count': self.error_count,
'avg_processing_time': avg_processing_time,
'input_queue_size': self.input_queue.qsize(),
'output_queue_size': self.output_queue.qsize(),
'multidongle_stats': multidongle_stats
}
class InferencePipeline:
"""Multi-stage inference pipeline"""
def __init__(self, stage_configs: List[StageConfig],
final_postprocessor: Optional[PostProcessor] = None,
pipeline_name: str = "InferencePipeline"):
"""
Initialize inference pipeline
:param stage_configs: List of stage configurations
:param final_postprocessor: Final postprocessor after all stages
:param pipeline_name: Name for this pipeline instance
"""
self.pipeline_name = pipeline_name
self.stage_configs = stage_configs
self.final_postprocessor = final_postprocessor
# Create stages
self.stages: List[PipelineStage] = []
for config in stage_configs:
stage = PipelineStage(config)
self.stages.append(stage)
# Pipeline coordinator
self.coordinator_thread = None
self.running = False
self._stop_event = threading.Event()
# Input/Output queues for the entire pipeline
self.pipeline_input_queue = queue.Queue(maxsize=100)
self.pipeline_output_queue = queue.Queue(maxsize=100)
# Callbacks
self.result_callback = None
self.error_callback = None
self.stats_callback = None
# Statistics
self.pipeline_counter = 0
self.completed_counter = 0
self.error_counter = 0
def initialize(self):
"""Initialize all stages"""
print(f"[{self.pipeline_name}] Initializing pipeline with {len(self.stages)} stages...")
for i, stage in enumerate(self.stages):
try:
stage.initialize()
print(f"[{self.pipeline_name}] Stage {i+1}/{len(self.stages)} initialized")
except Exception as e:
print(f"[{self.pipeline_name}] Failed to initialize stage {stage.stage_id}: {e}")
# Cleanup already initialized stages
for j in range(i):
self.stages[j].stop()
raise
print(f"[{self.pipeline_name}] All stages initialized successfully")
def start(self):
"""Start the pipeline"""
print(f"[{self.pipeline_name}] Starting pipeline...")
# Start all stages
for stage in self.stages:
stage.start()
# Start coordinator
self.running = True
self._stop_event.clear()
self.coordinator_thread = threading.Thread(target=self._coordinator_loop, daemon=True)
self.coordinator_thread.start()
print(f"[{self.pipeline_name}] Pipeline started successfully")
def stop(self):
"""Stop the pipeline gracefully"""
print(f"[{self.pipeline_name}] Stopping pipeline...")
self.running = False
self._stop_event.set()
# Stop coordinator
if self.coordinator_thread and self.coordinator_thread.is_alive():
try:
self.pipeline_input_queue.put(None, timeout=1.0)
except queue.Full:
pass
self.coordinator_thread.join(timeout=3.0)
# Stop all stages
for stage in self.stages:
stage.stop()
print(f"[{self.pipeline_name}] Pipeline stopped")
def _coordinator_loop(self):
"""Coordinate data flow between stages"""
print(f"[{self.pipeline_name}] Coordinator started")
while self.running and not self._stop_event.is_set():
try:
# Get input data
try:
input_data = self.pipeline_input_queue.get(timeout=0.1)
if input_data is None: # Sentinel
continue
except queue.Empty:
continue
# Create pipeline data
pipeline_data = PipelineData(
data=input_data,
metadata={'start_timestamp': time.time()},
stage_results={},
pipeline_id=f"pipeline_{self.pipeline_counter}",
timestamp=time.time()
)
self.pipeline_counter += 1
# Process through each stage
current_data = pipeline_data
success = True
for i, stage in enumerate(self.stages):
# Send data to stage
if not stage.put_data(current_data, timeout=1.0):
print(f"[{self.pipeline_name}] Stage {stage.stage_id} input queue full, dropping data")
success = False
break
# Get result from stage
result_data = None
timeout_start = time.time()
while time.time() - timeout_start < 10.0: # 10 second timeout per stage
result_data = stage.get_result(timeout=0.1)
if result_data:
break
if self._stop_event.is_set():
break
time.sleep(0.01)
if not result_data:
print(f"[{self.pipeline_name}] Stage {stage.stage_id} timeout")
success = False
break
current_data = result_data
# Final postprocessing
if success and self.final_postprocessor:
try:
if isinstance(current_data.data, dict) and 'raw_output' in current_data.data:
final_result = self.final_postprocessor.process(current_data.data['raw_output'])
current_data.stage_results['final'] = final_result
current_data.data = final_result
except Exception as e:
print(f"[{self.pipeline_name}] Final postprocessing error: {e}")
# Output result
if success:
current_data.metadata['end_timestamp'] = time.time()
current_data.metadata['total_processing_time'] = (
current_data.metadata['end_timestamp'] -
current_data.metadata['start_timestamp']
)
try:
self.pipeline_output_queue.put(current_data, block=False)
self.completed_counter += 1
# Call result callback
if self.result_callback:
self.result_callback(current_data)
except queue.Full:
# Drop oldest and add new
try:
self.pipeline_output_queue.get_nowait()
self.pipeline_output_queue.put(current_data, block=False)
except queue.Empty:
pass
else:
self.error_counter += 1
if self.error_callback:
self.error_callback(current_data)
except Exception as e:
print(f"[{self.pipeline_name}] Coordinator error: {e}")
traceback.print_exc()
self.error_counter += 1
print(f"[{self.pipeline_name}] Coordinator stopped")
def put_data(self, data: Any, timeout: float = 1.0) -> bool:
"""Put data into pipeline"""
try:
self.pipeline_input_queue.put(data, timeout=timeout)
return True
except queue.Full:
return False
def get_result(self, timeout: float = 0.1) -> Optional[PipelineData]:
"""Get result from pipeline"""
try:
return self.pipeline_output_queue.get(timeout=timeout)
except queue.Empty:
return None
def set_result_callback(self, callback: Callable[[PipelineData], None]):
"""Set callback for successful results"""
self.result_callback = callback
def set_error_callback(self, callback: Callable[[PipelineData], None]):
"""Set callback for errors"""
self.error_callback = callback
def set_stats_callback(self, callback: Callable[[Dict[str, Any]], None]):
"""Set callback for statistics"""
self.stats_callback = callback
def get_pipeline_statistics(self) -> Dict[str, Any]:
"""Get comprehensive pipeline statistics"""
stage_stats = []
for stage in self.stages:
stage_stats.append(stage.get_statistics())
return {
'pipeline_name': self.pipeline_name,
'total_stages': len(self.stages),
'pipeline_input_submitted': self.pipeline_counter,
'pipeline_completed': self.completed_counter,
'pipeline_errors': self.error_counter,
'pipeline_input_queue_size': self.pipeline_input_queue.qsize(),
'pipeline_output_queue_size': self.pipeline_output_queue.qsize(),
'stage_statistics': stage_stats
}
def start_stats_reporting(self, interval: float = 5.0):
"""Start periodic statistics reporting"""
def stats_loop():
while self.running:
if self.stats_callback:
stats = self.get_pipeline_statistics()
self.stats_callback(stats)
time.sleep(interval)
stats_thread = threading.Thread(target=stats_loop, daemon=True)
stats_thread.start()
# Utility functions for common inter-stage processing
def create_feature_extractor_preprocessor() -> PreProcessor:
"""Create preprocessor for feature extraction stage"""
def extract_features(frame, target_size):
# Example: extract edges, keypoints, etc.
import cv2
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150)
return cv2.resize(edges, target_size)
return PreProcessor(resize_fn=extract_features)
def create_result_aggregator_postprocessor() -> PostProcessor:
"""Create postprocessor for aggregating multiple stage results"""
def aggregate_results(raw_output, **kwargs):
# Example: combine results from multiple stages
if isinstance(raw_output, dict):
# If raw_output is already processed results
return raw_output
# Standard processing
if raw_output.size > 0:
probability = float(raw_output[0])
return {
'aggregated_probability': probability,
'confidence': 'High' if probability > 0.8 else 'Medium' if probability > 0.5 else 'Low',
'result': 'Detected' if probability > 0.5 else 'Not Detected'
}
return {'aggregated_probability': 0.0, 'confidence': 'Low', 'result': 'Not Detected'}
return PostProcessor(process_fn=aggregate_results)