Spaces:
Sleeping
Sleeping
File size: 13,578 Bytes
95d4fad 7bdd0ae 95d4fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
# app/synchronizer.py
import logging
import time
import threading
from typing import Dict, List, Optional, Tuple, Union, Any, Callable
from datetime import datetime, timedelta
import queue
import uuid
class Synchronizer:
def __init__(self):
"""Initialize the Synchronizer for coordinating agent interactions."""
self.logger = logging.getLogger(__name__)
# Synchronization barriers
self.barriers = {}
# Completion signals
self.completion_signals = {}
# Dependency tracking
self.dependencies = {}
# Locks for thread safety
self.barrier_lock = threading.RLock()
self.signal_lock = threading.RLock()
# Event queues for agents
self.event_queues = {}
# Callback registry
self.callbacks = {}
def register_agent(self, agent_id: str) -> None:
"""Register an agent with the synchronizer."""
with self.barrier_lock:
if agent_id not in self.event_queues:
self.event_queues[agent_id] = queue.Queue()
self.logger.info(f"Registered agent: {agent_id}")
def create_barrier(self, barrier_id: str, participants: List[str]) -> str:
"""
Create a synchronization barrier that multiple agents must reach.
Returns a unique barrier instance ID.
"""
instance_id = f"{barrier_id}_{uuid.uuid4().hex[:8]}"
with self.barrier_lock:
self.barriers[instance_id] = {
"participants": set(participants),
"arrived": set(),
"event": threading.Event(),
"created_at": datetime.now().isoformat()
}
self.logger.info(f"Created barrier {instance_id} with {len(participants)} participants")
return instance_id
def arrive_at_barrier(self, barrier_id: str, agent_id: str, data: Any = None) -> bool:
"""
Signal that an agent has arrived at a barrier.
Returns True if all participants have arrived, False otherwise.
"""
with self.barrier_lock:
if barrier_id not in self.barriers:
self.logger.warning(f"Barrier {barrier_id} not found")
return False
barrier = self.barriers[barrier_id]
# Check if agent is a participant
if agent_id not in barrier["participants"]:
self.logger.warning(f"Agent {agent_id} is not a participant in barrier {barrier_id}")
return False
# Record arrival
barrier["arrived"].add(agent_id)
# Store data if provided
if data is not None:
if "data" not in barrier:
barrier["data"] = {}
barrier["data"][agent_id] = data
# Check if all participants have arrived
all_arrived = barrier["arrived"] == barrier["participants"]
if all_arrived:
# Signal all waiting threads
barrier["event"].set()
self.logger.info(f"All participants arrived at barrier {barrier_id}")
self.logger.info(f" barrier participants: {barrier['participants']}")
# Notify agents via their event queues
for participant in barrier["participants"]:
if participant in self.event_queues:
self.event_queues[participant].put({
"type": "barrier_complete",
"barrier_id": barrier_id,
"timestamp": datetime.now().isoformat()
})
# Execute callbacks if registered
callback_key = f"barrier:{barrier_id}"
if callback_key in self.callbacks:
try:
self.callbacks[callback_key](barrier_id, barrier.get("data", {}))
except Exception as e:
self.logger.error(f"Error executing callback for {callback_key}: {e}")
return all_arrived
def wait_for_barrier(self, barrier_id: str, timeout: Optional[float] = None) -> bool:
"""
Wait until all participants arrive at the barrier.
Returns True if all arrived, False on timeout.
"""
with self.barrier_lock:
if barrier_id not in self.barriers:
self.logger.warning(f"Barrier {barrier_id} not found")
return False
barrier = self.barriers[barrier_id]
event = barrier["event"]
# Release lock before waiting
return event.wait(timeout)
def get_barrier_data(self, barrier_id: str) -> Dict[str, Any]:
"""Get data provided by agents at a barrier."""
with self.barrier_lock:
if barrier_id not in self.barriers:
self.logger.warning(f"Barrier {barrier_id} not found")
return {}
barrier = self.barriers[barrier_id]
return barrier.get("data", {}).copy()
def signal_completion(self, agent_id: str, task_id: str, result: Any = None) -> None:
"""Signal that an agent has completed a task."""
with self.signal_lock:
key = f"{agent_id}:{task_id}"
self.completion_signals[key] = {
"completed_at": datetime.now().isoformat(),
"result": result
}
# Check if any dependencies are waiting on this
for dep_key, deps in self.dependencies.items():
if key in deps["waiting_on"]:
deps["waiting_on"].remove(key)
# If no more dependencies, signal completion
if not deps["waiting_on"]:
waiting_agent, waiting_task = dep_key.split(":", 1)
# Notify the waiting agent
if waiting_agent in self.event_queues:
self.event_queues[waiting_agent].put({
"type": "dependencies_met",
"task_id": waiting_task,
"timestamp": datetime.now().isoformat()
})
# Execute callback if registered
callback_key = f"dependency:{dep_key}"
if callback_key in self.callbacks:
try:
self.callbacks[callback_key](dep_key)
except Exception as e:
self.logger.error(f"Error executing callback for {callback_key}: {e}")
self.logger.info(f"Agent {agent_id} completed task {task_id}")
def wait_for_completion(self, agent_id: str, task_id: str,
timeout: Optional[float] = None) -> Tuple[bool, Any]:
"""
Wait for an agent to complete a task.
Returns (completed, result) tuple.
"""
key = f"{agent_id}:{task_id}"
end_time = time.time() + timeout if timeout is not None else None
while True:
with self.signal_lock:
if key in self.completion_signals:
return True, self.completion_signals[key].get("result")
# Check timeout
if end_time is not None and time.time() >= end_time:
return False, None
# Sleep briefly to avoid busy waiting
time.sleep(0.1)
def register_dependencies(self, agent_id: str, task_id: str,
dependencies: List[Tuple[str, str]]) -> None:
"""
Register that a task depends on completion of other tasks.
Dependencies are specified as [(agent_id, task_id), ...].
"""
waiting_on = set()
for dep_agent, dep_task in dependencies:
dep_key = f"{dep_agent}:{dep_task}"
# If dependency is already completed, don't wait on it
with self.signal_lock:
if dep_key not in self.completion_signals:
waiting_on.add(dep_key)
key = f"{agent_id}:{task_id}"
with self.signal_lock:
self.dependencies[key] = {
"waiting_on": waiting_on,
"registered_at": datetime.now().isoformat()
}
self.logger.info(f"Registered {len(waiting_on)} dependencies for {key}")
def are_dependencies_met(self, agent_id: str, task_id: str) -> bool:
"""Check if all dependencies for a task are met."""
key = f"{agent_id}:{task_id}"
with self.signal_lock:
if key not in self.dependencies:
return True # No dependencies registered
return len(self.dependencies[key]["waiting_on"]) == 0
def register_callback(self, event_type: str, event_id: str, callback: Callable) -> None:
"""Register a callback function for a specific event."""
key = f"{event_type}:{event_id}"
self.callbacks[key] = callback
self.logger.info(f"Registered callback for {key}")
def get_events(self, agent_id: str, timeout: Optional[float] = 0) -> List[Dict[str, Any]]:
"""
Get events for an agent from its queue.
Non-blocking by default (timeout=0).
"""
if agent_id not in self.event_queues:
return []
events = []
queue = self.event_queues[agent_id]
try:
# Try to get at least one event
event = queue.get(block=(timeout is not None), timeout=timeout)
events.append(event)
# Get any additional events that are already in the queue
while not queue.empty():
try:
event = queue.get_nowait()
events.append(event)
except queue.Empty:
break
except queue.Empty:
pass
return events
def send_event(self, agent_id: str, event_type: str, data: Any = None) -> bool:
"""Send a custom event to an agent."""
if agent_id not in self.event_queues:
self.logger.warning(f"Agent {agent_id} not registered")
return False
event = {
"type": event_type,
"timestamp": datetime.now().isoformat(),
"data": data
}
self.event_queues[agent_id].put(event)
self.logger.info(f"Sent {event_type} event to {agent_id}")
return True
def cleanup(self, older_than_hours: Optional[int] = None) -> Dict[str, int]:
"""
Clean up old synchronization artifacts.
Returns counts of items removed.
"""
if older_than_hours is None:
# Clear everything
with self.barrier_lock:
barriers_count = len(self.barriers)
self.barriers = {}
with self.signal_lock:
signals_count = len(self.completion_signals)
dependencies_count = len(self.dependencies)
self.completion_signals = {}
self.dependencies = {}
return {
"barriers": barriers_count,
"signals": signals_count,
"dependencies": dependencies_count
}
# Calculate cutoff time
cutoff = datetime.now() - timedelta(hours=older_than_hours)
cutoff_str = cutoff.isoformat()
# Clean up old barriers
barriers_removed = 0
with self.barrier_lock:
to_remove = []
for barrier_id, barrier in self.barriers.items():
if barrier["created_at"] < cutoff_str:
to_remove.append(barrier_id)
for barrier_id in to_remove:
del self.barriers[barrier_id]
barriers_removed += 1
# Clean up old signals and dependencies
signals_removed = 0
dependencies_removed = 0
with self.signal_lock:
# Clean signals
to_remove = []
for key, signal in self.completion_signals.items():
if signal["completed_at"] < cutoff_str:
to_remove.append(key)
for key in to_remove:
del self.completion_signals[key]
signals_removed += 1
# Clean dependencies
to_remove = []
for key, dep in self.dependencies.items():
if dep["registered_at"] < cutoff_str:
to_remove.append(key)
for key in to_remove:
del self.dependencies[key]
dependencies_removed += 1
return {
"barriers": barriers_removed,
"signals": signals_removed,
"dependencies": dependencies_removed
}
|