Spaces:
Running
Running
from fastapi import WebSocket | |
from typing import List, Dict, Optional | |
import logging | |
logger = logging.getLogger(__name__) | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections: Dict[Optional[int], List[WebSocket]] = {} | |
self.websocket_to_user: Dict[WebSocket, Optional[int]] = {} | |
async def connect(self, websocket: WebSocket, user_id: Optional[int] = None): | |
await websocket.accept() | |
if user_id not in self.active_connections: | |
self.active_connections[user_id] = [] | |
self.active_connections[user_id].append(websocket) | |
self.websocket_to_user[websocket] = user_id | |
logger.info(f"WebSocket connected: {websocket.client.host}:{websocket.client.port}, User ID: {user_id}") | |
logger.info(f"Total connections: {sum(len(conns) for conns in self.active_connections.values())}") | |
def disconnect(self, websocket: WebSocket): | |
user_id = self.websocket_to_user.pop(websocket, None) | |
if user_id in self.active_connections: | |
try: | |
self.active_connections[user_id].remove(websocket) | |
if not self.active_connections[user_id]: | |
del self.active_connections[user_id] | |
except ValueError: | |
logger.warning(f"WebSocket not found in active list for user {user_id} during disconnect.") | |
logger.info(f"WebSocket disconnected: {websocket.client.host}:{websocket.client.port}, User ID: {user_id}") | |
logger.info(f"Total connections: {sum(len(conns) for conns in self.active_connections.values())}") | |
async def broadcast(self, message: str, sender_id: Optional[int] = None): | |
disconnected_websockets = [] | |
# Iterate through all connections | |
all_websockets = [ws for user_conns in self.active_connections.values() for ws in user_conns] | |
logger.info(f"Broadcasting to {len(all_websockets)} connections (excluding sender if ID matches). Sender ID: {sender_id}") | |
for websocket in all_websockets: | |
ws_user_id = self.websocket_to_user.get(websocket) | |
if ws_user_id != sender_id: | |
try: | |
await websocket.send_text(message) | |
logger.debug(f"Message sent to user {ws_user_id}") | |
except Exception as e: | |
logger.error(f"Error sending message to websocket {websocket.client}: {e}. Marking for disconnect.") | |
disconnected_websockets.append(websocket) | |
# Clean up connections that failed during broadcast | |
for ws in disconnected_websockets: | |
self.disconnect(ws) | |
manager = ConnectionManager() |