Spaces:
Sleeping
Sleeping
# agents/coordinator_agent.py | |
import logging | |
import os | |
import time | |
from typing import Dict, List, Optional, Tuple, Union, Any | |
from datetime import datetime | |
import json | |
# Import latest LangChain packages | |
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
#from langchain_community.llms import HuggingFaceHub | |
#from langchain_huggingface import HuggingFaceHub | |
from langchain_core.tools import Tool, tool | |
#from langchain_core.agents import create_react_agent | |
from langchain.agents.agent import AgentExecutor | |
#from langchain_community.agents import AgentExecutor | |
from langchain.agents.react.agent import create_react_agent | |
# Import utility classes | |
from utils.token_manager import TokenManager | |
from utils.cache_manager import CacheManager | |
from utils.metrics_calculator import MetricsCalculator | |
# Import agent classes for type hints | |
from agents.text_analysis_agent import TextAnalysisAgent | |
from agents.image_processing_agent import ImageProcessingAgent | |
from agents.report_generation_agent import ReportGeneratorAgent | |
from agents.metrics_agent import MetricsAgent | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
class CoordinatorAgent: | |
def __init__(self, text_analysis_agent=None, image_processing_agent=None, | |
report_generation_agent=None, metrics_agent=None, | |
token_manager=None, cache_manager=None, metrics_calculator=None): | |
"""Initialize the CoordinatorAgent with required agents and utilities.""" | |
self.logger = logging.getLogger(__name__) | |
self.text_analysis_agent = text_analysis_agent | |
self.image_processing_agent = image_processing_agent | |
self.report_generation_agent = report_generation_agent | |
self.metrics_agent = metrics_agent | |
self.token_manager = token_manager | |
self.cache_manager = cache_manager | |
self.metrics_calculator = metrics_calculator | |
# Track workflow states | |
self.workflow_state = {} | |
self.current_topic = None | |
self.workflow_id = None | |
# Agent name for logging | |
self.agent_name = "coordinator_agent" | |
# Initialize LangChain components | |
self._initialize_langchain_components() | |
def _initialize_langchain_components(self): | |
"""Initialize LangChain components for coordination.""" | |
try: | |
# Use HuggingFaceHub with a local model that doesn't require API keys | |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") | |
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") | |
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=1024) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
# Define tools for the agent | |
self.tools = self._create_tools() | |
# Create coordination agent prompt | |
#self.agent_prompt = self._create_agent_prompt() | |
#self.agent_prompt = self._create_agent_prompt() | |
self.agent_prompt = PromptTemplate.from_template( | |
"""You are an efficient workflow coordinator for a multi-agent AI system. | |
Your job is to orchestrate the analysis of text files and images related to a user's topic. | |
Topic: {topic} | |
Available Tools: analyze_text_files, process_images, generate_report, get_sustainability_metrics | |
How would you approach analyzing this topic with the available tools? | |
""" | |
) | |
# Create the agent | |
# self.agent = create_react_agent( | |
# self.llm, | |
# self.tools, | |
# self.agent_prompt | |
# ) | |
# # Create agent executor | |
# self.agent_executor = AgentExecutor( | |
# agent=self.agent, | |
# tools=self.tools, | |
# verbose=True, | |
# handle_parsing_errors=True, | |
# max_iterations=10 | |
# ) | |
# Create a simpler chain instead of a ReAct agent | |
self.chain = ( | |
self.agent_prompt | |
| self.llm | |
| StrOutputParser() | |
) | |
# Set agent_executor to use the chain | |
self.agent_executor = self.chain | |
self.logger.info("LangChain components initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Failed to initialize LangChain components: {e}") | |
# Fallback to direct coordination if LangChain initialization fails | |
self.agent_executor = None | |
def _create_tools(self): | |
"""Create tools for the LangChain agent.""" | |
tools = [] | |
# Tool for analyzing text files | |
def analyze_text_files(topic: str, file_paths: List[str]) -> str: | |
""" | |
Analyze text files for relevance to the specified topic. | |
Args: | |
topic: The topic to analyze for | |
file_paths: List of paths to text files | |
Returns: | |
Analysis results as a string summary | |
""" | |
if not self.text_analysis_agent: | |
return "Text analysis agent not available" | |
try: | |
result = self.text_analysis_agent.process_text_files(topic, file_paths) | |
return f"Text analysis completed. Found {result.get('relevant_documents', 0)} relevant documents out of {result.get('total_documents', 0)}." | |
except Exception as e: | |
return f"Error analyzing text files: {str(e)}" | |
# Tool for processing images | |
def process_images(topic: str, file_paths: List[str]) -> str: | |
""" | |
Process images for relevance to the specified topic. | |
Args: | |
topic: The topic to analyze for | |
file_paths: List of paths to image files | |
Returns: | |
Processing results as a string summary | |
""" | |
if not self.image_processing_agent: | |
return "Image processing agent not available" | |
try: | |
result = self.image_processing_agent.process_image_files(topic, file_paths) | |
return f"Image processing completed. Found {result.get('relevant_images', 0)} relevant images out of {result.get('total_images', 0)}." | |
except Exception as e: | |
return f"Error processing images: {str(e)}" | |
# Tool for generating reports | |
def generate_report(topic: str) -> str: | |
""" | |
Generate a comprehensive report on the topic based on previous analyses. | |
Args: | |
topic: The topic of the report | |
Returns: | |
Report generation status | |
""" | |
if not self.report_generation_agent: | |
return "Report generation agent not available" | |
if topic not in self.workflow_state: | |
return f"No analyses found for topic: {topic}" | |
try: | |
text_analysis = self.workflow_state[topic].get("text_analysis") | |
image_analysis = self.workflow_state[topic].get("image_analysis") | |
result = self.report_generation_agent.generate_report( | |
topic, text_analysis, image_analysis) | |
self.workflow_state[topic]["report"] = result | |
return f"Report generated successfully with confidence level: {result.get('confidence_level', 'unknown')}" | |
except Exception as e: | |
return f"Error generating report: {str(e)}" | |
# Tool for getting sustainability metrics | |
def get_sustainability_metrics() -> str: | |
""" | |
Get current sustainability metrics for the system. | |
Returns: | |
Sustainability metrics as a string summary | |
""" | |
if not self.metrics_agent: | |
return "Metrics agent not available" | |
try: | |
result = self.metrics_agent.generate_sustainability_report() | |
energy_usage = result.get("sustainability_metrics", {}).get("energy_usage_wh", 0) | |
carbon_footprint = result.get("sustainability_metrics", {}).get("carbon_footprint_kg", 0) | |
energy_saved = result.get("optimization_results", {}).get("energy_saved_wh", 0) | |
return f"Sustainability metrics: Energy used: {energy_usage:.6f} Wh, Carbon footprint: {carbon_footprint:.6f} kg CO2, Energy saved: {energy_saved:.6f} Wh" | |
except Exception as e: | |
return f"Error getting sustainability metrics: {str(e)}" | |
# Tool for allocating token budget | |
def allocate_token_budget(operation_type: str, budget: int) -> str: | |
""" | |
Allocate token budget for a specific operation type. | |
Args: | |
operation_type: Type of operation (text_analysis, image_captioning, etc.) | |
budget: Token budget to allocate | |
Returns: | |
Allocation status | |
""" | |
if not self.token_manager: | |
return "Token manager not available" | |
try: | |
self.token_manager.adjust_budget(operation_type, budget) | |
return f"Token budget for {operation_type} adjusted to {budget}" | |
except Exception as e: | |
return f"Error allocating token budget: {str(e)}" | |
# Add all tools | |
tools.extend([ | |
analyze_text_files, | |
process_images, | |
generate_report, | |
get_sustainability_metrics, | |
allocate_token_budget | |
]) | |
#return tools | |
return [{ | |
"name": tool.name, | |
"description": tool.description, | |
"func": tool | |
} for tool in tools] | |
def _create_agent_prompt(self): | |
"""Create the prompt for the coordination agent.""" | |
# Change this from ChatPromptTemplate to a simpler PromptTemplate | |
return PromptTemplate.from_template( | |
"""You are an efficient workflow coordinator for a multi-agent AI system. | |
Your job is to orchestrate the analysis of text files and images related to a user's topic. | |
Topic: {topic} | |
Available Agents: Text Analysis, Image Processing, Report Generation, Metrics | |
What would you like to do? | |
""" | |
) | |
# """Create the prompt for the coordination agent.""" | |
# return ChatPromptTemplate.from_messages([ | |
# ("system", """You are an efficient workflow coordinator for a multi-agent AI system. | |
# Your job is to orchestrate the analysis of text files and images related to a user's topic. | |
# Follow these steps in order: | |
# 1. First analyze text files for relevance to the topic | |
# 2. Then process images for relevance to the topic | |
# 3. Generate a comprehensive report combining both analyses | |
# 4. Check sustainability metrics | |
# Be efficient with resources and focus on finding information relevant to the user's topic. | |
# If one type of analysis fails, try to continue with the other type. | |
# Always provide clear updates on the progress of each step. | |
# """), | |
# ("user", "{input}"), | |
# ]) | |
def initialize_workflow(self, topic: str, text_files: List[str], image_files: List[str]) -> Dict[str, Any]: | |
""" | |
Initialize a new workflow for the given topic and files. | |
Returns a workflow status dict. | |
""" | |
# Generate a workflow ID | |
self.workflow_id = f"workflow_{int(time.time())}" | |
self.current_topic = topic | |
# Initialize workflow state | |
self.workflow_state[topic] = { | |
"workflow_id": self.workflow_id, | |
"topic": topic, | |
"text_files": text_files, | |
"image_files": image_files, | |
"start_time": datetime.now().isoformat(), | |
"status": "initialized", | |
"steps_completed": [], | |
"text_analysis": None, | |
"image_analysis": None, | |
"report": None | |
} | |
self.logger.info(f"Initialized workflow {self.workflow_id} for topic: {topic}") | |
# Log initial token budget if available | |
if self.token_manager: | |
self.workflow_state[topic]["initial_token_budget"] = self.token_manager.get_usage_stats() | |
return { | |
"workflow_id": self.workflow_id, | |
"topic": topic, | |
"status": "initialized", | |
"message": f"Workflow initialized with {len(text_files)} text files and {len(image_files)} image files" | |
} | |
def execute_workflow(self) -> Dict[str, Any]: | |
""" | |
Execute the current workflow using either LangChain agent or direct coordination. | |
Returns the workflow results. | |
""" | |
if not self.current_topic or self.current_topic not in self.workflow_state: | |
return {"error": "No active workflow. Please initialize a workflow first."} | |
topic = self.current_topic | |
workflow = self.workflow_state[topic] | |
text_files = workflow["text_files"] | |
image_files = workflow["image_files"] | |
start_time = time.time() | |
self.logger.info(f"Executing workflow {workflow['workflow_id']} for topic: {topic}") | |
# Update status | |
workflow["status"] = "in_progress" | |
try: | |
# Try to use LangChain agent if available | |
if self.agent_executor: | |
agent_input = f""" | |
I need to analyze information about the topic: "{topic}". | |
I have {len(text_files)} text files and {len(image_files)} image files to analyze. | |
Please coordinate the analysis process and generate a comprehensive report. | |
""" | |
agent_result = self.agent_executor.invoke({"input": agent_input}) | |
# Extract relevant information from agent output | |
workflow["agent_output"] = agent_result | |
# Update status | |
workflow["status"] = "completed" | |
workflow["end_time"] = datetime.now().isoformat() | |
workflow["processing_time"] = time.time() - start_time | |
return { | |
"workflow_id": workflow["workflow_id"], | |
"topic": topic, | |
"status": "completed", | |
"message": "Workflow completed successfully using LangChain agent", | |
"report": workflow.get("report", {}) | |
} | |
else: | |
# Fallback to direct coordination | |
return self._direct_coordination(topic, text_files, image_files) | |
except Exception as e: | |
self.logger.error(f"Error executing workflow: {e}") | |
# Fallback to direct coordination | |
self.logger.info("Falling back to direct coordination") | |
return self._direct_coordination(topic, text_files, image_files) | |
def _direct_coordination(self, topic: str, text_files: List[str], image_files: List[str]) -> Dict[str, Any]: | |
""" | |
Directly coordinate the workflow without using LangChain. | |
This is a fallback method if LangChain initialization fails or errors occur. | |
""" | |
workflow = self.workflow_state[topic] | |
start_time = time.time() | |
# Step 1: Analyze text files | |
if self.text_analysis_agent and text_files: | |
try: | |
self.logger.info(f"Analyzing {len(text_files)} text files") | |
text_analysis = self.text_analysis_agent.process_text_files(topic, text_files) | |
workflow["text_analysis"] = text_analysis | |
workflow["steps_completed"].append("text_analysis") | |
self.logger.info(f"Text analysis completed. Found {text_analysis.get('relevant_documents', 0)} relevant documents") | |
except Exception as e: | |
self.logger.error(f"Error in text analysis: {e}") | |
workflow["text_analysis_error"] = str(e) | |
# Step 2: Process images | |
if self.image_processing_agent and image_files: | |
try: | |
self.logger.info(f"Processing {len(image_files)} images") | |
image_analysis = self.image_processing_agent.process_image_files(topic, image_files) | |
workflow["image_analysis"] = image_analysis | |
workflow["steps_completed"].append("image_analysis") | |
self.logger.info(f"Image processing completed. Found {image_analysis.get('relevant_images', 0)} relevant images") | |
except Exception as e: | |
self.logger.error(f"Error in image processing: {e}") | |
workflow["image_analysis_error"] = str(e) | |
# Step 3: Generate report | |
if self.report_generation_agent: | |
try: | |
self.logger.info("Generating report") | |
report = self.report_generation_agent.generate_report( | |
topic, | |
workflow.get("text_analysis"), | |
workflow.get("image_analysis") | |
) | |
workflow["report"] = report | |
workflow["steps_completed"].append("report_generation") | |
self.logger.info(f"Report generated with confidence level: {report.get('confidence_level', 'unknown')}") | |
except Exception as e: | |
self.logger.error(f"Error in report generation: {e}") | |
workflow["report_generation_error"] = str(e) | |
# Step 4: Get sustainability metrics | |
if self.metrics_agent: | |
try: | |
self.logger.info("Getting sustainability metrics") | |
metrics = self.metrics_agent.generate_sustainability_report() | |
workflow["sustainability_metrics"] = metrics | |
workflow["steps_completed"].append("metrics_collection") | |
except Exception as e: | |
self.logger.error(f"Error getting sustainability metrics: {e}") | |
workflow["metrics_error"] = str(e) | |
# Update workflow status | |
workflow["status"] = "completed" | |
workflow["end_time"] = datetime.now().isoformat() | |
workflow["processing_time"] = time.time() - start_time | |
return { | |
"workflow_id": workflow["workflow_id"], | |
"topic": topic, | |
"status": "completed", | |
"message": "Workflow completed successfully using direct coordination", | |
"steps_completed": workflow["steps_completed"], | |
"processing_time": workflow["processing_time"], | |
"report": workflow.get("report", {}) | |
} | |
def get_workflow_status(self, workflow_id: Optional[str] = None) -> Dict[str, Any]: | |
""" | |
Get the status of a workflow. | |
If workflow_id is not provided, returns the status of the current workflow. | |
""" | |
if workflow_id: | |
# Find workflow by ID | |
for topic, workflow in self.workflow_state.items(): | |
if workflow.get("workflow_id") == workflow_id: | |
return { | |
"workflow_id": workflow_id, | |
"topic": topic, | |
"status": workflow.get("status", "unknown"), | |
"steps_completed": workflow.get("steps_completed", []), | |
"processing_time": workflow.get("processing_time", 0) if workflow.get("status") == "completed" else None | |
} | |
# Workflow not found | |
return {"error": f"Workflow {workflow_id} not found"} | |
# Return current workflow status | |
if not self.current_topic or self.current_topic not in self.workflow_state: | |
return {"error": "No active workflow"} | |
workflow = self.workflow_state[self.current_topic] | |
return { | |
"workflow_id": workflow.get("workflow_id"), | |
"topic": self.current_topic, | |
"status": workflow.get("status", "unknown"), | |
"steps_completed": workflow.get("steps_completed", []), | |
"processing_time": workflow.get("processing_time", 0) if workflow.get("status") == "completed" else None | |
} | |
def _store_workflow_results(self, topic: str) -> None: | |
""" | |
Store workflow results in cache for future reuse. | |
""" | |
if not self.cache_manager or topic not in self.workflow_state: | |
return | |
workflow = self.workflow_state[topic] | |
# Only cache completed workflows | |
if workflow.get("status") != "completed": | |
return | |
# Store text analysis results | |
if "text_analysis" in workflow and workflow["text_analysis"]: | |
text_key = f"text_analysis:{topic}" | |
self.cache_manager.put( | |
text_key, | |
workflow["text_analysis"], | |
namespace="workflow_results" | |
) | |
# Store image analysis results | |
if "image_analysis" in workflow and workflow["image_analysis"]: | |
image_key = f"image_analysis:{topic}" | |
self.cache_manager.put( | |
image_key, | |
workflow["image_analysis"], | |
namespace="workflow_results" | |
) | |
# Store report | |
if "report" in workflow and workflow["report"]: | |
report_key = f"report:{topic}" | |
self.cache_manager.put( | |
report_key, | |
workflow["report"], | |
namespace="workflow_results" | |
) | |
self.logger.info(f"Stored workflow results for topic '{topic}' in cache") | |
def _get_cached_results(self, topic: str, file_paths: List[str]) -> Dict[str, Any]: | |
""" | |
Try to retrieve cached results for the given topic and files. | |
Returns a dict of cached components or empty dict if nothing cached. | |
""" | |
if not self.cache_manager: | |
return {} | |
# Create a cache key that includes file information | |
files_hash = str(hash(tuple(sorted(file_paths)))) | |
cache_key = f"workflow:{topic}:{files_hash}" | |
cached_results = {} | |
# Try to get text analysis from cache | |
text_key = f"text_analysis:{topic}" | |
cache_hit, text_analysis = self.cache_manager.get(text_key, namespace="workflow_results") | |
if cache_hit: | |
cached_results["text_analysis"] = text_analysis | |
self.logger.info(f"Retrieved cached text analysis for topic '{topic}'") | |
# Try to get image analysis from cache | |
image_key = f"image_analysis:{topic}" | |
cache_hit, image_analysis = self.cache_manager.get(image_key, namespace="workflow_results") | |
if cache_hit: | |
cached_results["image_analysis"] = image_analysis | |
self.logger.info(f"Retrieved cached image analysis for topic '{topic}'") | |
# Try to get report from cache | |
report_key = f"report:{topic}" | |
cache_hit, report = self.cache_manager.get(report_key, namespace="workflow_results") | |
if cache_hit: | |
cached_results["report"] = report | |
self.logger.info(f"Retrieved cached report for topic '{topic}'") | |
return cached_results | |
def cleanup_workflow(self, workflow_id: Optional[str] = None) -> Dict[str, Any]: | |
""" | |
Clean up resources for a completed workflow. | |
If workflow_id is not provided, cleans up the current workflow. | |
""" | |
if workflow_id: | |
# Find workflow by ID | |
target_topic = None | |
for topic, workflow in self.workflow_state.items(): | |
if workflow.get("workflow_id") == workflow_id: | |
target_topic = topic | |
break | |
if not target_topic: | |
return {"error": f"Workflow {workflow_id} not found"} | |
else: | |
target_topic = self.current_topic | |
if not target_topic or target_topic not in self.workflow_state: | |
return {"error": "No workflow to clean up"} | |
# Store results in cache before cleanup | |
self._store_workflow_results(target_topic) | |
# Get workflow for reporting | |
workflow = self.workflow_state[target_topic] | |
workflow_id = workflow.get("workflow_id") | |
# Clean up large data structures but keep metadata | |
if "text_analysis" in workflow: | |
# Keep summary but remove large content | |
if "processed_documents" in workflow["text_analysis"]: | |
for doc in workflow["text_analysis"]["processed_documents"]: | |
if "content" in doc: | |
doc["content"] = f"[CLEANED] {len(doc['content'])} characters" | |
if "image_analysis" in workflow: | |
# Remove any image data that might be stored | |
if "processed_images" in workflow["image_analysis"]: | |
for img in workflow["image_analysis"]["processed_images"]: | |
if "image" in img: | |
del img["image"] | |
self.logger.info(f"Cleaned up workflow {workflow_id} for topic '{target_topic}'") | |
return { | |
"workflow_id": workflow_id, | |
"topic": target_topic, | |
"status": "cleaned_up", | |
"message": "Workflow resources have been cleaned up" | |
} | |