Skip to content

Commit

Permalink
improved code quality and style
Browse files Browse the repository at this point in the history
  • Loading branch information
janaab11 committed Jan 3, 2025
1 parent 3d3bb45 commit b2c9293
Showing 1 changed file with 75 additions and 60 deletions.
135 changes: 75 additions & 60 deletions src/diart/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,26 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
client_id = client["id"]
logger.info(f"New client connected: {client_id}")

if client_id not in self._clients:
try:
self._clients[client_id] = self._create_client_state(client_id)
if client_id in self._clients:
return

# Setup RTTM response hook
self._clients[client_id].inference.attach_hooks(
lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm())
)
try:
self._clients[client_id] = self._create_client_state(client_id)

# Start inference
self._clients[client_id].inference()
logger.info(f"Started inference for client: {client_id}")
# Setup RTTM response hook
self._clients[client_id].inference.attach_hooks(
lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm())
)

# Send ready notification to client
self.send(client_id, "READY")
except Exception as e:
logger.error(f"Failed to initialize client {client_id}: {e}")
self.close(client_id)
# Start inference
self._clients[client_id].inference()
logger.info(f"Started inference for client: {client_id}")

# Send ready notification to client
self.send(client_id, "READY")
except Exception as e:
logger.error(f"Failed to initialize client {client_id}: {e}")
self.close(client_id)

def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
"""Handle client disconnection.
Expand Down Expand Up @@ -172,16 +174,19 @@ def _on_message_received(
Received message data
"""
client_id = client["id"]
if client_id in self._clients:
try:
self._clients[client_id].audio_source.process_message(message)
except (socket.error, ConnectionError) as e:
logger.warning(f"Client {client_id} disconnected: {e}")
self.close(client_id)
except Exception as e:
logger.error(f"Error processing message from client {client_id}: {e}")
# Don't close the connection for non-connection related errors
# This allows the client to retry sending the message

if client_id not in self._clients:
return

try:
self._clients[client_id].audio_source.process_message(message)
except (socket.error, ConnectionError) as e:
logger.warning(f"Client {client_id} disconnected: {e}")
self.close(client_id)
except Exception as e:
logger.error(f"Error processing message from client {client_id}: {e}")
# Don't close the connection for non-connection related errors
# This allows the client to retry sending the message

def send(self, client_id: Text, message: AnyStr) -> None:
"""Send a message to a specific client.
Expand All @@ -198,16 +203,18 @@ def send(self, client_id: Text, message: AnyStr) -> None:

client = next((c for c in self.server.clients if c["id"] == client_id), None)

if client is not None:
try:
self.server.send_message(client, message)
except (socket.error, ConnectionError) as e:
logger.warning(
f"Client {client_id} disconnected while sending message: {e}"
)
self.close(client_id)
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}")
if client is None:
return

try:
self.server.send_message(client, message)
except (socket.error, ConnectionError) as e:
logger.warning(
f"Client {client_id} disconnected while sending message: {e}"
)
self.close(client_id)
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}")

def run(self) -> None:
"""Start the WebSocket server."""
Expand Down Expand Up @@ -242,39 +249,47 @@ def close(self, client_id: Text) -> None:
client_id : Text
Client identifier to close
"""
if client_id in self._clients:
try:
# Clean up pipeline state using built-in reset method
client_state = self._clients[client_id]
client_state.inference.pipeline.reset()

# Close audio source and remove client
client_state.audio_source.close()
del self._clients[client_id]

# Try to send a close frame to the client
try:
client = next(
(c for c in self.server.clients if c["id"] == client_id), None
)
if client:
self.server.send_message(client, "CLOSE")
except Exception:
pass # Ignore errors when trying to send close message
if client_id not in self._clients:
return

logger.info(
f"Closed connection and cleaned up state for client: {client_id}"
try:
# Clean up pipeline state using built-in reset method
client_state = self._clients[client_id]
client_state.inference.pipeline.reset()

# Close audio source and remove client
client_state.audio_source.close()
del self._clients[client_id]

# Try to send a close frame to the client
client = next((c for c in self.server.clients if c["id"] == client_id), None)

if client is None:
return

try:
self.server.send_message(client, "CLOSE")
except (socket.error, ConnectionError) as e:
logger.warning(
f"Client {client_id} disconnected while sending message: {e}"
)
self.close(client_id)
except Exception as e:
logger.error(f"Error closing client {client_id}: {e}")
# Ensure client is removed from dictionary even if cleanup fails
self._clients.pop(client_id, None)
logger.error(f"Failed to send message to client {client_id}: {e}")

logger.info(
f"Closed connection and cleaned up state for client: {client_id}"
)
except Exception as e:
logger.error(f"Error closing client {client_id}: {e}")
# Ensure client is removed from dictionary even if cleanup fails
self._clients.pop(client_id, None)

def close_all(self) -> None:
"""Shutdown the server and cleanup all client connections."""
logger.info("Shutting down server...")
try:
for client_id in list(self._clients.keys()):
for client_id in self._clients.keys():
self.close(client_id)
if self.server is not None:
self.server.shutdown_gracefully()
Expand Down

0 comments on commit b2c9293

Please sign in to comment.