Spaces:
Sleeping
Sleeping
# Updated app/orchestrator.py | |
import logging | |
import os | |
import time | |
from typing import Dict, List, Optional, Tuple, Union, Any | |
from datetime import datetime, timedelta | |
import json | |
from app.error_handler import ErrorHandler, with_error_handling | |
from app.synchronizer import Synchronizer | |
import threading | |
class Orchestrator: | |
def __init__(self, coordinator_agent, text_analysis_agent, image_processing_agent, report_generation_agent,metrics_agent, text_model_manager=None, image_model_manager=None, | |
summary_model_manager=None, token_manager=None, cache_manager=None, | |
metrics_calculator=None): | |
"""Initialize the Orchestrator with required components.""" | |
self.logger = logging.getLogger(__name__) | |
self.coordinator_agent = coordinator_agent | |
self.text_model_manager = text_model_manager | |
self.image_model_manager = image_model_manager | |
self.summary_model_manager = summary_model_manager | |
self.token_manager = token_manager | |
self.cache_manager = cache_manager | |
self.metrics_calculator = metrics_calculator | |
# Store the agents directly | |
self.text_analysis_agent = text_analysis_agent | |
self.image_processing_agent = image_processing_agent | |
self.report_generation_agent = report_generation_agent | |
self.metrics_agent = metrics_agent | |
# Initialize error handler | |
self.error_handler = ErrorHandler(metrics_calculator=metrics_calculator) | |
# Register fallbacks | |
self._register_fallbacks() | |
# Track active sessions | |
self.active_sessions = {} | |
self.session_counter = 0 | |
self.synchronizer = Synchronizer() | |
self._register_agents_with_synchronizer() | |
def _register_agents_with_synchronizer(self): | |
"""Register all agents with the synchronizer.""" | |
# Register the coordinator | |
self.synchronizer.register_agent("coordinator_agent") | |
# Register other agents if available | |
if hasattr(self, "text_analysis_agent") and self.text_analysis_agent: | |
self.synchronizer.register_agent("text_analysis_agent") | |
if hasattr(self, "image_processing_agent") and self.image_processing_agent: | |
self.synchronizer.register_agent("image_processing_agent") | |
if hasattr(self, "report_generation_agent") and self.report_generation_agent: | |
self.synchronizer.register_agent("report_generation_agent") | |
if hasattr(self, "metrics_agent") and self.metrics_agent: | |
self.synchronizer.register_agent("metrics_agent") | |
def coordinate_workflow_with_synchronization(self, session_id: str, topic: str, | |
text_files: List[str], image_files: List[str]) -> Dict[str, Any]: | |
""" | |
Coordinate a workflow with explicit synchronization points. | |
This provides more control over the workflow execution than the standard process_request. | |
""" | |
if session_id not in self.active_sessions: | |
return {"error": f"Session {session_id} not found. Please create a new session."} | |
session = self.active_sessions[session_id] | |
session["status"] = "processing" | |
# Create a workflow ID | |
workflow_id = f"workflow_{int(time.time())}" | |
# Initialize workflow | |
workflow_result = self.coordinator_agent.initialize_workflow(topic, text_files, image_files) | |
# Store workflow ID | |
if "workflow_id" in workflow_result: | |
workflow_id = workflow_result["workflow_id"] | |
session["workflows"].append(workflow_id) | |
session["current_workflow"] = workflow_id | |
# Create synchronization barriers | |
analysis_barrier_id = self.synchronizer.create_barrier( | |
"analysis_complete", | |
["text_analysis_agent", "image_processing_agent"] | |
) | |
report_barrier_id = self.synchronizer.create_barrier( | |
"report_ready", | |
["report_generation_agent"] | |
) | |
# Set up dependencies | |
if hasattr(self, "report_generation_agent") and self.report_generation_agent: | |
self.synchronizer.register_dependencies( | |
"report_generation_agent", | |
f"generate_report_{workflow_id}", | |
[ | |
("text_analysis_agent", f"analyze_text_{workflow_id}"), | |
("image_processing_agent", f"process_images_{workflow_id}") | |
] | |
) | |
# Start text analysis in background | |
if hasattr(self, "text_analysis_agent") and self.text_analysis_agent and text_files: | |
def text_analysis_task(): | |
try: | |
# Process text files | |
result = self.text_analysis_agent.process_text_files(topic, text_files) | |
# Signal completion | |
self.synchronizer.signal_completion( | |
"text_analysis_agent", | |
f"analyze_text_{workflow_id}", | |
result | |
) | |
# Arrive at barrier | |
self.synchronizer.arrive_at_barrier(analysis_barrier_id, "text_analysis_agent", result) | |
return result | |
except Exception as e: | |
self.logger.error(f"Error in text analysis: {str(e)}") | |
# Signal completion with error | |
self.synchronizer.signal_completion( | |
"text_analysis_agent", | |
f"analyze_text_{workflow_id}", | |
{"error": str(e)} | |
) | |
# Arrive at barrier with error | |
self.synchronizer.arrive_at_barrier( | |
analysis_barrier_id, | |
"text_analysis_agent", | |
{"error": str(e)} | |
) | |
return {"error": str(e)} | |
# Start in background thread | |
text_thread = threading.Thread(target=text_analysis_task) | |
text_thread.daemon = True | |
text_thread.start() | |
else: | |
# If no text analysis, signal completion with empty result | |
self.synchronizer.signal_completion( | |
"text_analysis_agent", | |
f"analyze_text_{workflow_id}", | |
{"status": "skipped", "reason": "No text files or text analysis agent"} | |
) | |
self.synchronizer.arrive_at_barrier( | |
analysis_barrier_id, | |
"text_analysis_agent", | |
{"status": "skipped"} | |
) | |
# Start image processing in background | |
if hasattr(self, "image_processing_agent") and self.image_processing_agent and image_files: | |
def image_processing_task(): | |
try: | |
# Process images | |
result = self.image_processing_agent.process_image_files(topic, image_files) | |
# Signal completion | |
self.synchronizer.signal_completion( | |
"image_processing_agent", | |
f"process_images_{workflow_id}", | |
result | |
) | |
# Arrive at barrier | |
self.synchronizer.arrive_at_barrier(analysis_barrier_id, "image_processing_agent", result) | |
return result | |
except Exception as e: | |
self.logger.error(f"Error in image processing: {str(e)}") | |
# Signal completion with error | |
self.synchronizer.signal_completion( | |
"image_processing_agent", | |
f"process_images_{workflow_id}", | |
{"error": str(e)} | |
) | |
# Arrive at barrier with error | |
self.synchronizer.arrive_at_barrier( | |
analysis_barrier_id, | |
"image_processing_agent", | |
{"error": str(e)} | |
) | |
return {"error": str(e)} | |
# Start in background thread | |
image_thread = threading.Thread(target=image_processing_task) | |
image_thread.daemon = True | |
image_thread.start() | |
else: | |
# If no image processing, signal completion with empty result | |
self.synchronizer.signal_completion( | |
"image_processing_agent", | |
f"process_images_{workflow_id}", | |
{"status": "skipped", "reason": "No image files or image processing agent"} | |
) | |
self.synchronizer.arrive_at_barrier( | |
analysis_barrier_id, | |
"image_processing_agent", | |
{"status": "skipped"} | |
) | |
# Wait for analysis to complete | |
if not self.synchronizer.wait_for_barrier(analysis_barrier_id, timeout=900): # 15 minute timeout | |
self.logger.error(f"Timeout waiting for analysis to complete") | |
session["status"] = "error" | |
return { | |
"error": "Timeout waiting for analysis to complete", | |
"status": "timeout", | |
"workflow_id": workflow_id | |
} | |
# Get analysis results | |
barrier_data = self.synchronizer.get_barrier_data(analysis_barrier_id) | |
text_analysis = barrier_data.get("text_analysis_agent", {}) | |
image_analysis = barrier_data.get("image_processing_agent", {}) | |
# Add debug logging here | |
self.logger.info(f"Analysis complete. Text analysis: {bool(text_analysis)}, Image analysis: {bool(image_analysis)}") | |
# Check if report_generation_agent exists before trying to use it | |
if not hasattr(self, "report_generation_agent") or not self.report_generation_agent: | |
self.logger.warning("Report generation agent not available, skipping report generation") | |
session["status"] = "completed" | |
return { | |
"status": "completed", | |
"workflow_id": workflow_id, | |
"topic": topic, | |
"results": { | |
"text_analysis": text_analysis, | |
"image_analysis": image_analysis | |
} | |
} | |
# Make sure the report agent is registered with the synchronizer | |
self.synchronizer.register_agent("report_generation_agent") | |
# Manually signal arrival for report agent if it's not responding | |
# This is a fallback in case the report thread is not starting properly | |
report_thread_started = False | |
# Check for errors | |
text_error = "error" in text_analysis | |
image_error = "error" in image_analysis | |
if text_error and image_error: | |
session["status"] = "error" | |
return { | |
"error": "Both text and image analysis failed", | |
"text_error": text_analysis.get("error", "Unknown error"), | |
"image_error": image_analysis.get("error", "Unknown error"), | |
"status": "error", | |
"workflow_id": workflow_id | |
} | |
#report_thread_started = False | |
# Generate report | |
if hasattr(self, "report_generation_agent") and self.report_generation_agent: | |
def report_generation_task(): | |
nonlocal report_thread_started | |
report_thread_started = True | |
try: | |
# Wait for dependencies to be met | |
if not self.synchronizer.are_dependencies_met( | |
"report_generation_agent", f"generate_report_{workflow_id}"): | |
self.logger.info("Waiting for dependencies to be met for report generation") | |
# Generate report | |
result = self.report_generation_agent.generate_report( | |
topic, text_analysis, image_analysis) | |
# Signal completion | |
self.synchronizer.signal_completion( | |
"report_generation_agent", | |
f"generate_report_{workflow_id}", | |
result | |
) | |
# Arrive at barrier | |
self.synchronizer.arrive_at_barrier(report_barrier_id, "report_generation_agent", result) | |
return result | |
except Exception as e: | |
self.logger.error(f"Error in report generation: {str(e)}") | |
# Signal completion with error | |
self.synchronizer.signal_completion( | |
"report_generation_agent", | |
f"generate_report_{workflow_id}", | |
{"error": str(e)} | |
) | |
# Arrive at barrier with error | |
self.synchronizer.arrive_at_barrier( | |
report_barrier_id, | |
"report_generation_agent", | |
{"error": str(e)} | |
) | |
return {"error": str(e)} | |
# Start in background thread | |
report_thread = threading.Thread(target=report_generation_task) | |
report_thread.daemon = True | |
report_thread.start() | |
start_time = time.time() | |
while not report_thread_started and time.time() - start_time < 10: # 10 second timeout | |
time.sleep(0.1) | |
if not report_thread_started: | |
self.logger.error("Report generation thread failed to start, manually signaling completion") | |
# Manually arrive at the barrier | |
self.synchronizer.arrive_at_barrier( | |
report_barrier_id, | |
"report_generation_agent", | |
{"error": "Report thread failed to start"} | |
) | |
self.logger.info(f"Final report data: {report.keys() if report else 'None'}") | |
self.logger.info("Workflow completed, returning results to UI") | |
# Wait for report to be ready | |
if not self.synchronizer.wait_for_barrier(report_barrier_id, timeout=300): # 5 minute timeout | |
self.logger.error(f"Timeout waiting for report generation") | |
session["status"] = "error" | |
return { | |
"error": "Timeout waiting for report generation", | |
"status": "timeout", | |
"workflow_id": workflow_id, | |
"partial_results": { | |
"text_analysis": text_analysis, | |
"image_analysis": image_analysis | |
} | |
} | |
# Get report | |
barrier_data = self.synchronizer.get_barrier_data(report_barrier_id) | |
report = barrier_data.get("report_generation_agent", {}) | |
self.logger.info(f"Report barrier data keys: {barrier_data.keys()}") | |
self.logger.info(f"Report data keys: {report.keys() if isinstance(report, dict) else 'Not a dict'}") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("Report generation completed, preparing to return results") | |
self.logger.info(f"Report type: {type(report)}") | |
self.logger.info(f"Report keys: {report.keys() if isinstance(report, dict) else 'Not a dict'}") | |
self.logger.info(f"Session status updated to: {session['status']}") | |
print("WORKFLOW COMPLETED: Results ready to be returned to UI") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
self.logger.info("************************************************") | |
# Check for errors | |
if "error" in report: | |
session["status"] = "error" | |
return { | |
"error": "Report generation failed", | |
"report_error": report.get("error", "Unknown error"), | |
"status": "error", | |
"workflow_id": workflow_id, | |
"partial_results": { | |
"text_analysis": text_analysis, | |
"image_analysis": image_analysis | |
} | |
} | |
# Update session status | |
session["status"] = "completed" | |
session["last_result"] = report | |
# Make sure session update is visible to UI | |
self.active_sessions[session_id] = { | |
"status": "completed", | |
"report": report, | |
"workflow_id": workflow_id, | |
"topic": topic, | |
"timestamp": datetime.now().isoformat() | |
} | |
# Get sustainability metrics if available | |
sustainability_metrics = None | |
if hasattr(self, "metrics_agent") and self.metrics_agent: | |
try: | |
sustainability_metrics = self.metrics_agent.generate_sustainability_report() | |
except Exception as e: | |
self.logger.error(f"Error getting sustainability metrics: {str(e)}") | |
session["status"] = "completed" | |
self.logger.info("========== WORKFLOW COMPLETED ==========") | |
self.logger.info(f"Returning final result with keys: {list(result.keys()) if isinstance(result, dict) else 'Not a dict'}") | |
# Return final result | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("Returning final result to caller") | |
print(f"RETURN DATA: status={result.get('status')}, keys={result.keys() if isinstance(result, dict) else 'Not a dict'}") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
self.logger.info("#####################################################") | |
return { | |
"status": "completed", | |
"workflow_id": workflow_id, | |
"topic": topic, | |
"report": report, | |
"sustainability_metrics": sustainability_metrics | |
} | |
else: | |
# No report generation, return analysis results | |
session["status"] = "completed" | |
self.logger.info("========== ERROR WORKFLOW COMPLETED ==========") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("Returning final result to caller") | |
print(f"RETURN DATA: status={result.get('status')}, keys={result.keys() if isinstance(result, dict) else 'Not a dict'}") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info("=====================================================") | |
self.logger.info(f"Returning final result with keys: {list(result.keys()) if isinstance(result, dict) else 'Not a dict'}") | |
return { | |
"status": "completed", | |
"workflow_id": workflow_id, | |
"topic": topic, | |
"results": { | |
"text_analysis": text_analysis, | |
"image_analysis": image_analysis | |
} | |
} | |
def _register_fallbacks(self): | |
"""Register fallback functions for critical operations.""" | |
# Fallback for process_request | |
self.error_handler.register_fallback( | |
"orchestrator", "process_request", | |
self._fallback_process_request | |
) | |
# Fallback for coordinator workflow execution | |
self.error_handler.register_fallback( | |
"coordinator_agent", "execute_workflow", | |
self._fallback_execute_workflow | |
) | |
def _fallback_process_request(self, context): | |
"""Fallback function for processing requests.""" | |
# Extract what we can from the context | |
kwargs = context.get("kwargs", {}) | |
topic = kwargs.get("topic", "unknown") | |
session_id = kwargs.get("session_id", "unknown") | |
# Check if we have a session | |
if session_id in self.active_sessions: | |
session = self.active_sessions[session_id] | |
session["status"] = "error" | |
session["error"] = "Request processing failed, using fallback" | |
return { | |
"status": "error", | |
"message": "An error occurred while processing your request. Using simplified processing.", | |
"topic": topic, | |
"fallback": True, | |
"result": { | |
"confidence_level": "low", | |
"summary": "Unable to process request fully. Please try again or simplify your query." | |
} | |
} | |
def _fallback_execute_workflow(self, context): | |
"""Fallback function for workflow execution.""" | |
# We can attempt direct coordination as a fallback | |
try: | |
if hasattr(self.coordinator_agent, "_direct_coordination"): | |
# Extract current topic and files from coordinator agent state | |
topic = self.coordinator_agent.current_topic | |
if topic and topic in self.coordinator_agent.workflow_state: | |
workflow = self.coordinator_agent.workflow_state[topic] | |
text_files = workflow.get("text_files", []) | |
image_files = workflow.get("image_files", []) | |
# Try direct coordination | |
return self.coordinator_agent._direct_coordination(topic, text_files, image_files) | |
# If we can't do direct coordination, return a basic error response | |
return { | |
"status": "error", | |
"message": "Workflow execution failed. Using fallback.", | |
"fallback": True | |
} | |
except Exception as e: | |
self.logger.error(f"Fallback for execute_workflow also failed: {str(e)}") | |
return { | |
"status": "critical_error", | |
"message": "Both primary and fallback execution failed." | |
} | |
#@with_error_handling("orchestrator", "create_session", lambda self: self.error_handler) | |
def create_session(self) -> str: | |
"""Create a new session and return session ID.""" | |
session_id = f"session_{int(time.time())}_{self.session_counter}" | |
self.session_counter += 1 | |
self.active_sessions[session_id] = { | |
"created_at": datetime.now().isoformat(), | |
"status": "initialized", | |
"workflows": [], | |
"current_workflow": None | |
} | |
self.logger.info(f"Created new session: {session_id}") | |
return session_id | |
#@with_error_handling("orchestrator", "process_request", lambda self: self.error_handler) | |
def process_request(self, session_id: str, topic: str, text_files: List[str], | |
image_files: List[str]) -> Dict[str, Any]: | |
""" | |
Process a user request within a session. | |
Coordinates the workflow through the coordinator agent. | |
""" | |
if session_id not in self.active_sessions: | |
return {"error": f"Session {session_id} not found. Please create a new session."} | |
session = self.active_sessions[session_id] | |
session["status"] = "processing" | |
# Initialize workflow via coordinator | |
workflow_result = self.coordinator_agent.initialize_workflow(topic, text_files, image_files) | |
# Store workflow ID in session | |
workflow_id = workflow_result.get("workflow_id") | |
if workflow_id: | |
session["workflows"].append(workflow_id) | |
session["current_workflow"] = workflow_id | |
# Execute workflow with error handling | |
try: | |
# Try to execute with error handling | |
result = self._execute_workflow_with_error_handling() | |
except Exception as e: | |
# If that fails, try direct execution as a last resort | |
self.logger.error(f"Error executing workflow with error handling: {str(e)}") | |
result = self.coordinator_agent.execute_workflow() | |
# Update session status | |
session["status"] = "completed" if not result.get("error") else "error" | |
session["last_result"] = result | |
self.logger.info(f"Process request completed. Session status: {session['status']}") | |
self.logger.info(f"Active sessions: {list(self.active_sessions.keys())}") | |
#return result | |
return result | |
def _execute_workflow_with_error_handling(self) -> Dict[str, Any]: | |
"""Execute workflow with error handling.""" | |
try: | |
result = self.coordinator_agent.execute_workflow() | |
self.error_handler.record_success("coordinator_agent", "execute_workflow") | |
return result | |
except Exception as e: | |
# Create context | |
context = { | |
"orchestrator": self, | |
"coordinator_agent": self.coordinator_agent | |
} | |
# Handle the error | |
handled, fallback_result = self.error_handler.handle_error( | |
"coordinator_agent", "execute_workflow", e, context) | |
if handled: | |
return fallback_result | |
else: | |
# Re-raise the exception if not handled | |
raise | |
#@with_error_handling("orchestrator", "get_session_status", lambda self: self.error_handler) | |
def get_session_status(self, session_id: str) -> Dict[str, Any]: | |
"""Get the status of a session.""" | |
if session_id not in self.active_sessions: | |
return {"error": f"Session {session_id} not found"} | |
session = self.active_sessions[session_id] | |
# If there's an active workflow, get its status | |
if session.get("current_workflow"): | |
try: | |
workflow_status = self.coordinator_agent.get_workflow_status( | |
session["current_workflow"]) | |
return { | |
"session_id": session_id, | |
"status": session["status"], | |
"created_at": session["created_at"], | |
"workflows": session["workflows"], | |
"current_workflow": session["current_workflow"], | |
"workflow_status": workflow_status | |
} | |
except Exception as e: | |
# If getting workflow status fails, return basic session info | |
self.logger.error(f"Error getting workflow status: {str(e)}") | |
return { | |
"session_id": session_id, | |
"status": session["status"], | |
"created_at": session["created_at"], | |
"workflows": session["workflows"], | |
"error": "Failed to retrieve detailed workflow status" | |
} | |
else: | |
return { | |
"session_id": session_id, | |
"status": session["status"], | |
"created_at": session["created_at"], | |
"workflows": session["workflows"] | |
} | |
#@with_error_handling("orchestrator", "cleanup_session", lambda self: self.error_handler) | |
def cleanup_session(self, session_id: str) -> Dict[str, Any]: | |
"""Clean up resources for a session.""" | |
if session_id not in self.active_sessions: | |
return {"error": f"Session {session_id} not found"} | |
session = self.active_sessions[session_id] | |
# Clean up any active workflows | |
if session.get("current_workflow"): | |
try: | |
self.coordinator_agent.cleanup_workflow(session["current_workflow"]) | |
except Exception as e: | |
self.logger.error(f"Error cleaning up workflow: {str(e)}") | |
# Continue with session cleanup even if workflow cleanup fails | |
# Mark session as cleaned up | |
session["status"] = "cleaned_up" | |
return { | |
"session_id": session_id, | |
"status": "cleaned_up", | |
"message": "Session resources have been cleaned up" | |
} | |
#@with_error_handling("orchestrator", "get_sustainability_metrics", lambda self: self.error_handler) | |
def get_sustainability_metrics(self, session_id: Optional[str] = None) -> Dict[str, Any]: | |
""" | |
Get sustainability metrics for a session or the entire system. | |
If session_id is provided, returns metrics for that session only. | |
""" | |
if not self.metrics_calculator: | |
return {"error": "Metrics calculator not available"} | |
if session_id: | |
# TODO: Implement session-specific metrics | |
# For now, return global metrics | |
return self.metrics_calculator.get_all_metrics() | |
else: | |
# Return global metrics | |
return self.metrics_calculator.get_all_metrics() | |
def get_error_report(self) -> Dict[str, Any]: | |
"""Get error report from the error handler.""" | |
if not self.error_handler: | |
return {"error": "Error handler not available"} | |
return self.error_handler.get_error_report() | |