diff --git a/src/diart/websockets.py b/src/diart/websockets.py index 7ef5765..542237a 100644 --- a/src/diart/websockets.py +++ b/src/diart/websockets.py @@ -124,24 +124,26 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None: client_id = client["id"] logger.info(f"New client connected: {client_id}") - if client_id not in self._clients: - try: - self._clients[client_id] = self._create_client_state(client_id) + if client_id in self._clients: + return - # Setup RTTM response hook - self._clients[client_id].inference.attach_hooks( - lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm()) - ) + try: + self._clients[client_id] = self._create_client_state(client_id) - # Start inference - self._clients[client_id].inference() - logger.info(f"Started inference for client: {client_id}") + # Setup RTTM response hook + self._clients[client_id].inference.attach_hooks( + lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm()) + ) - # Send ready notification to client - self.send(client_id, "READY") - except Exception as e: - logger.error(f"Failed to initialize client {client_id}: {e}") - self.close(client_id) + # Start inference + self._clients[client_id].inference() + logger.info(f"Started inference for client: {client_id}") + + # Send ready notification to client + self.send(client_id, "READY") + except Exception as e: + logger.error(f"Failed to initialize client {client_id}: {e}") + self.close(client_id) def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> None: """Handle client disconnection. @@ -172,16 +174,19 @@ def _on_message_received( Received message data """ client_id = client["id"] - if client_id in self._clients: - try: - self._clients[client_id].audio_source.process_message(message) - except (socket.error, ConnectionError) as e: - logger.warning(f"Client {client_id} disconnected: {e}") - self.close(client_id) - except Exception as e: - logger.error(f"Error processing message from client {client_id}: {e}") - # Don't close the connection for non-connection related errors - # This allows the client to retry sending the message + + if client_id not in self._clients: + return + + try: + self._clients[client_id].audio_source.process_message(message) + except (socket.error, ConnectionError) as e: + logger.warning(f"Client {client_id} disconnected: {e}") + self.close(client_id) + except Exception as e: + logger.error(f"Error processing message from client {client_id}: {e}") + # Don't close the connection for non-connection related errors + # This allows the client to retry sending the message def send(self, client_id: Text, message: AnyStr) -> None: """Send a message to a specific client. @@ -198,16 +203,18 @@ def send(self, client_id: Text, message: AnyStr) -> None: client = next((c for c in self.server.clients if c["id"] == client_id), None) - if client is not None: - try: - self.server.send_message(client, message) - except (socket.error, ConnectionError) as e: - logger.warning( - f"Client {client_id} disconnected while sending message: {e}" - ) - self.close(client_id) - except Exception as e: - logger.error(f"Failed to send message to client {client_id}: {e}") + if client is None: + return + + try: + self.server.send_message(client, message) + except (socket.error, ConnectionError) as e: + logger.warning( + f"Client {client_id} disconnected while sending message: {e}" + ) + self.close(client_id) + except Exception as e: + logger.error(f"Failed to send message to client {client_id}: {e}") def run(self) -> None: """Start the WebSocket server.""" @@ -242,39 +249,47 @@ def close(self, client_id: Text) -> None: client_id : Text Client identifier to close """ - if client_id in self._clients: - try: - # Clean up pipeline state using built-in reset method - client_state = self._clients[client_id] - client_state.inference.pipeline.reset() - - # Close audio source and remove client - client_state.audio_source.close() - del self._clients[client_id] - - # Try to send a close frame to the client - try: - client = next( - (c for c in self.server.clients if c["id"] == client_id), None - ) - if client: - self.server.send_message(client, "CLOSE") - except Exception: - pass # Ignore errors when trying to send close message + if client_id not in self._clients: + return - logger.info( - f"Closed connection and cleaned up state for client: {client_id}" + try: + # Clean up pipeline state using built-in reset method + client_state = self._clients[client_id] + client_state.inference.pipeline.reset() + + # Close audio source and remove client + client_state.audio_source.close() + del self._clients[client_id] + + # Try to send a close frame to the client + client = next((c for c in self.server.clients if c["id"] == client_id), None) + + if client is None: + return + + try: + self.server.send_message(client, "CLOSE") + except (socket.error, ConnectionError) as e: + logger.warning( + f"Client {client_id} disconnected while sending message: {e}" ) + self.close(client_id) except Exception as e: - logger.error(f"Error closing client {client_id}: {e}") - # Ensure client is removed from dictionary even if cleanup fails - self._clients.pop(client_id, None) + logger.error(f"Failed to send message to client {client_id}: {e}") + + logger.info( + f"Closed connection and cleaned up state for client: {client_id}" + ) + except Exception as e: + logger.error(f"Error closing client {client_id}: {e}") + # Ensure client is removed from dictionary even if cleanup fails + self._clients.pop(client_id, None) def close_all(self) -> None: """Shutdown the server and cleanup all client connections.""" logger.info("Shutting down server...") try: - for client_id in list(self._clients.keys()): + for client_id in self._clients.keys(): self.close(client_id) if self.server is not None: self.server.shutdown_gracefully()