Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add WebSocket server with multi-client support #263

Open
wants to merge 33 commits into
base: develop
Choose a base branch
from

Conversation

janaab11
Copy link

@janaab11 janaab11 commented Dec 22, 2024

Overview

Implements a WebSocket server that can handle audio streams from multiple client connections

Changes

  • Added multi-client support to WebSocket server
  • Created StreamingInferenceHandler for managing connections
  • Added Dockerfile for easier deployment

Testing

  • Tested with multiple concurrent clients
  • Verified Docker container functionality
  • Checked resource cleanup on disconnection

Please let me know if any changes or improvements are needed!


Fixes #252

@janaab11
Copy link
Author

janaab11 commented Dec 22, 2024

Couple things I still want to work on:

  1. The way resources are shared between client connections. Currently, each connection shares the same config but the models (for segmentation and embedding, in SpeakerDiarization) are initialised and maintained separately. This is a more flexible design, but quite wasteful.

    • I have attempted a fix at this, but parallel connections ended up sharing state. From what I understood, I might need to dig deeper into the Aggregation steps in the pipeline.
  2. Related to the above point. I was also wondering if different configs could share the same underlying model resources at runtime. For eg. I have seen performance differ a lot for the same config when number of speakers are far apart - say, 2 vs 10. And this is a parameter that a client is better suited to configure, when setting up the connection to the server.

@janaab11
Copy link
Author

janaab11 commented Dec 22, 2024

I have also added a cleanup step in the server, when a client disconnects. This was mostly to ensure explicit memory management since client streams are not sharing resources, at the moment - but this should also address #255

@janaab11
Copy link
Author

janaab11 commented Dec 31, 2024

Moved to LazyModel for resource management, based off this comment. Client-specific Pipeline instances now share resources that are initialised in a common PipelineConfig instance.

Still unsure about how this would scale with client connections - would appreciate any thoughts on this!

@janaab11 janaab11 force-pushed the develop-server branch 2 times, most recently from bbf2df2 to d4380c4 Compare January 1, 2025 06:50
@juanmc2005 juanmc2005 self-requested a review January 2, 2025 14:36
@juanmc2005 juanmc2005 added the feature New feature or request label Jan 2, 2025
@juanmc2005 juanmc2005 added this to the Version 0.9.2 milestone Jan 2, 2025
Copy link
Owner

@juanmc2005 juanmc2005 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR! The feature is well designed, I'd just like to make a few adjustments to clean up pieces of code that should be discontinued and to make sure that the new API is clean and easy to understand.

The PR is also missing an update to the websocket section of the README with some usage examples. I think that will also give us some ideas on where the API can be improved.

Dockerfile Outdated Show resolved Hide resolved
src/diart/handler.py Outdated Show resolved Hide resolved
src/diart/handler.py Outdated Show resolved Hide resolved
src/diart/handler.py Outdated Show resolved Hide resolved
src/diart/sources.py Show resolved Hide resolved
src/diart/sources.py Outdated Show resolved Hide resolved
src/diart/handler.py Outdated Show resolved Hide resolved
src/diart/console/serve.py Outdated Show resolved Hide resolved
src/diart/handler.py Outdated Show resolved Hide resolved
src/diart/handler.py Outdated Show resolved Hide resolved
@juanmc2005
Copy link
Owner

@janaab11 about your question concerning the wasteful model copies, I fully agree with this limitation. However, I think it would be best suited for a separate PR, it's a pretty big amount of work and I would hate to delay the multi-client feature because of it. Glad to discuss it in a future PR if that interests you!

Copy link
Author

@janaab11 janaab11 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have completed most of the suggested changes. A few items remain:

  • WebSocketAudioSource: I do agree it is acting as a proxy, but prefer keeping it as an AudioSource subclass - makes it consistent with other audio sources. Would like to think more about this and then propose changes.
  • Documentation: Planning to add WebSocket usage examples to the README

@janaab11
Copy link
Author

janaab11 commented Jan 3, 2025

@janaab11 about your question concerning the wasteful model copies, I fully agree with this limitation. However, I think it would be best suited for a separate PR, it's a pretty big amount of work and I would hate to delay the multi-client feature because of it. Glad to discuss it in a future PR if that interests you!

Definitely interested in resolving this - happy to take it up after closing the work here!

@juanmc2005
Copy link
Owner

WebSocketAudioSource: I do agree it is acting as a proxy, but prefer keeping it as an AudioSource subclass - makes it consistent with other audio sources. Would like to think more about this and then propose changes.

Oh it would definitely still be a subclass of AudioSource, my suggestion was simply to move it to websockets.py so we hide it from the end user. I don't think such a proxy audio source would be needed under normal circumstances

.dockerignore Outdated Show resolved Hide resolved
Dockerfile Show resolved Hide resolved
src/diart/console/serve.py Show resolved Hide resolved
src/diart/websockets.py Show resolved Hide resolved
src/diart/websockets.py Outdated Show resolved Hide resolved
src/diart/websockets.py Outdated Show resolved Hide resolved
src/diart/websockets.py Outdated Show resolved Hide resolved
src/diart/websockets.py Show resolved Hide resolved
src/diart/websockets.py Outdated Show resolved Hide resolved
src/diart/websockets.py Outdated Show resolved Hide resolved
@janaab11
Copy link
Author

janaab11 commented Jan 3, 2025

Oh it would definitely still be a subclass of AudioSource, my suggestion was simply to move it to websockets.py so we hide it from the end user. I don't think such a proxy audio source would be needed under normal circumstances

Okay this does make sense to me too - not exposing such websocket-specific functionality. Made the changes!

@janaab11
Copy link
Author

janaab11 commented Jan 3, 2025

Added error handling for the following edge cases - in send, close and _on_message_received methods:

if client is None:
    return
if client_id not in self._clients:
    return

These edge-cases can occur due to race conditions in the client lifecycle (connect/disconnect/cleanup) or network issues that lead to client state mismatches between the server and client. Added warnings to catch these async timing issues, and documented the edge-case conditions in the respective method's docstring.

@janaab11
Copy link
Author

janaab11 commented Jan 3, 2025

Modified client.py to handle disconnects properly on KeyboardInterrupt events - i believe this was referenced in another issue as well. Do let me know if the implementation is too involved? I had liked the simplicity of the client before this.

Apart from this, complete documentation in the README is pending. Will get to that next.

@janaab11 janaab11 requested a review from juanmc2005 January 6, 2025 13:45
@@ -202,6 +202,7 @@ def embedding_loader():
segmentation = SegmentationModel(segmentation_loader)
embedding = EmbeddingModel(embedding_loader)
config = SpeakerDiarizationConfig(
# Set the segmentation model used in the paper
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct. To remove

@@ -332,20 +333,57 @@ diart.client microphone --host <server-address> --port 7007

See `-h` for more options.

### From the Dockerfile
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### From the Dockerfile
### From a Docker container


You can also run the server in a Docker container. First, build the image:
```shell
docker build -t diart -f Dockerfile .
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-f Dockerfile is not needed, as it will pick up the file with that name in the specified directory


Run the server with default configuration:
```shell
docker run -p 7007:7007 --gpus all -e HF_TOKEN=<token> diart
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add a note somewhere saying that for GPU usage they need to install nvidia-container-toolkit.

Also, is there a way to pick up the HF token from the huggingface-cli config? That way we avoid passing it directly and keeping it in the terminal history. This is possible when running outside docker, and we shouldn't make it mandatory, as it's an important security feature.

docker run -p 7007:7007 --gpus all -e HF_TOKEN=<token> diart
```

Run with custom configuration:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Run with custom configuration:
Example with a custom configuration:

Raises
------
Warning
If client not found in self._clients. Common cases:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as previous comment

try:
# Clean up pipeline state using built-in reset method
client_state = self._clients[client_id]
client_state.inference.pipeline.reset()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure a reset is required because the pipeline will be removed from memory anyway

# Ensure client is removed even if cleanup fails
self._clients.pop(client_id, None)

def close_all(self) -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be called shutdown() because it shutdowns the server after closing all clients

Comment on lines +347 to +370
while retry_count < max_retries:
try:
self.server.run_forever()
break # If server exits normally, break the retry loop
except OSError as e:
logger.warning(f"WebSocket server connection error: {e}")
retry_count += 1
if retry_count < max_retries:
delay = base_delay * (2 ** (retry_count - 1)) # Exponential backoff
logger.info(
f"Retrying in {delay} seconds... "
f"(attempt {retry_count}/{max_retries})"
)
time.sleep(delay)
else:
logger.error(
f"WebSocket server failed to start after {max_retries} attempts. "
f"Last error: {e}"
)
except Exception as e:
logger.error(f"Fatal server error: {e}")
break
finally:
self.close_all()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, it's probably not required to retry starting the server, right? I mean if starting the server doesn't work, it's probably a configuration error that should be fixed by the developer, for example if the port is already in use. What do you think? What use case did you have in mind for retrying?


return ClientState(audio_source=audio_source, inference=inference)

def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should allow a max number of clients to connect? My reasoning is the following: if we have to copy StreamingInference instances (including models) for every new client, the server will most likely crash at some point (especially if sharing GPU). However, given system resources, we can probably estimate how many clients fit in the machine, or if the new client fits in the remaining available resources.

If this is too complicated, we can simply add a parameter inside __init__() for the maximum number of simultaneous clients. Something like client_pool_size: int = 4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants