diff --git a/oidc-controller/api/routers/socketio.py b/oidc-controller/api/routers/socketio.py index 53898519..6a08e9f7 100644 --- a/oidc-controller/api/routers/socketio.py +++ b/oidc-controller/api/routers/socketio.py @@ -11,10 +11,12 @@ sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") sio_app = socketio.ASGIApp(socketio_server=sio, socketio_path="/ws/socket.io") + @sio.event async def connect(sid, socket): logger.info(f">>> connect : sid={sid}") + @sio.event async def initialize(sid, data): global connections, message_buffers @@ -24,6 +26,7 @@ async def initialize(sid, data): if pid not in message_buffers: message_buffers[pid] = [] + @sio.event async def disconnect(sid): global connections, message_buffers @@ -34,9 +37,10 @@ async def disconnect(sid): # Remove pid from connections del connections[pid] + async def buffered_emit(event, data, to_pid=None): global connections, message_buffers - + connections = connections_reload() sid = connections.get(to_pid) logger.debug(f"sid: {sid} found for pid: {to_pid}") @@ -51,6 +55,7 @@ async def buffered_emit(event, data, to_pid=None): # Buffer the message if the target is not connected buffer_message(to_pid, event, data) + def buffer_message(pid, event, data): global message_buffers current_time = time.time() @@ -60,10 +65,12 @@ def buffer_message(pid, event, data): message_buffers[pid].append((event, data, current_time)) # Clean up old messages message_buffers[pid] = [ - (msg_event, msg_data, timestamp) for msg_event, msg_data, timestamp in message_buffers[pid] + (msg_event, msg_data, timestamp) + for msg_event, msg_data, timestamp in message_buffers[pid] if current_time - timestamp <= buffer_timeout ] + @sio.event async def fetch_buffered_messages(sid, pid): global message_buffers @@ -71,7 +78,8 @@ async def fetch_buffered_messages(sid, pid): if pid in message_buffers: # Filter messages that are still valid (i.e., within the buffer_timeout) valid_messages = [ - (msg_event, msg_data, timestamp) for msg_event, msg_data, timestamp in message_buffers[pid] + (msg_event, msg_data, timestamp) + for msg_event, msg_data, timestamp in message_buffers[pid] if current_time - timestamp <= buffer_timeout ] # Emit each valid message @@ -80,6 +88,7 @@ async def fetch_buffered_messages(sid, pid): # Reassign the valid_messages back to message_buffers[pid] to clean up old messages message_buffers[pid] = valid_messages + def connections_reload(): global connections return connections