98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
|
|
import json
|
|
import csv
|
|
import os
|
|
import time
|
|
from typing import Any, Dict, List
|
|
|
|
class ResultSerializer:
|
|
"""
|
|
Serializes inference results into various formats.
|
|
"""
|
|
def to_json(self, data: Dict[str, Any]) -> str:
|
|
"""
|
|
Serializes data to a JSON string.
|
|
"""
|
|
return json.dumps(data, indent=2)
|
|
|
|
def to_csv(self, data: List[Dict[str, Any]], fieldnames: List[str]) -> str:
|
|
"""
|
|
Serializes data to a CSV string.
|
|
"""
|
|
import io
|
|
output = io.StringIO()
|
|
writer = csv.DictWriter(output, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(data)
|
|
return output.getvalue()
|
|
|
|
class FileOutputManager:
|
|
"""
|
|
Manages writing results to files with timestamped names and directory organization.
|
|
"""
|
|
def __init__(self, base_path: str = "./output"):
|
|
"""
|
|
Initializes the FileOutputManager.
|
|
|
|
Args:
|
|
base_path (str): The base directory to save output files.
|
|
"""
|
|
self.base_path = base_path
|
|
self.serializer = ResultSerializer()
|
|
|
|
def save_result(self, result_data: Dict[str, Any], pipeline_name: str, format: str = 'json'):
|
|
"""
|
|
Saves a single result to a file.
|
|
|
|
Args:
|
|
result_data (Dict[str, Any]): The result data to save.
|
|
pipeline_name (str): The name of the pipeline that generated the result.
|
|
format (str): The format to save the result in ('json' or 'csv').
|
|
"""
|
|
try:
|
|
# Sanitize pipeline_name to be a valid directory name
|
|
sanitized_pipeline_name = "".join(c for c in pipeline_name if c.isalnum() or c in (' ', '_')).rstrip()
|
|
if not sanitized_pipeline_name:
|
|
sanitized_pipeline_name = "default_pipeline"
|
|
|
|
# Ensure base_path is valid
|
|
if not self.base_path or not isinstance(self.base_path, str):
|
|
self.base_path = "./output"
|
|
|
|
# Create directory structure
|
|
today = time.strftime("%Y-%m-%d")
|
|
output_dir = os.path.join(self.base_path, sanitized_pipeline_name, today)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Create filename
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{timestamp}_{result_data.get('pipeline_id', 'result')}.{format}"
|
|
file_path = os.path.join(output_dir, filename)
|
|
|
|
# Serialize and save
|
|
if format == 'json':
|
|
content = self.serializer.to_json(result_data)
|
|
with open(file_path, 'w') as f:
|
|
f.write(content)
|
|
elif format == 'csv':
|
|
# For CSV, we expect a list of dicts. If it's a single dict, wrap it.
|
|
data_to_save = result_data if isinstance(result_data, list) else [result_data]
|
|
if data_to_save:
|
|
# Ensure all items in the list are dictionaries
|
|
if all(isinstance(item, dict) for item in data_to_save):
|
|
fieldnames = list(data_to_save[0].keys())
|
|
content = self.serializer.to_csv(data_to_save, fieldnames)
|
|
with open(file_path, 'w') as f:
|
|
f.write(content)
|
|
else:
|
|
print(f"Error: CSV data must be a list of dictionaries.")
|
|
return
|
|
else:
|
|
print(f"Error: Unsupported format '{format}'")
|
|
return
|
|
|
|
print(f"Result saved to {file_path}")
|
|
|
|
except Exception as e:
|
|
print(f"Error saving result: {e}")
|