62 lines
1.6 KiB
Python
62 lines
1.6 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
from typing import Any
|
|
|
|
try:
|
|
from fastapi import WebSocket
|
|
except ModuleNotFoundError:
|
|
WebSocket = Any
|
|
|
|
from .task_constants import current_timestamp
|
|
|
|
|
|
class TaskStreamManager:
|
|
def __init__(self):
|
|
self._loop: asyncio.AbstractEventLoop | None = None
|
|
self._connections: dict[str, set[WebSocket]] = defaultdict(set)
|
|
|
|
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
|
self._loop = loop
|
|
|
|
async def connect(self, task_id: str, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self._connections[task_id].add(websocket)
|
|
|
|
def disconnect(self, task_id: str, websocket: WebSocket):
|
|
task_connections = self._connections.get(task_id)
|
|
if not task_connections:
|
|
return
|
|
|
|
task_connections.discard(websocket)
|
|
if not task_connections:
|
|
self._connections.pop(task_id, None)
|
|
|
|
def broadcast_event(self, task_id: str, event_type: str, stage: str, data: dict):
|
|
if self._loop is None:
|
|
return
|
|
|
|
payload = {
|
|
'type': event_type,
|
|
'task_id': task_id,
|
|
'stage': stage,
|
|
'timestamp': current_timestamp(),
|
|
'data': data
|
|
}
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._broadcast(task_id, payload),
|
|
self._loop
|
|
)
|
|
|
|
async def _broadcast(self, task_id: str, payload: dict):
|
|
task_connections = list(self._connections.get(task_id, set()))
|
|
disconnected: list[WebSocket] = []
|
|
|
|
for websocket in task_connections:
|
|
try:
|
|
await websocket.send_json(payload)
|
|
except Exception:
|
|
disconnected.append(websocket)
|
|
|
|
for websocket in disconnected:
|
|
self.disconnect(task_id, websocket)
|