Spaces:
Sleeping
Sleeping
# app/synchronizer.py | |
import logging | |
import time | |
import threading | |
from typing import Dict, List, Optional, Tuple, Union, Any, Callable | |
from datetime import datetime, timedelta | |
import queue | |
import uuid | |
class Synchronizer: | |
def __init__(self): | |
"""Initialize the Synchronizer for coordinating agent interactions.""" | |
self.logger = logging.getLogger(__name__) | |
# Synchronization barriers | |
self.barriers = {} | |
# Completion signals | |
self.completion_signals = {} | |
# Dependency tracking | |
self.dependencies = {} | |
# Locks for thread safety | |
self.barrier_lock = threading.RLock() | |
self.signal_lock = threading.RLock() | |
# Event queues for agents | |
self.event_queues = {} | |
# Callback registry | |
self.callbacks = {} | |
def register_agent(self, agent_id: str) -> None: | |
"""Register an agent with the synchronizer.""" | |
with self.barrier_lock: | |
if agent_id not in self.event_queues: | |
self.event_queues[agent_id] = queue.Queue() | |
self.logger.info(f"Registered agent: {agent_id}") | |
def create_barrier(self, barrier_id: str, participants: List[str]) -> str: | |
""" | |
Create a synchronization barrier that multiple agents must reach. | |
Returns a unique barrier instance ID. | |
""" | |
instance_id = f"{barrier_id}_{uuid.uuid4().hex[:8]}" | |
with self.barrier_lock: | |
self.barriers[instance_id] = { | |
"participants": set(participants), | |
"arrived": set(), | |
"event": threading.Event(), | |
"created_at": datetime.now().isoformat() | |
} | |
self.logger.info(f"Created barrier {instance_id} with {len(participants)} participants") | |
return instance_id | |
def arrive_at_barrier(self, barrier_id: str, agent_id: str, data: Any = None) -> bool: | |
""" | |
Signal that an agent has arrived at a barrier. | |
Returns True if all participants have arrived, False otherwise. | |
""" | |
with self.barrier_lock: | |
if barrier_id not in self.barriers: | |
self.logger.warning(f"Barrier {barrier_id} not found") | |
return False | |
barrier = self.barriers[barrier_id] | |
# Check if agent is a participant | |
if agent_id not in barrier["participants"]: | |
self.logger.warning(f"Agent {agent_id} is not a participant in barrier {barrier_id}") | |
return False | |
# Record arrival | |
barrier["arrived"].add(agent_id) | |
# Store data if provided | |
if data is not None: | |
if "data" not in barrier: | |
barrier["data"] = {} | |
barrier["data"][agent_id] = data | |
# Check if all participants have arrived | |
all_arrived = barrier["arrived"] == barrier["participants"] | |
if all_arrived: | |
# Signal all waiting threads | |
barrier["event"].set() | |
self.logger.info(f"All participants arrived at barrier {barrier_id}") | |
self.logger.info(f" barrier participants: {barrier['participants']}") | |
# Notify agents via their event queues | |
for participant in barrier["participants"]: | |
if participant in self.event_queues: | |
self.event_queues[participant].put({ | |
"type": "barrier_complete", | |
"barrier_id": barrier_id, | |
"timestamp": datetime.now().isoformat() | |
}) | |
# Execute callbacks if registered | |
callback_key = f"barrier:{barrier_id}" | |
if callback_key in self.callbacks: | |
try: | |
self.callbacks[callback_key](barrier_id, barrier.get("data", {})) | |
except Exception as e: | |
self.logger.error(f"Error executing callback for {callback_key}: {e}") | |
return all_arrived | |
def wait_for_barrier(self, barrier_id: str, timeout: Optional[float] = None) -> bool: | |
""" | |
Wait until all participants arrive at the barrier. | |
Returns True if all arrived, False on timeout. | |
""" | |
with self.barrier_lock: | |
if barrier_id not in self.barriers: | |
self.logger.warning(f"Barrier {barrier_id} not found") | |
return False | |
barrier = self.barriers[barrier_id] | |
event = barrier["event"] | |
# Release lock before waiting | |
return event.wait(timeout) | |
def get_barrier_data(self, barrier_id: str) -> Dict[str, Any]: | |
"""Get data provided by agents at a barrier.""" | |
with self.barrier_lock: | |
if barrier_id not in self.barriers: | |
self.logger.warning(f"Barrier {barrier_id} not found") | |
return {} | |
barrier = self.barriers[barrier_id] | |
return barrier.get("data", {}).copy() | |
def signal_completion(self, agent_id: str, task_id: str, result: Any = None) -> None: | |
"""Signal that an agent has completed a task.""" | |
with self.signal_lock: | |
key = f"{agent_id}:{task_id}" | |
self.completion_signals[key] = { | |
"completed_at": datetime.now().isoformat(), | |
"result": result | |
} | |
# Check if any dependencies are waiting on this | |
for dep_key, deps in self.dependencies.items(): | |
if key in deps["waiting_on"]: | |
deps["waiting_on"].remove(key) | |
# If no more dependencies, signal completion | |
if not deps["waiting_on"]: | |
waiting_agent, waiting_task = dep_key.split(":", 1) | |
# Notify the waiting agent | |
if waiting_agent in self.event_queues: | |
self.event_queues[waiting_agent].put({ | |
"type": "dependencies_met", | |
"task_id": waiting_task, | |
"timestamp": datetime.now().isoformat() | |
}) | |
# Execute callback if registered | |
callback_key = f"dependency:{dep_key}" | |
if callback_key in self.callbacks: | |
try: | |
self.callbacks[callback_key](dep_key) | |
except Exception as e: | |
self.logger.error(f"Error executing callback for {callback_key}: {e}") | |
self.logger.info(f"Agent {agent_id} completed task {task_id}") | |
def wait_for_completion(self, agent_id: str, task_id: str, | |
timeout: Optional[float] = None) -> Tuple[bool, Any]: | |
""" | |
Wait for an agent to complete a task. | |
Returns (completed, result) tuple. | |
""" | |
key = f"{agent_id}:{task_id}" | |
end_time = time.time() + timeout if timeout is not None else None | |
while True: | |
with self.signal_lock: | |
if key in self.completion_signals: | |
return True, self.completion_signals[key].get("result") | |
# Check timeout | |
if end_time is not None and time.time() >= end_time: | |
return False, None | |
# Sleep briefly to avoid busy waiting | |
time.sleep(0.1) | |
def register_dependencies(self, agent_id: str, task_id: str, | |
dependencies: List[Tuple[str, str]]) -> None: | |
""" | |
Register that a task depends on completion of other tasks. | |
Dependencies are specified as [(agent_id, task_id), ...]. | |
""" | |
waiting_on = set() | |
for dep_agent, dep_task in dependencies: | |
dep_key = f"{dep_agent}:{dep_task}" | |
# If dependency is already completed, don't wait on it | |
with self.signal_lock: | |
if dep_key not in self.completion_signals: | |
waiting_on.add(dep_key) | |
key = f"{agent_id}:{task_id}" | |
with self.signal_lock: | |
self.dependencies[key] = { | |
"waiting_on": waiting_on, | |
"registered_at": datetime.now().isoformat() | |
} | |
self.logger.info(f"Registered {len(waiting_on)} dependencies for {key}") | |
def are_dependencies_met(self, agent_id: str, task_id: str) -> bool: | |
"""Check if all dependencies for a task are met.""" | |
key = f"{agent_id}:{task_id}" | |
with self.signal_lock: | |
if key not in self.dependencies: | |
return True # No dependencies registered | |
return len(self.dependencies[key]["waiting_on"]) == 0 | |
def register_callback(self, event_type: str, event_id: str, callback: Callable) -> None: | |
"""Register a callback function for a specific event.""" | |
key = f"{event_type}:{event_id}" | |
self.callbacks[key] = callback | |
self.logger.info(f"Registered callback for {key}") | |
def get_events(self, agent_id: str, timeout: Optional[float] = 0) -> List[Dict[str, Any]]: | |
""" | |
Get events for an agent from its queue. | |
Non-blocking by default (timeout=0). | |
""" | |
if agent_id not in self.event_queues: | |
return [] | |
events = [] | |
queue = self.event_queues[agent_id] | |
try: | |
# Try to get at least one event | |
event = queue.get(block=(timeout is not None), timeout=timeout) | |
events.append(event) | |
# Get any additional events that are already in the queue | |
while not queue.empty(): | |
try: | |
event = queue.get_nowait() | |
events.append(event) | |
except queue.Empty: | |
break | |
except queue.Empty: | |
pass | |
return events | |
def send_event(self, agent_id: str, event_type: str, data: Any = None) -> bool: | |
"""Send a custom event to an agent.""" | |
if agent_id not in self.event_queues: | |
self.logger.warning(f"Agent {agent_id} not registered") | |
return False | |
event = { | |
"type": event_type, | |
"timestamp": datetime.now().isoformat(), | |
"data": data | |
} | |
self.event_queues[agent_id].put(event) | |
self.logger.info(f"Sent {event_type} event to {agent_id}") | |
return True | |
def cleanup(self, older_than_hours: Optional[int] = None) -> Dict[str, int]: | |
""" | |
Clean up old synchronization artifacts. | |
Returns counts of items removed. | |
""" | |
if older_than_hours is None: | |
# Clear everything | |
with self.barrier_lock: | |
barriers_count = len(self.barriers) | |
self.barriers = {} | |
with self.signal_lock: | |
signals_count = len(self.completion_signals) | |
dependencies_count = len(self.dependencies) | |
self.completion_signals = {} | |
self.dependencies = {} | |
return { | |
"barriers": barriers_count, | |
"signals": signals_count, | |
"dependencies": dependencies_count | |
} | |
# Calculate cutoff time | |
cutoff = datetime.now() - timedelta(hours=older_than_hours) | |
cutoff_str = cutoff.isoformat() | |
# Clean up old barriers | |
barriers_removed = 0 | |
with self.barrier_lock: | |
to_remove = [] | |
for barrier_id, barrier in self.barriers.items(): | |
if barrier["created_at"] < cutoff_str: | |
to_remove.append(barrier_id) | |
for barrier_id in to_remove: | |
del self.barriers[barrier_id] | |
barriers_removed += 1 | |
# Clean up old signals and dependencies | |
signals_removed = 0 | |
dependencies_removed = 0 | |
with self.signal_lock: | |
# Clean signals | |
to_remove = [] | |
for key, signal in self.completion_signals.items(): | |
if signal["completed_at"] < cutoff_str: | |
to_remove.append(key) | |
for key in to_remove: | |
del self.completion_signals[key] | |
signals_removed += 1 | |
# Clean dependencies | |
to_remove = [] | |
for key, dep in self.dependencies.items(): | |
if dep["registered_at"] < cutoff_str: | |
to_remove.append(key) | |
for key in to_remove: | |
del self.dependencies[key] | |
dependencies_removed += 1 | |
return { | |
"barriers": barriers_removed, | |
"signals": signals_removed, | |
"dependencies": dependencies_removed | |
} | |