Skip to content

Commit

Permalink
simplified websocket-server class and improved naming
Browse files Browse the repository at this point in the history
  • Loading branch information
janaab11 committed Jan 3, 2025
1 parent 7ba2f55 commit 3d3bb45
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 59 deletions.
17 changes: 4 additions & 13 deletions src/diart/console/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diart import argdoc
from diart import models as m
from diart import utils
from diart.handler import StreamingHandlerConfig, StreamingHandler
from diart.websockets import WebSocketStreamingServer


def run():
Expand Down Expand Up @@ -98,24 +98,15 @@ def run():
pipeline_class = utils.get_pipeline_class(args.pipeline)
pipeline_config = pipeline_class.get_config_class()(**vars(args))

# Create handler configuration for inference
config = StreamingHandlerConfig(
# Initialize Websocket server
server = WebSocketStreamingServer(
pipeline_class=pipeline_class,
pipeline_config=pipeline_config,
batch_size=1,
do_profile=False,
do_plot=False,
show_progress=False,
)

# Initialize handler
handler = StreamingHandler(
config=config,
host=args.host,
port=args.port,
)

handler.run()
server.run()


if __name__ == "__main__":
Expand Down
66 changes: 20 additions & 46 deletions src/diart/handler.py → src/diart/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from . import blocks
from . import sources as src
from .inference import StreamingInference
from .progress import ProgressBar, RichProgressBar

# Configure logging
logging.basicConfig(
Expand All @@ -18,37 +17,6 @@
logger = logging.getLogger(__name__)


@dataclass
class StreamingHandlerConfig:
"""Configuration for streaming inference.
Parameters
----------
pipeline_class : type
Pipeline class
pipeline_config : blocks.PipelineConfig
Pipeline configuration
batch_size : int
Number of inputs to process at once
do_profile : bool
Enable processing time profiling
do_plot : bool
Enable real-time prediction plotting
show_progress : bool
Display progress bar
progress_bar : Optional[ProgressBar]
Custom progress bar implementation
"""

pipeline_class: type
pipeline_config: blocks.PipelineConfig
batch_size: int = 1
do_profile: bool = True
do_plot: bool = False
show_progress: bool = True
progress_bar: Optional[ProgressBar] = None


@dataclass
class ClientState:
"""Represents the state of a connected client."""
Expand All @@ -57,16 +25,18 @@ class ClientState:
inference: StreamingInference


class StreamingHandler:
class WebSocketStreamingServer:
"""Handles real-time speaker diarization inference for multiple audio sources over WebSocket.
This handler manages WebSocket connections from multiple clients, processing
audio streams and performing speaker diarization in real-time.
Parameters
----------
config : StreamingHandlerConfig
Streaming inference configuration
pipeline_class : type
Pipeline class
pipeline_config : blocks.PipelineConfig
Pipeline configuration
host : str, optional
WebSocket server host, by default "127.0.0.1"
port : int, optional
Expand All @@ -79,17 +49,20 @@ class StreamingHandler:

def __init__(
self,
config: StreamingHandlerConfig,
pipeline_class: type,
pipeline_config: blocks.PipelineConfig,
host: Text = "127.0.0.1",
port: int = 7007,
key: Optional[Union[Text, Path]] = None,
certificate: Optional[Union[Text, Path]] = None,
):
self.config = config
self.host = host
self.port = port
# Pipeline configuration
self.pipeline_class = pipeline_class
self.pipeline_config = pipeline_config

# Server configuration
self.host = host
self.port = port
self.uri = f"{host}:{port}"
self._clients: Dict[Text, ClientState] = {}

Expand Down Expand Up @@ -118,21 +91,22 @@ def _create_client_state(self, client_id: Text) -> ClientState:
"""
# Create a new pipeline instance with the same config
# This ensures each client has its own state while sharing model weights
pipeline = self.config.pipeline_class(self.config.pipeline_config)
pipeline = self.pipeline_class(self.pipeline_config)

audio_source = src.WebSocketAudioSource(
uri=f"{self.uri}:{client_id}",
sample_rate=self.config.pipeline_config.sample_rate,
sample_rate=self.pipeline_config.sample_rate,
)

inference = StreamingInference(
pipeline=pipeline,
source=audio_source,
batch_size=self.config.batch_size,
do_profile=self.config.do_profile,
do_plot=self.config.do_plot,
show_progress=self.config.show_progress,
progress_bar=self.config.progress_bar,
# The following variables are fixed for a client
batch_size=1,
do_profile=False, # for minimal latency
do_plot=False,
show_progress=False,
progress_bar=None,
)

return ClientState(audio_source=audio_source, inference=inference)
Expand Down

0 comments on commit 3d3bb45

Please sign in to comment.