Skip to content

Commit

Permalink
Add threading.Thread Class Override
Browse files Browse the repository at this point in the history
  • Loading branch information
gerrymeixiong committed Dec 16, 2024
1 parent 7df7b27 commit 000d45c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
10 changes: 7 additions & 3 deletions deepgram/clients/common/v1/abstract_sync_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import time
import logging
from typing import Dict, Union, Optional, cast, Any, Callable
from typing import Dict, Union, Optional, cast, Any, Callable, Type
from datetime import datetime
import threading
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -52,12 +52,14 @@ class AbstractSyncWebSocketClient(ABC): # pylint: disable=too-many-instance-att
_listen_thread: Union[threading.Thread, None]
_delegate: Optional[Speaker] = None

_thread_cls: Type[threading.Thread]

_kwargs: Optional[Dict] = None
_addons: Optional[Dict] = None
_options: Optional[Dict] = None
_headers: Optional[Dict] = None

def __init__(self, config: DeepgramClientOptions, endpoint: str = ""):
def __init__(self, config: DeepgramClientOptions, endpoint: str = "", thread_cls: Type[threading.Thread] = threading.Thread) -> None:
if config is None:
raise DeepgramError("Config is required")
if endpoint == "":
Expand All @@ -73,6 +75,8 @@ def __init__(self, config: DeepgramClientOptions, endpoint: str = ""):

self._listen_thread = None

self._thread_cls = thread_cls

# exit
self._exit_event = threading.Event()

Expand Down Expand Up @@ -152,7 +156,7 @@ def start(
self._delegate.set_push_callback(self._process_message)
else:
self._logger.notice("create _listening thread")
self._listen_thread = threading.Thread(target=self._listening)
self._listen_thread = self._thread_cls(target=self._listening)
self._listen_thread.start()

# debug the threads
Expand Down
12 changes: 8 additions & 4 deletions deepgram/clients/listen/v1/websocket/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import time
import logging
from typing import Dict, Union, Optional, cast, Any, Callable
from typing import Dict, Union, Optional, cast, Any, Callable, Type
from datetime import datetime
import threading

Expand Down Expand Up @@ -55,12 +55,14 @@ class ListenWebSocketClient(
_flush_thread: Union[threading.Thread, None]
_last_datagram: Optional[datetime] = None

_thread_cls: Type[threading.Thread]

_kwargs: Optional[Dict] = None
_addons: Optional[Dict] = None
_options: Optional[Dict] = None
_headers: Optional[Dict] = None

def __init__(self, config: DeepgramClientOptions):
def __init__(self, config: DeepgramClientOptions, thread_cls: Type[threading.Thread] = threading.Thread):
if config is None:
raise DeepgramError("Config is required")

Expand All @@ -78,6 +80,8 @@ def __init__(self, config: DeepgramClientOptions):
self._last_datagram = None
self._lock_flush = threading.Lock()

self._thread_cls = thread_cls

# init handlers
self._event_handlers = {
event: [] for event in LiveTranscriptionEvents.__members__.values()
Expand Down Expand Up @@ -154,15 +158,15 @@ def start(
# keepalive thread
if self._config.is_keep_alive_enabled():
self._logger.notice("keepalive is enabled")
self._keep_alive_thread = threading.Thread(target=self._keep_alive)
self._keep_alive_thread = self._thread_cls(target=self._keep_alive)
self._keep_alive_thread.start()
else:
self._logger.notice("keepalive is disabled")

# flush thread
if self._config.is_auto_flush_reply_enabled():
self._logger.notice("autoflush is enabled")
self._flush_thread = threading.Thread(target=self._flush)
self._flush_thread = self._thread_cls(target=self._flush)
self._flush_thread.start()
else:
self._logger.notice("autoflush is disabled")
Expand Down

0 comments on commit 000d45c

Please sign in to comment.