import asyncio from collections import defaultdict from typing import Any try: from fastapi import WebSocket except ModuleNotFoundError: WebSocket = Any from .task_constants import current_timestamp class TaskStreamManager: def __init__(self): self._loop: asyncio.AbstractEventLoop | None = None self._connections: dict[str, set[WebSocket]] = defaultdict(set) def set_loop(self, loop: asyncio.AbstractEventLoop): self._loop = loop async def connect(self, task_id: str, websocket: WebSocket): await websocket.accept() self._connections[task_id].add(websocket) def disconnect(self, task_id: str, websocket: WebSocket): task_connections = self._connections.get(task_id) if not task_connections: return task_connections.discard(websocket) if not task_connections: self._connections.pop(task_id, None) def broadcast_event(self, task_id: str, event_type: str, stage: str, data: dict): if self._loop is None: return payload = { 'type': event_type, 'task_id': task_id, 'stage': stage, 'timestamp': current_timestamp(), 'data': data } asyncio.run_coroutine_threadsafe( self._broadcast(task_id, payload), self._loop ) async def _broadcast(self, task_id: str, payload: dict): task_connections = list(self._connections.get(task_id, set())) disconnected: list[WebSocket] = [] for websocket in task_connections: try: await websocket.send_json(payload) except Exception: disconnected.append(websocket) for websocket in disconnected: self.disconnect(task_id, websocket)