Skip to content

Commit

Permalink
apply styling with black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
janaab11 committed Dec 22, 2024
1 parent fb9fecf commit 1975f75
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 38 deletions.
6 changes: 3 additions & 3 deletions src/diart/console/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from pathlib import Path
from threading import Thread
from typing import Text, Optional
from typing import Optional, Text

import rx.operators as ops
from websocket import WebSocket
Expand Down Expand Up @@ -66,7 +66,7 @@ def run():
# Run websocket client
ws = WebSocket()
ws.connect(f"ws://{args.host}:{args.port}")

# Wait for READY signal from server
print("Waiting for server to be ready...", end="", flush=True)
while True:
Expand All @@ -75,7 +75,7 @@ def run():
print(" OK")
break
print(f"\nUnexpected message while waiting for READY: {message}")

sender = Thread(
target=send_audio, args=[ws, args.source, args.step, args.sample_rate]
)
Expand Down
2 changes: 1 addition & 1 deletion src/diart/console/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diart import argdoc
from diart import models as m
from diart import utils
from diart.handler import StreamingInferenceHandler, StreamingInferenceConfig
from diart.handler import StreamingInferenceConfig, StreamingInferenceHandler


def run():
Expand Down
69 changes: 37 additions & 32 deletions src/diart/handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import socket
from dataclasses import dataclass
from pathlib import Path
from typing import Union, Text, Optional, AnyStr, Dict, Any, Callable
import logging
from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union

from websocket_server import WebsocketServer
import socket

from . import blocks
from . import sources as src
Expand All @@ -12,8 +13,7 @@

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

Expand All @@ -29,6 +29,7 @@ class WebSocketAudioSourceConfig:
sample_rate : int
Audio sample rate in Hz
"""

uri: str
sample_rate: int = 16000

Expand All @@ -52,6 +53,7 @@ class StreamingInferenceConfig:
progress_bar : Optional[ProgressBar]
Custom progress bar implementation
"""

pipeline: blocks.Pipeline
batch_size: int = 1
do_profile: bool = True
Expand All @@ -63,6 +65,7 @@ class StreamingInferenceConfig:
@dataclass
class ClientState:
"""Represents the state of a connected client."""

audio_source: src.WebSocketAudioSource
inference: StreamingInference

Expand Down Expand Up @@ -102,7 +105,7 @@ def __init__(
self.sample_rate = sample_rate
self.host = host
self.port = port

# Server configuration
self.uri = f"{host}:{port}"
self._clients: Dict[Text, ClientState] = {}
Expand Down Expand Up @@ -132,16 +135,16 @@ def _create_client_state(self, client_id: Text) -> ClientState:
"""
# Create a new pipeline instance with the same config
# This ensures each client has its own state while sharing model weights
pipeline = self.inference_config.pipeline.__class__(self.inference_config.pipeline.config)

pipeline = self.inference_config.pipeline.__class__(
self.inference_config.pipeline.config
)

audio_config = WebSocketAudioSourceConfig(
uri=f"{self.uri}:{client_id}",
sample_rate=self.sample_rate
uri=f"{self.uri}:{client_id}", sample_rate=self.sample_rate
)

audio_source = src.WebSocketAudioSource(
uri=audio_config.uri,
sample_rate=audio_config.sample_rate
uri=audio_config.uri, sample_rate=audio_config.sample_rate
)

inference = StreamingInference(
Expand All @@ -151,7 +154,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
do_profile=self.inference_config.do_profile,
do_plot=self.inference_config.do_plot,
show_progress=self.inference_config.show_progress,
progress_bar=self.inference_config.progress_bar
progress_bar=self.inference_config.progress_bar,
)

return ClientState(audio_source=audio_source, inference=inference)
Expand Down Expand Up @@ -182,7 +185,7 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
# Start inference
client_state.inference()
logger.info(f"Started inference for client: {client_id}")

# Send ready notification to client
self.send(client_id, "READY")
except Exception as e:
Expand All @@ -204,10 +207,7 @@ def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> No
self.close(client_id)

def _on_message_received(
self,
client: Dict[Text, Any],
server: WebsocketServer,
message: AnyStr
self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr
) -> None:
"""Process incoming client messages.
Expand Down Expand Up @@ -245,16 +245,15 @@ def send(self, client_id: Text, message: AnyStr) -> None:
if not message:
return

client = next(
(c for c in self.server.clients if c["id"] == client_id),
None
)

client = next((c for c in self.server.clients if c["id"] == client_id), None)

if client is not None:
try:
self.server.send_message(client, message)
except (socket.error, ConnectionError) as e:
logger.warning(f"Client {client_id} disconnected while sending message: {e}")
logger.warning(
f"Client {client_id} disconnected while sending message: {e}"
)
self.close(client_id)
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}")
Expand All @@ -264,7 +263,7 @@ def run(self) -> None:
logger.info(f"Starting WebSocket server on {self.uri}")
max_retries = 3
retry_count = 0

while retry_count < max_retries:
try:
self.server.run_forever()
Expand All @@ -273,7 +272,9 @@ def run(self) -> None:
logger.warning(f"WebSocket connection error: {e}")
retry_count += 1
if retry_count < max_retries:
logger.info(f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})")
logger.info(
f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})"
)
else:
logger.error("Max retry attempts reached. Server shutting down.")
except Exception as e:
Expand All @@ -295,20 +296,24 @@ def close(self, client_id: Text) -> None:
# Clean up pipeline state using built-in reset method
client_state = self._clients[client_id]
client_state.inference.pipeline.reset()

# Close audio source and remove client
client_state.audio_source.close()
del self._clients[client_id]

# Try to send a close frame to the client
try:
client = next((c for c in self.server.clients if c["id"] == client_id), None)
client = next(
(c for c in self.server.clients if c["id"] == client_id), None
)
if client:
self.server.send_message(client, "CLOSE")
except Exception:
pass # Ignore errors when trying to send close message

logger.info(f"Closed connection and cleaned up state for client: {client_id}")

logger.info(
f"Closed connection and cleaned up state for client: {client_id}"
)
except Exception as e:
logger.error(f"Error closing client {client_id}: {e}")
# Ensure client is removed from dictionary even if cleanup fails
Expand Down
4 changes: 2 additions & 2 deletions src/diart/sources.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from queue import SimpleQueue
from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple
from typing import Any, AnyStr, Dict, Optional, Text, Tuple, Union

import numpy as np
import sounddevice as sd
Expand All @@ -12,7 +12,7 @@
from websocket_server import WebsocketServer

from . import utils
from .audio import FilePath, AudioLoader
from .audio import AudioLoader, FilePath


class AudioSource(ABC):
Expand Down

0 comments on commit 1975f75

Please sign in to comment.