ai_agents_sustainable / app /synchronizer.py
Chamin09's picture
Update app/synchronizer.py
7bdd0ae verified
# app/synchronizer.py
import logging
import time
import threading
from typing import Dict, List, Optional, Tuple, Union, Any, Callable
from datetime import datetime, timedelta
import queue
import uuid
class Synchronizer:
def __init__(self):
"""Initialize the Synchronizer for coordinating agent interactions."""
self.logger = logging.getLogger(__name__)
# Synchronization barriers
self.barriers = {}
# Completion signals
self.completion_signals = {}
# Dependency tracking
self.dependencies = {}
# Locks for thread safety
self.barrier_lock = threading.RLock()
self.signal_lock = threading.RLock()
# Event queues for agents
self.event_queues = {}
# Callback registry
self.callbacks = {}
def register_agent(self, agent_id: str) -> None:
"""Register an agent with the synchronizer."""
with self.barrier_lock:
if agent_id not in self.event_queues:
self.event_queues[agent_id] = queue.Queue()
self.logger.info(f"Registered agent: {agent_id}")
def create_barrier(self, barrier_id: str, participants: List[str]) -> str:
"""
Create a synchronization barrier that multiple agents must reach.
Returns a unique barrier instance ID.
"""
instance_id = f"{barrier_id}_{uuid.uuid4().hex[:8]}"
with self.barrier_lock:
self.barriers[instance_id] = {
"participants": set(participants),
"arrived": set(),
"event": threading.Event(),
"created_at": datetime.now().isoformat()
}
self.logger.info(f"Created barrier {instance_id} with {len(participants)} participants")
return instance_id
def arrive_at_barrier(self, barrier_id: str, agent_id: str, data: Any = None) -> bool:
"""
Signal that an agent has arrived at a barrier.
Returns True if all participants have arrived, False otherwise.
"""
with self.barrier_lock:
if barrier_id not in self.barriers:
self.logger.warning(f"Barrier {barrier_id} not found")
return False
barrier = self.barriers[barrier_id]
# Check if agent is a participant
if agent_id not in barrier["participants"]:
self.logger.warning(f"Agent {agent_id} is not a participant in barrier {barrier_id}")
return False
# Record arrival
barrier["arrived"].add(agent_id)
# Store data if provided
if data is not None:
if "data" not in barrier:
barrier["data"] = {}
barrier["data"][agent_id] = data
# Check if all participants have arrived
all_arrived = barrier["arrived"] == barrier["participants"]
if all_arrived:
# Signal all waiting threads
barrier["event"].set()
self.logger.info(f"All participants arrived at barrier {barrier_id}")
self.logger.info(f" barrier participants: {barrier['participants']}")
# Notify agents via their event queues
for participant in barrier["participants"]:
if participant in self.event_queues:
self.event_queues[participant].put({
"type": "barrier_complete",
"barrier_id": barrier_id,
"timestamp": datetime.now().isoformat()
})
# Execute callbacks if registered
callback_key = f"barrier:{barrier_id}"
if callback_key in self.callbacks:
try:
self.callbacks[callback_key](barrier_id, barrier.get("data", {}))
except Exception as e:
self.logger.error(f"Error executing callback for {callback_key}: {e}")
return all_arrived
def wait_for_barrier(self, barrier_id: str, timeout: Optional[float] = None) -> bool:
"""
Wait until all participants arrive at the barrier.
Returns True if all arrived, False on timeout.
"""
with self.barrier_lock:
if barrier_id not in self.barriers:
self.logger.warning(f"Barrier {barrier_id} not found")
return False
barrier = self.barriers[barrier_id]
event = barrier["event"]
# Release lock before waiting
return event.wait(timeout)
def get_barrier_data(self, barrier_id: str) -> Dict[str, Any]:
"""Get data provided by agents at a barrier."""
with self.barrier_lock:
if barrier_id not in self.barriers:
self.logger.warning(f"Barrier {barrier_id} not found")
return {}
barrier = self.barriers[barrier_id]
return barrier.get("data", {}).copy()
def signal_completion(self, agent_id: str, task_id: str, result: Any = None) -> None:
"""Signal that an agent has completed a task."""
with self.signal_lock:
key = f"{agent_id}:{task_id}"
self.completion_signals[key] = {
"completed_at": datetime.now().isoformat(),
"result": result
}
# Check if any dependencies are waiting on this
for dep_key, deps in self.dependencies.items():
if key in deps["waiting_on"]:
deps["waiting_on"].remove(key)
# If no more dependencies, signal completion
if not deps["waiting_on"]:
waiting_agent, waiting_task = dep_key.split(":", 1)
# Notify the waiting agent
if waiting_agent in self.event_queues:
self.event_queues[waiting_agent].put({
"type": "dependencies_met",
"task_id": waiting_task,
"timestamp": datetime.now().isoformat()
})
# Execute callback if registered
callback_key = f"dependency:{dep_key}"
if callback_key in self.callbacks:
try:
self.callbacks[callback_key](dep_key)
except Exception as e:
self.logger.error(f"Error executing callback for {callback_key}: {e}")
self.logger.info(f"Agent {agent_id} completed task {task_id}")
def wait_for_completion(self, agent_id: str, task_id: str,
timeout: Optional[float] = None) -> Tuple[bool, Any]:
"""
Wait for an agent to complete a task.
Returns (completed, result) tuple.
"""
key = f"{agent_id}:{task_id}"
end_time = time.time() + timeout if timeout is not None else None
while True:
with self.signal_lock:
if key in self.completion_signals:
return True, self.completion_signals[key].get("result")
# Check timeout
if end_time is not None and time.time() >= end_time:
return False, None
# Sleep briefly to avoid busy waiting
time.sleep(0.1)
def register_dependencies(self, agent_id: str, task_id: str,
dependencies: List[Tuple[str, str]]) -> None:
"""
Register that a task depends on completion of other tasks.
Dependencies are specified as [(agent_id, task_id), ...].
"""
waiting_on = set()
for dep_agent, dep_task in dependencies:
dep_key = f"{dep_agent}:{dep_task}"
# If dependency is already completed, don't wait on it
with self.signal_lock:
if dep_key not in self.completion_signals:
waiting_on.add(dep_key)
key = f"{agent_id}:{task_id}"
with self.signal_lock:
self.dependencies[key] = {
"waiting_on": waiting_on,
"registered_at": datetime.now().isoformat()
}
self.logger.info(f"Registered {len(waiting_on)} dependencies for {key}")
def are_dependencies_met(self, agent_id: str, task_id: str) -> bool:
"""Check if all dependencies for a task are met."""
key = f"{agent_id}:{task_id}"
with self.signal_lock:
if key not in self.dependencies:
return True # No dependencies registered
return len(self.dependencies[key]["waiting_on"]) == 0
def register_callback(self, event_type: str, event_id: str, callback: Callable) -> None:
"""Register a callback function for a specific event."""
key = f"{event_type}:{event_id}"
self.callbacks[key] = callback
self.logger.info(f"Registered callback for {key}")
def get_events(self, agent_id: str, timeout: Optional[float] = 0) -> List[Dict[str, Any]]:
"""
Get events for an agent from its queue.
Non-blocking by default (timeout=0).
"""
if agent_id not in self.event_queues:
return []
events = []
queue = self.event_queues[agent_id]
try:
# Try to get at least one event
event = queue.get(block=(timeout is not None), timeout=timeout)
events.append(event)
# Get any additional events that are already in the queue
while not queue.empty():
try:
event = queue.get_nowait()
events.append(event)
except queue.Empty:
break
except queue.Empty:
pass
return events
def send_event(self, agent_id: str, event_type: str, data: Any = None) -> bool:
"""Send a custom event to an agent."""
if agent_id not in self.event_queues:
self.logger.warning(f"Agent {agent_id} not registered")
return False
event = {
"type": event_type,
"timestamp": datetime.now().isoformat(),
"data": data
}
self.event_queues[agent_id].put(event)
self.logger.info(f"Sent {event_type} event to {agent_id}")
return True
def cleanup(self, older_than_hours: Optional[int] = None) -> Dict[str, int]:
"""
Clean up old synchronization artifacts.
Returns counts of items removed.
"""
if older_than_hours is None:
# Clear everything
with self.barrier_lock:
barriers_count = len(self.barriers)
self.barriers = {}
with self.signal_lock:
signals_count = len(self.completion_signals)
dependencies_count = len(self.dependencies)
self.completion_signals = {}
self.dependencies = {}
return {
"barriers": barriers_count,
"signals": signals_count,
"dependencies": dependencies_count
}
# Calculate cutoff time
cutoff = datetime.now() - timedelta(hours=older_than_hours)
cutoff_str = cutoff.isoformat()
# Clean up old barriers
barriers_removed = 0
with self.barrier_lock:
to_remove = []
for barrier_id, barrier in self.barriers.items():
if barrier["created_at"] < cutoff_str:
to_remove.append(barrier_id)
for barrier_id in to_remove:
del self.barriers[barrier_id]
barriers_removed += 1
# Clean up old signals and dependencies
signals_removed = 0
dependencies_removed = 0
with self.signal_lock:
# Clean signals
to_remove = []
for key, signal in self.completion_signals.items():
if signal["completed_at"] < cutoff_str:
to_remove.append(key)
for key in to_remove:
del self.completion_signals[key]
signals_removed += 1
# Clean dependencies
to_remove = []
for key, dep in self.dependencies.items():
if dep["registered_at"] < cutoff_str:
to_remove.append(key)
for key in to_remove:
del self.dependencies[key]
dependencies_removed += 1
return {
"barriers": barriers_removed,
"signals": signals_removed,
"dependencies": dependencies_removed
}