Skip to content

Commit

Permalink
Release GIL to support multithreading in websocket servers. (#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Nov 27, 2023
1 parent 8dc08a9 commit 87a47d7
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 47 deletions.
12 changes: 8 additions & 4 deletions python-api-examples/non_streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def get_args():
parser.add_argument(
"--max-batch-size",
type=int,
default=25,
default=3,
help="""Max batch size for computation. Note if there are not enough
requests in the queue, it will wait for max_wait_ms time. After that,
even if there are not enough requests, it still sends the
Expand Down Expand Up @@ -459,7 +459,7 @@ def get_args():
parser.add_argument(
"--max-active-connections",
type=int,
default=500,
default=200,
help="""Maximum number of active connections. The server will refuse
to accept new connections once the current number of active connections
equals to this limit.
Expand Down Expand Up @@ -533,6 +533,7 @@ def __init__(
self.certificate = certificate
self.http_server = HttpServer(doc_root)

self.nn_pool_size = nn_pool_size
self.nn_pool = ThreadPoolExecutor(
max_workers=nn_pool_size,
thread_name_prefix="nn",
Expand Down Expand Up @@ -604,7 +605,9 @@ async def process_request(
async def run(self, port: int):
logging.info("started")

task = asyncio.create_task(self.stream_consumer_task())
tasks = []
for i in range(self.nn_pool_size):
tasks.append(asyncio.create_task(self.stream_consumer_task()))

if self.certificate:
logging.info(f"Using certificate: {self.certificate}")
Expand Down Expand Up @@ -636,7 +639,7 @@ async def run(self, port: int):

await asyncio.Future() # run forever

await task # not reachable
await asyncio.gather(*tasks) # not reachable

async def recv_audio_samples(
self,
Expand Down Expand Up @@ -722,6 +725,7 @@ async def stream_consumer_task(self):
batch.append(item)
except asyncio.QueueEmpty:
pass

stream_list = [b[0] for b in batch]
future_list = [b[1] for b in batch]

Expand Down
11 changes: 7 additions & 4 deletions python-api-examples/streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def get_args():
parser.add_argument(
"--max-batch-size",
type=int,
default=50,
default=3,
help="""Max batch size for computation. Note if there are not enough
requests in the queue, it will wait for max_wait_ms time. After that,
even if there are not enough requests, it still sends the
Expand Down Expand Up @@ -334,7 +334,7 @@ def get_args():
parser.add_argument(
"--max-active-connections",
type=int,
default=500,
default=200,
help="""Maximum number of active connections. The server will refuse
to accept new connections once the current number of active connections
equals to this limit.
Expand Down Expand Up @@ -478,6 +478,7 @@ def __init__(
self.certificate = certificate
self.http_server = HttpServer(doc_root)

self.nn_pool_size = nn_pool_size
self.nn_pool = ThreadPoolExecutor(
max_workers=nn_pool_size,
thread_name_prefix="nn",
Expand Down Expand Up @@ -591,7 +592,9 @@ async def process_request(
return status, header, response

async def run(self, port: int):
task = asyncio.create_task(self.stream_consumer_task())
tasks = []
for i in range(self.nn_pool_size):
tasks.append(asyncio.create_task(self.stream_consumer_task()))

if self.certificate:
logging.info(f"Using certificate: {self.certificate}")
Expand Down Expand Up @@ -629,7 +632,7 @@ async def run(self, port: int):

await asyncio.Future() # run forever

await task # not reachable
await asyncio.gather(*tasks) # not reachable

async def handle_connection(
self,
Expand Down
10 changes: 6 additions & 4 deletions sherpa-onnx/python/csrc/circular-buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ void PybindCircularBuffer(py::module *m) {
[](PyClass &self, const std::vector<float> &samples) {
self.Push(samples.data(), samples.size());
},
py::arg("samples"))
.def("get", &PyClass::Get, py::arg("start_index"), py::arg("n"))
.def("pop", &PyClass::Pop, py::arg("n"))
.def("reset", &PyClass::Reset)
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def("get", &PyClass::Get, py::arg("start_index"), py::arg("n"),
py::call_guard<py::gil_scoped_release>())
.def("pop", &PyClass::Pop, py::arg("n"),
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("size", &PyClass::Size)
.def_property_readonly("head", &PyClass::Head)
.def_property_readonly("tail", &PyClass::Tail);
Expand Down
21 changes: 13 additions & 8 deletions sherpa-onnx/python/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@ void PybindOfflineRecognizer(py::module *m) {
using PyClass = OfflineRecognizer;
py::class_<PyClass>(*m, "OfflineRecognizer")
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
.def("create_stream",
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](const PyClass &self) { return self.CreateStream(); },
py::call_guard<py::gil_scoped_release>())
.def(
"create_stream",
[](PyClass &self, const std::string &hotwords) {
return self.CreateStream(hotwords);
},
py::arg("hotwords"))
.def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams",
[](const PyClass &self, std::vector<OfflineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
});
py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
.def("decode_stream", &PyClass::DecodeStream,
py::call_guard<py::gil_scoped_release>())
.def(
"decode_streams",
[](const PyClass &self, std::vector<OfflineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
},
py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
13 changes: 12 additions & 1 deletion sherpa-onnx/python/csrc/offline-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,20 @@ void PybindOfflineStream(py::module *m) {
.def(
"accept_waveform",
[](PyClass &self, float sample_rate, py::array_t<float> waveform) {
#if 0
auto report_gil_status = []() {
auto is_gil_held = false;
if (auto tstate = py::detail::get_thread_state_unchecked())
is_gil_held = (tstate == PyGILState_GetThisThreadState());

return is_gil_held ? "GIL held" : "GIL released";
};
std::cout << report_gil_status() << "\n";
#endif
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
},
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("result", &PyClass::GetResult);
}

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/python/csrc/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void PybindOfflineTts(py::module *m) {
py::class_<PyClass>(*m, "OfflineTts")
.def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
.def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0,
py::arg("speed") = 1.0);
py::arg("speed") = 1.0, py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
32 changes: 20 additions & 12 deletions sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,31 @@ void PybindOnlineRecognizer(py::module *m) {
using PyClass = OnlineRecognizer;
py::class_<PyClass>(*m, "OnlineRecognizer")
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
.def("create_stream",
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](const PyClass &self) { return self.CreateStream(); },
py::call_guard<py::gil_scoped_release>())
.def(
"create_stream",
[](PyClass &self, const std::string &hotwords) {
return self.CreateStream(hotwords);
},
py::arg("hotwords"))
.def("is_ready", &PyClass::IsReady)
.def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams",
[](PyClass &self, std::vector<OnlineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
})
.def("get_result", &PyClass::GetResult)
.def("is_endpoint", &PyClass::IsEndpoint)
.def("reset", &PyClass::Reset);
py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
.def("is_ready", &PyClass::IsReady,
py::call_guard<py::gil_scoped_release>())
.def("decode_stream", &PyClass::DecodeStream,
py::call_guard<py::gil_scoped_release>())
.def(
"decode_streams",
[](PyClass &self, std::vector<OnlineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
},
py::call_guard<py::gil_scoped_release>())
.def("get_result", &PyClass::GetResult,
py::call_guard<py::gil_scoped_release>())
.def("is_endpoint", &PyClass::IsEndpoint,
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
6 changes: 4 additions & 2 deletions sherpa-onnx/python/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ void PybindOnlineStream(py::module *m) {
[](PyClass &self, float sample_rate, py::array_t<float> waveform) {
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
},
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
.def("input_finished", &PyClass::InputFinished);
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
py::call_guard<py::gil_scoped_release>())
.def("input_finished", &PyClass::InputFinished,
py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
16 changes: 10 additions & 6 deletions sherpa-onnx/python/csrc/vad-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@ namespace sherpa_onnx {
void PybindVadModel(py::module *m) {
using PyClass = VadModel;
py::class_<PyClass>(*m, "VadModel")
.def_static("create", &PyClass::Create, py::arg("config"))
.def("reset", &PyClass::Reset)
.def_static("create", &PyClass::Create, py::arg("config"),
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def(
"is_speech",
[](PyClass &self, const std::vector<float> &samples) -> bool {
return self.IsSpeech(samples.data(), samples.size());
},
py::arg("samples"))
.def("window_size", &PyClass::WindowSize)
.def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples)
.def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples);
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def("window_size", &PyClass::WindowSize,
py::call_guard<py::gil_scoped_release>())
.def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples,
py::call_guard<py::gil_scoped_release>())
.def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples,
py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
11 changes: 6 additions & 5 deletions sherpa-onnx/python/csrc/voice-activity-detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ void PybindVoiceActivityDetector(py::module *m) {
[](PyClass &self, const std::vector<float> &samples) {
self.AcceptWaveform(samples.data(), samples.size());
},
py::arg("samples"))
.def("empty", &PyClass::Empty)
.def("pop", &PyClass::Pop)
.def("is_speech_detected", &PyClass::IsSpeechDetected)
.def("reset", &PyClass::Reset)
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
.def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
.def("is_speech_detected", &PyClass::IsSpeechDetected,
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("front", &PyClass::Front);
}

Expand Down

0 comments on commit 87a47d7

Please sign in to comment.