From 87a47d7db4e8de003cb729f9a8e728f532004b42 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Nov 2023 13:44:03 +0800 Subject: [PATCH] Release GIL to support multithreading in websocket servers. (#451) --- python-api-examples/non_streaming_server.py | 12 ++++--- python-api-examples/streaming_server.py | 11 ++++--- sherpa-onnx/python/csrc/circular-buffer.cc | 10 +++--- sherpa-onnx/python/csrc/offline-recognizer.cc | 21 +++++++----- sherpa-onnx/python/csrc/offline-stream.cc | 13 +++++++- sherpa-onnx/python/csrc/offline-tts.cc | 2 +- sherpa-onnx/python/csrc/online-recognizer.cc | 32 ++++++++++++------- sherpa-onnx/python/csrc/online-stream.cc | 6 ++-- sherpa-onnx/python/csrc/vad-model.cc | 16 ++++++---- .../python/csrc/voice-activity-detector.cc | 11 ++++--- 10 files changed, 87 insertions(+), 47 deletions(-) diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 2ca45a76e..ac83d9a42 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -414,7 +414,7 @@ def get_args(): parser.add_argument( "--max-batch-size", type=int, - default=25, + default=3, help="""Max batch size for computation. Note if there are not enough requests in the queue, it will wait for max_wait_ms time. After that, even if there are not enough requests, it still sends the @@ -459,7 +459,7 @@ def get_args(): parser.add_argument( "--max-active-connections", type=int, - default=500, + default=200, help="""Maximum number of active connections. The server will refuse to accept new connections once the current number of active connections equals to this limit. @@ -533,6 +533,7 @@ def __init__( self.certificate = certificate self.http_server = HttpServer(doc_root) + self.nn_pool_size = nn_pool_size self.nn_pool = ThreadPoolExecutor( max_workers=nn_pool_size, thread_name_prefix="nn", @@ -604,7 +605,9 @@ async def process_request( async def run(self, port: int): logging.info("started") - task = asyncio.create_task(self.stream_consumer_task()) + tasks = [] + for i in range(self.nn_pool_size): + tasks.append(asyncio.create_task(self.stream_consumer_task())) if self.certificate: logging.info(f"Using certificate: {self.certificate}") @@ -636,7 +639,7 @@ async def run(self, port: int): await asyncio.Future() # run forever - await task # not reachable + await asyncio.gather(*tasks) # not reachable async def recv_audio_samples( self, @@ -722,6 +725,7 @@ async def stream_consumer_task(self): batch.append(item) except asyncio.QueueEmpty: pass + stream_list = [b[0] for b in batch] future_list = [b[1] for b in batch] diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index c40428a83..c47ea28a9 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -296,7 +296,7 @@ def get_args(): parser.add_argument( "--max-batch-size", type=int, - default=50, + default=3, help="""Max batch size for computation. Note if there are not enough requests in the queue, it will wait for max_wait_ms time. After that, even if there are not enough requests, it still sends the @@ -334,7 +334,7 @@ def get_args(): parser.add_argument( "--max-active-connections", type=int, - default=500, + default=200, help="""Maximum number of active connections. The server will refuse to accept new connections once the current number of active connections equals to this limit. @@ -478,6 +478,7 @@ def __init__( self.certificate = certificate self.http_server = HttpServer(doc_root) + self.nn_pool_size = nn_pool_size self.nn_pool = ThreadPoolExecutor( max_workers=nn_pool_size, thread_name_prefix="nn", @@ -591,7 +592,9 @@ async def process_request( return status, header, response async def run(self, port: int): - task = asyncio.create_task(self.stream_consumer_task()) + tasks = [] + for i in range(self.nn_pool_size): + tasks.append(asyncio.create_task(self.stream_consumer_task())) if self.certificate: logging.info(f"Using certificate: {self.certificate}") @@ -629,7 +632,7 @@ async def run(self, port: int): await asyncio.Future() # run forever - await task # not reachable + await asyncio.gather(*tasks) # not reachable async def handle_connection( self, diff --git a/sherpa-onnx/python/csrc/circular-buffer.cc b/sherpa-onnx/python/csrc/circular-buffer.cc index 20ea4b519..f08a27e3b 100644 --- a/sherpa-onnx/python/csrc/circular-buffer.cc +++ b/sherpa-onnx/python/csrc/circular-buffer.cc @@ -19,10 +19,12 @@ void PybindCircularBuffer(py::module *m) { [](PyClass &self, const std::vector &samples) { self.Push(samples.data(), samples.size()); }, - py::arg("samples")) - .def("get", &PyClass::Get, py::arg("start_index"), py::arg("n")) - .def("pop", &PyClass::Pop, py::arg("n")) - .def("reset", &PyClass::Reset) + py::arg("samples"), py::call_guard()) + .def("get", &PyClass::Get, py::arg("start_index"), py::arg("n"), + py::call_guard()) + .def("pop", &PyClass::Pop, py::arg("n"), + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) .def_property_readonly("size", &PyClass::Size) .def_property_readonly("head", &PyClass::Head) .def_property_readonly("tail", &PyClass::Tail); diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 3c3cc043d..5662a9e92 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -41,19 +41,24 @@ void PybindOfflineRecognizer(py::module *m) { using PyClass = OfflineRecognizer; py::class_(*m, "OfflineRecognizer") .def(py::init(), py::arg("config")) - .def("create_stream", - [](const PyClass &self) { return self.CreateStream(); }) + .def( + "create_stream", + [](const PyClass &self) { return self.CreateStream(); }, + py::call_guard()) .def( "create_stream", [](PyClass &self, const std::string &hotwords) { return self.CreateStream(hotwords); }, - py::arg("hotwords")) - .def("decode_stream", &PyClass::DecodeStream) - .def("decode_streams", - [](const PyClass &self, std::vector ss) { - self.DecodeStreams(ss.data(), ss.size()); - }); + py::arg("hotwords"), py::call_guard()) + .def("decode_stream", &PyClass::DecodeStream, + py::call_guard()) + .def( + "decode_streams", + [](const PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }, + py::call_guard()); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-stream.cc b/sherpa-onnx/python/csrc/offline-stream.cc index bf851fd53..e9eb9c57a 100644 --- a/sherpa-onnx/python/csrc/offline-stream.cc +++ b/sherpa-onnx/python/csrc/offline-stream.cc @@ -50,9 +50,20 @@ void PybindOfflineStream(py::module *m) { .def( "accept_waveform", [](PyClass &self, float sample_rate, py::array_t waveform) { +#if 0 + auto report_gil_status = []() { + auto is_gil_held = false; + if (auto tstate = py::detail::get_thread_state_unchecked()) + is_gil_held = (tstate == PyGILState_GetThisThreadState()); + + return is_gil_held ? "GIL held" : "GIL released"; + }; + std::cout << report_gil_status() << "\n"; +#endif self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); }, - py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, + py::call_guard()) .def_property_readonly("result", &PyClass::GetResult); } diff --git a/sherpa-onnx/python/csrc/offline-tts.cc b/sherpa-onnx/python/csrc/offline-tts.cc index e58ca3113..538ceceed 100644 --- a/sherpa-onnx/python/csrc/offline-tts.cc +++ b/sherpa-onnx/python/csrc/offline-tts.cc @@ -45,7 +45,7 @@ void PybindOfflineTts(py::module *m) { py::class_(*m, "OfflineTts") .def(py::init(), py::arg("config")) .def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0, - py::arg("speed") = 1.0); + py::arg("speed") = 1.0, py::call_guard()); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 9cfce8456..be68a104c 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -54,23 +54,31 @@ void PybindOnlineRecognizer(py::module *m) { using PyClass = OnlineRecognizer; py::class_(*m, "OnlineRecognizer") .def(py::init(), py::arg("config")) - .def("create_stream", - [](const PyClass &self) { return self.CreateStream(); }) + .def( + "create_stream", + [](const PyClass &self) { return self.CreateStream(); }, + py::call_guard()) .def( "create_stream", [](PyClass &self, const std::string &hotwords) { return self.CreateStream(hotwords); }, - py::arg("hotwords")) - .def("is_ready", &PyClass::IsReady) - .def("decode_stream", &PyClass::DecodeStream) - .def("decode_streams", - [](PyClass &self, std::vector ss) { - self.DecodeStreams(ss.data(), ss.size()); - }) - .def("get_result", &PyClass::GetResult) - .def("is_endpoint", &PyClass::IsEndpoint) - .def("reset", &PyClass::Reset); + py::arg("hotwords"), py::call_guard()) + .def("is_ready", &PyClass::IsReady, + py::call_guard()) + .def("decode_stream", &PyClass::DecodeStream, + py::call_guard()) + .def( + "decode_streams", + [](PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }, + py::call_guard()) + .def("get_result", &PyClass::GetResult, + py::call_guard()) + .def("is_endpoint", &PyClass::IsEndpoint, + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-stream.cc b/sherpa-onnx/python/csrc/online-stream.cc index 411354a03..9f8a17b9c 100644 --- a/sherpa-onnx/python/csrc/online-stream.cc +++ b/sherpa-onnx/python/csrc/online-stream.cc @@ -28,8 +28,10 @@ void PybindOnlineStream(py::module *m) { [](PyClass &self, float sample_rate, py::array_t waveform) { self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); }, - py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) - .def("input_finished", &PyClass::InputFinished); + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, + py::call_guard()) + .def("input_finished", &PyClass::InputFinished, + py::call_guard()); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/vad-model.cc b/sherpa-onnx/python/csrc/vad-model.cc index 11c81cbba..f304fd0ae 100644 --- a/sherpa-onnx/python/csrc/vad-model.cc +++ b/sherpa-onnx/python/csrc/vad-model.cc @@ -13,17 +13,21 @@ namespace sherpa_onnx { void PybindVadModel(py::module *m) { using PyClass = VadModel; py::class_(*m, "VadModel") - .def_static("create", &PyClass::Create, py::arg("config")) - .def("reset", &PyClass::Reset) + .def_static("create", &PyClass::Create, py::arg("config"), + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) .def( "is_speech", [](PyClass &self, const std::vector &samples) -> bool { return self.IsSpeech(samples.data(), samples.size()); }, - py::arg("samples")) - .def("window_size", &PyClass::WindowSize) - .def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples) - .def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples); + py::arg("samples"), py::call_guard()) + .def("window_size", &PyClass::WindowSize, + py::call_guard()) + .def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples, + py::call_guard()) + .def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples, + py::call_guard()); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/voice-activity-detector.cc b/sherpa-onnx/python/csrc/voice-activity-detector.cc index 237e32abc..f3e683185 100644 --- a/sherpa-onnx/python/csrc/voice-activity-detector.cc +++ b/sherpa-onnx/python/csrc/voice-activity-detector.cc @@ -30,11 +30,12 @@ void PybindVoiceActivityDetector(py::module *m) { [](PyClass &self, const std::vector &samples) { self.AcceptWaveform(samples.data(), samples.size()); }, - py::arg("samples")) - .def("empty", &PyClass::Empty) - .def("pop", &PyClass::Pop) - .def("is_speech_detected", &PyClass::IsSpeechDetected) - .def("reset", &PyClass::Reset) + py::arg("samples"), py::call_guard()) + .def("empty", &PyClass::Empty, py::call_guard()) + .def("pop", &PyClass::Pop, py::call_guard()) + .def("is_speech_detected", &PyClass::IsSpeechDetected, + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) .def_property_readonly("front", &PyClass::Front); }