From 1975f759f99912f3d26742b57d1d7f842c5998b6 Mon Sep 17 00:00:00 2001 From: Manas Date: Sun, 22 Dec 2024 18:14:49 +0000 Subject: [PATCH] apply styling with black and isort --- src/diart/console/client.py | 6 ++-- src/diart/console/serve.py | 2 +- src/diart/handler.py | 69 ++++++++++++++++++++----------------- src/diart/sources.py | 4 +-- 4 files changed, 43 insertions(+), 38 deletions(-) diff --git a/src/diart/console/client.py b/src/diart/console/client.py index 578cf82..86c9e29 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -1,7 +1,7 @@ import argparse from pathlib import Path from threading import Thread -from typing import Text, Optional +from typing import Optional, Text import rx.operators as ops from websocket import WebSocket @@ -66,7 +66,7 @@ def run(): # Run websocket client ws = WebSocket() ws.connect(f"ws://{args.host}:{args.port}") - + # Wait for READY signal from server print("Waiting for server to be ready...", end="", flush=True) while True: @@ -75,7 +75,7 @@ def run(): print(" OK") break print(f"\nUnexpected message while waiting for READY: {message}") - + sender = Thread( target=send_audio, args=[ws, args.source, args.step, args.sample_rate] ) diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index e7bac14..78eb50b 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -6,7 +6,7 @@ from diart import argdoc from diart import models as m from diart import utils -from diart.handler import StreamingInferenceHandler, StreamingInferenceConfig +from diart.handler import StreamingInferenceConfig, StreamingInferenceHandler def run(): diff --git a/src/diart/handler.py b/src/diart/handler.py index 788c435..baf72e0 100644 --- a/src/diart/handler.py +++ b/src/diart/handler.py @@ -1,9 +1,10 @@ +import logging +import socket from dataclasses import dataclass from pathlib import Path -from typing import Union, Text, Optional, AnyStr, Dict, Any, Callable -import logging +from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union + from websocket_server import WebsocketServer -import socket from . import blocks from . import sources as src @@ -12,8 +13,7 @@ # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -29,6 +29,7 @@ class WebSocketAudioSourceConfig: sample_rate : int Audio sample rate in Hz """ + uri: str sample_rate: int = 16000 @@ -52,6 +53,7 @@ class StreamingInferenceConfig: progress_bar : Optional[ProgressBar] Custom progress bar implementation """ + pipeline: blocks.Pipeline batch_size: int = 1 do_profile: bool = True @@ -63,6 +65,7 @@ class StreamingInferenceConfig: @dataclass class ClientState: """Represents the state of a connected client.""" + audio_source: src.WebSocketAudioSource inference: StreamingInference @@ -102,7 +105,7 @@ def __init__( self.sample_rate = sample_rate self.host = host self.port = port - + # Server configuration self.uri = f"{host}:{port}" self._clients: Dict[Text, ClientState] = {} @@ -132,16 +135,16 @@ def _create_client_state(self, client_id: Text) -> ClientState: """ # Create a new pipeline instance with the same config # This ensures each client has its own state while sharing model weights - pipeline = self.inference_config.pipeline.__class__(self.inference_config.pipeline.config) - + pipeline = self.inference_config.pipeline.__class__( + self.inference_config.pipeline.config + ) + audio_config = WebSocketAudioSourceConfig( - uri=f"{self.uri}:{client_id}", - sample_rate=self.sample_rate + uri=f"{self.uri}:{client_id}", sample_rate=self.sample_rate ) - + audio_source = src.WebSocketAudioSource( - uri=audio_config.uri, - sample_rate=audio_config.sample_rate + uri=audio_config.uri, sample_rate=audio_config.sample_rate ) inference = StreamingInference( @@ -151,7 +154,7 @@ def _create_client_state(self, client_id: Text) -> ClientState: do_profile=self.inference_config.do_profile, do_plot=self.inference_config.do_plot, show_progress=self.inference_config.show_progress, - progress_bar=self.inference_config.progress_bar + progress_bar=self.inference_config.progress_bar, ) return ClientState(audio_source=audio_source, inference=inference) @@ -182,7 +185,7 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None: # Start inference client_state.inference() logger.info(f"Started inference for client: {client_id}") - + # Send ready notification to client self.send(client_id, "READY") except Exception as e: @@ -204,10 +207,7 @@ def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> No self.close(client_id) def _on_message_received( - self, - client: Dict[Text, Any], - server: WebsocketServer, - message: AnyStr + self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr ) -> None: """Process incoming client messages. @@ -245,16 +245,15 @@ def send(self, client_id: Text, message: AnyStr) -> None: if not message: return - client = next( - (c for c in self.server.clients if c["id"] == client_id), - None - ) - + client = next((c for c in self.server.clients if c["id"] == client_id), None) + if client is not None: try: self.server.send_message(client, message) except (socket.error, ConnectionError) as e: - logger.warning(f"Client {client_id} disconnected while sending message: {e}") + logger.warning( + f"Client {client_id} disconnected while sending message: {e}" + ) self.close(client_id) except Exception as e: logger.error(f"Failed to send message to client {client_id}: {e}") @@ -264,7 +263,7 @@ def run(self) -> None: logger.info(f"Starting WebSocket server on {self.uri}") max_retries = 3 retry_count = 0 - + while retry_count < max_retries: try: self.server.run_forever() @@ -273,7 +272,9 @@ def run(self) -> None: logger.warning(f"WebSocket connection error: {e}") retry_count += 1 if retry_count < max_retries: - logger.info(f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})") + logger.info( + f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})" + ) else: logger.error("Max retry attempts reached. Server shutting down.") except Exception as e: @@ -295,20 +296,24 @@ def close(self, client_id: Text) -> None: # Clean up pipeline state using built-in reset method client_state = self._clients[client_id] client_state.inference.pipeline.reset() - + # Close audio source and remove client client_state.audio_source.close() del self._clients[client_id] - + # Try to send a close frame to the client try: - client = next((c for c in self.server.clients if c["id"] == client_id), None) + client = next( + (c for c in self.server.clients if c["id"] == client_id), None + ) if client: self.server.send_message(client, "CLOSE") except Exception: pass # Ignore errors when trying to send close message - - logger.info(f"Closed connection and cleaned up state for client: {client_id}") + + logger.info( + f"Closed connection and cleaned up state for client: {client_id}" + ) except Exception as e: logger.error(f"Error closing client {client_id}: {e}") # Ensure client is removed from dictionary even if cleanup fails diff --git a/src/diart/sources.py b/src/diart/sources.py index 5c93ff9..5939b39 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from pathlib import Path from queue import SimpleQueue -from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple +from typing import Any, AnyStr, Dict, Optional, Text, Tuple, Union import numpy as np import sounddevice as sd @@ -12,7 +12,7 @@ from websocket_server import WebsocketServer from . import utils -from .audio import FilePath, AudioLoader +from .audio import AudioLoader, FilePath class AudioSource(ABC):