# app/synchronizer.py import logging import time import threading from typing import Dict, List, Optional, Tuple, Union, Any, Callable from datetime import datetime, timedelta import queue import uuid class Synchronizer: def __init__(self): """Initialize the Synchronizer for coordinating agent interactions.""" self.logger = logging.getLogger(__name__) # Synchronization barriers self.barriers = {} # Completion signals self.completion_signals = {} # Dependency tracking self.dependencies = {} # Locks for thread safety self.barrier_lock = threading.RLock() self.signal_lock = threading.RLock() # Event queues for agents self.event_queues = {} # Callback registry self.callbacks = {} def register_agent(self, agent_id: str) -> None: """Register an agent with the synchronizer.""" with self.barrier_lock: if agent_id not in self.event_queues: self.event_queues[agent_id] = queue.Queue() self.logger.info(f"Registered agent: {agent_id}") def create_barrier(self, barrier_id: str, participants: List[str]) -> str: """ Create a synchronization barrier that multiple agents must reach. Returns a unique barrier instance ID. """ instance_id = f"{barrier_id}_{uuid.uuid4().hex[:8]}" with self.barrier_lock: self.barriers[instance_id] = { "participants": set(participants), "arrived": set(), "event": threading.Event(), "created_at": datetime.now().isoformat() } self.logger.info(f"Created barrier {instance_id} with {len(participants)} participants") return instance_id def arrive_at_barrier(self, barrier_id: str, agent_id: str, data: Any = None) -> bool: """ Signal that an agent has arrived at a barrier. Returns True if all participants have arrived, False otherwise. """ with self.barrier_lock: if barrier_id not in self.barriers: self.logger.warning(f"Barrier {barrier_id} not found") return False barrier = self.barriers[barrier_id] # Check if agent is a participant if agent_id not in barrier["participants"]: self.logger.warning(f"Agent {agent_id} is not a participant in barrier {barrier_id}") return False # Record arrival barrier["arrived"].add(agent_id) # Store data if provided if data is not None: if "data" not in barrier: barrier["data"] = {} barrier["data"][agent_id] = data # Check if all participants have arrived all_arrived = barrier["arrived"] == barrier["participants"] if all_arrived: # Signal all waiting threads barrier["event"].set() self.logger.info(f"All participants arrived at barrier {barrier_id}") self.logger.info(f" barrier participants: {barrier['participants']}") # Notify agents via their event queues for participant in barrier["participants"]: if participant in self.event_queues: self.event_queues[participant].put({ "type": "barrier_complete", "barrier_id": barrier_id, "timestamp": datetime.now().isoformat() }) # Execute callbacks if registered callback_key = f"barrier:{barrier_id}" if callback_key in self.callbacks: try: self.callbacks[callback_key](barrier_id, barrier.get("data", {})) except Exception as e: self.logger.error(f"Error executing callback for {callback_key}: {e}") return all_arrived def wait_for_barrier(self, barrier_id: str, timeout: Optional[float] = None) -> bool: """ Wait until all participants arrive at the barrier. Returns True if all arrived, False on timeout. """ with self.barrier_lock: if barrier_id not in self.barriers: self.logger.warning(f"Barrier {barrier_id} not found") return False barrier = self.barriers[barrier_id] event = barrier["event"] # Release lock before waiting return event.wait(timeout) def get_barrier_data(self, barrier_id: str) -> Dict[str, Any]: """Get data provided by agents at a barrier.""" with self.barrier_lock: if barrier_id not in self.barriers: self.logger.warning(f"Barrier {barrier_id} not found") return {} barrier = self.barriers[barrier_id] return barrier.get("data", {}).copy() def signal_completion(self, agent_id: str, task_id: str, result: Any = None) -> None: """Signal that an agent has completed a task.""" with self.signal_lock: key = f"{agent_id}:{task_id}" self.completion_signals[key] = { "completed_at": datetime.now().isoformat(), "result": result } # Check if any dependencies are waiting on this for dep_key, deps in self.dependencies.items(): if key in deps["waiting_on"]: deps["waiting_on"].remove(key) # If no more dependencies, signal completion if not deps["waiting_on"]: waiting_agent, waiting_task = dep_key.split(":", 1) # Notify the waiting agent if waiting_agent in self.event_queues: self.event_queues[waiting_agent].put({ "type": "dependencies_met", "task_id": waiting_task, "timestamp": datetime.now().isoformat() }) # Execute callback if registered callback_key = f"dependency:{dep_key}" if callback_key in self.callbacks: try: self.callbacks[callback_key](dep_key) except Exception as e: self.logger.error(f"Error executing callback for {callback_key}: {e}") self.logger.info(f"Agent {agent_id} completed task {task_id}") def wait_for_completion(self, agent_id: str, task_id: str, timeout: Optional[float] = None) -> Tuple[bool, Any]: """ Wait for an agent to complete a task. Returns (completed, result) tuple. """ key = f"{agent_id}:{task_id}" end_time = time.time() + timeout if timeout is not None else None while True: with self.signal_lock: if key in self.completion_signals: return True, self.completion_signals[key].get("result") # Check timeout if end_time is not None and time.time() >= end_time: return False, None # Sleep briefly to avoid busy waiting time.sleep(0.1) def register_dependencies(self, agent_id: str, task_id: str, dependencies: List[Tuple[str, str]]) -> None: """ Register that a task depends on completion of other tasks. Dependencies are specified as [(agent_id, task_id), ...]. """ waiting_on = set() for dep_agent, dep_task in dependencies: dep_key = f"{dep_agent}:{dep_task}" # If dependency is already completed, don't wait on it with self.signal_lock: if dep_key not in self.completion_signals: waiting_on.add(dep_key) key = f"{agent_id}:{task_id}" with self.signal_lock: self.dependencies[key] = { "waiting_on": waiting_on, "registered_at": datetime.now().isoformat() } self.logger.info(f"Registered {len(waiting_on)} dependencies for {key}") def are_dependencies_met(self, agent_id: str, task_id: str) -> bool: """Check if all dependencies for a task are met.""" key = f"{agent_id}:{task_id}" with self.signal_lock: if key not in self.dependencies: return True # No dependencies registered return len(self.dependencies[key]["waiting_on"]) == 0 def register_callback(self, event_type: str, event_id: str, callback: Callable) -> None: """Register a callback function for a specific event.""" key = f"{event_type}:{event_id}" self.callbacks[key] = callback self.logger.info(f"Registered callback for {key}") def get_events(self, agent_id: str, timeout: Optional[float] = 0) -> List[Dict[str, Any]]: """ Get events for an agent from its queue. Non-blocking by default (timeout=0). """ if agent_id not in self.event_queues: return [] events = [] queue = self.event_queues[agent_id] try: # Try to get at least one event event = queue.get(block=(timeout is not None), timeout=timeout) events.append(event) # Get any additional events that are already in the queue while not queue.empty(): try: event = queue.get_nowait() events.append(event) except queue.Empty: break except queue.Empty: pass return events def send_event(self, agent_id: str, event_type: str, data: Any = None) -> bool: """Send a custom event to an agent.""" if agent_id not in self.event_queues: self.logger.warning(f"Agent {agent_id} not registered") return False event = { "type": event_type, "timestamp": datetime.now().isoformat(), "data": data } self.event_queues[agent_id].put(event) self.logger.info(f"Sent {event_type} event to {agent_id}") return True def cleanup(self, older_than_hours: Optional[int] = None) -> Dict[str, int]: """ Clean up old synchronization artifacts. Returns counts of items removed. """ if older_than_hours is None: # Clear everything with self.barrier_lock: barriers_count = len(self.barriers) self.barriers = {} with self.signal_lock: signals_count = len(self.completion_signals) dependencies_count = len(self.dependencies) self.completion_signals = {} self.dependencies = {} return { "barriers": barriers_count, "signals": signals_count, "dependencies": dependencies_count } # Calculate cutoff time cutoff = datetime.now() - timedelta(hours=older_than_hours) cutoff_str = cutoff.isoformat() # Clean up old barriers barriers_removed = 0 with self.barrier_lock: to_remove = [] for barrier_id, barrier in self.barriers.items(): if barrier["created_at"] < cutoff_str: to_remove.append(barrier_id) for barrier_id in to_remove: del self.barriers[barrier_id] barriers_removed += 1 # Clean up old signals and dependencies signals_removed = 0 dependencies_removed = 0 with self.signal_lock: # Clean signals to_remove = [] for key, signal in self.completion_signals.items(): if signal["completed_at"] < cutoff_str: to_remove.append(key) for key in to_remove: del self.completion_signals[key] signals_removed += 1 # Clean dependencies to_remove = [] for key, dep in self.dependencies.items(): if dep["registered_at"] < cutoff_str: to_remove.append(key) for key in to_remove: del self.dependencies[key] dependencies_removed += 1 return { "barriers": barriers_removed, "signals": signals_removed, "dependencies": dependencies_removed }