From 848b2dc753a823c2a3f1ca6e2bb4fd4f1d7eab31 Mon Sep 17 00:00:00 2001 From: Nickolay Shmyrev Date: Fri, 17 Dec 2021 22:22:30 +0100 Subject: [PATCH] Expose results in Python --- python/example/batch/asr_server_gpu.py | 85 +++++++++++++++++++++++++ python/example/batch/test_batch.py | 24 +++++-- python/vosk/__init__.py | 10 ++- src/batch_recognizer.cc | 87 +++++++++++++++++++++++--- src/batch_recognizer.h | 6 +- src/json.h | 8 +-- src/vosk_api.cc | 14 ++++- src/vosk_api.h | 8 ++- 8 files changed, 217 insertions(+), 25 deletions(-) create mode 100755 python/example/batch/asr_server_gpu.py diff --git a/python/example/batch/asr_server_gpu.py b/python/example/batch/asr_server_gpu.py new file mode 100755 index 00000000..f58587c9 --- /dev/null +++ b/python/example/batch/asr_server_gpu.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +import json +import os +import sys +import asyncio +import pathlib +import websockets +import logging + +from vosk import BatchRecognizer, GpuInit + + +async def recognize(websocket, path): + global args + global loop + global pool + global rec + global client_cnt + + uid = client_cnt + client_cnt += 1 + + logging.info('Connection %d from %s', uid, websocket.remote_address); + + while True: + + message = await websocket.recv() + + if message == '{"eof" : 1}': + rec.FinishStream(uid) + break + + if isinstance(message, str) and 'config' in message: + continue + + rec.AcceptWaveform(uid, message) + await asyncio.sleep(len(message) / 16000.0 / 2) + res = rec.Result(uid) + if len(res) == 0: + await websocket.send('{ "partial" : "" }') + else: + await websocket.send(res) + + rec.Wait() + res = rec.Result(uid) + await websocket.send(res) + +def start(): + + global rec + global args + global loop + global client_cnt + + # Enable loging if needed + # + # logger = logging.getLogger('websockets') + # logger.setLevel(logging.INFO) + # logger.addHandler(logging.StreamHandler()) + logging.basicConfig(level=logging.INFO) + + args = type('', (), {})() + + args.interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0') + args.port = int(os.environ.get('VOSK_SERVER_PORT', 2700)) + + GpuInit() + + rec = BatchRecognizer() + + client_cnt = 0 + + loop = asyncio.get_event_loop() + + start_server = websockets.serve( + recognize, args.interface, args.port) + + logging.info("Listening on %s:%d", args.interface, args.port) + loop.run_until_complete(start_server) + loop.run_forever() + + +if __name__ == '__main__': + start() diff --git a/python/example/batch/test_batch.py b/python/example/batch/test_batch.py index f93eb6ea..32aa021e 100755 --- a/python/example/batch/test_batch.py +++ b/python/example/batch/test_batch.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 -from vosk import Model, BatchRecognizer, GpuInit, GpuThreadInit import sys import os import wave +from time import sleep + +from vosk import Model, BatchRecognizer, GpuInit GpuInit() -GpuThreadInit() rec = BatchRecognizer() @@ -14,6 +15,7 @@ fds = [open(x.strip(), "rb") for x in fnames] ended = set() while True: + for i, fd in enumerate(fds): if i in ended: continue @@ -21,8 +23,20 @@ if len(data) == 0: rec.FinishStream(i) ended.add(i) - else: - rec.AcceptWaveform(i, data) - rec.Results() + continue + rec.AcceptWaveform(i, data) + + sleep(0.3) + for i, fd in enumerate(fds): + res = rec.Result(i) + print (i, res) + if len(ended) == len(fds): break + +sleep(20) +print ("Done") +for i, fd in enumerate(fds): + res = rec.Result(i) + print (i, res) +rec.Wait() diff --git a/python/vosk/__init__.py b/python/vosk/__init__.py index 964a0ac2..c83a7e34 100644 --- a/python/vosk/__init__.py +++ b/python/vosk/__init__.py @@ -116,8 +116,14 @@ def __del__(self): def AcceptWaveform(self, uid, data): res = _c.vosk_batch_recognizer_accept_waveform(self._handle, uid, data, len(data)) - def Results(self): - return _ffi.string(_c.vosk_batch_recognizer_results(self._handle)).decode('utf-8') + def Result(self, uid): + ptr = _c.vosk_batch_recognizer_front_result(self._handle, uid) + res = _ffi.string(ptr).decode('utf-8') + _c.vosk_batch_recognizer_pop(self._handle, uid) + return res def FinishStream(self, uid): _c.vosk_batch_recognizer_finish_stream(self._handle, uid) + + def Wait(self): + _c.vosk_batch_recognizer_wait(self._handle) diff --git a/src/batch_recognizer.cc b/src/batch_recognizer.cc index 184fb8a2..1773fc0e 100644 --- a/src/batch_recognizer.cc +++ b/src/batch_recognizer.cc @@ -16,6 +16,7 @@ #include "fstext/fstext-utils.h" #include "lat/sausages.h" +#include "json.h" #include @@ -37,12 +38,12 @@ BatchRecognizer::BatchRecognizer() { batched_decoder_config.feature_opts.feature_type = "mfcc"; batched_decoder_config.feature_opts.mfcc_config = "model/conf/mfcc.conf"; batched_decoder_config.feature_opts.ivector_extraction_config = "model/conf/ivector.conf"; - batched_decoder_config.decoder_opts.max_active = 7000; - batched_decoder_config.decoder_opts.default_beam = 13.0; - batched_decoder_config.decoder_opts.lattice_beam = 8.0; + batched_decoder_config.decoder_opts.max_active = 5000; + batched_decoder_config.decoder_opts.default_beam = 10.0; + batched_decoder_config.decoder_opts.lattice_beam = 4.0; batched_decoder_config.compute_opts.acoustic_scale = 1.0; batched_decoder_config.compute_opts.frame_subsampling_factor = 3; - batched_decoder_config.compute_opts.frames_per_chunk = 312; + batched_decoder_config.compute_opts.frames_per_chunk = 51; struct stat buffer; @@ -126,6 +127,47 @@ void BatchRecognizer::FinishStream(uint64_t id) streams_.erase(id); } + +void BatchRecognizer::PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset) +{ + fst::ScaleLattice(fst::GraphLatticeScale(0.9), &clat); + + CompactLattice aligned_lat; + WordAlignLattice(clat, *trans_model_, *winfo_, 0, &aligned_lat); + + MinimumBayesRisk mbr(aligned_lat); + const vector &conf = mbr.GetOneBestConfidences(); + const vector &words = mbr.GetOneBest(); + const vector > × = + mbr.GetOneBestTimes(); + + int size = words.size(); + + json::JSON obj; + stringstream text; + + // Create JSON object + for (int i = 0; i < size; i++) { + json::JSON word; + + word["word"] = word_syms_->Find(words[i]); + word["start"] = times[i].first * 0.03 + offset; + word["end"] = times[i].second * 0.03 + offset; + word["conf"] = conf[i]; + obj["result"].append(word); + + if (i) { + text << " "; + } + text << word_syms_->Find(words[i]); + } + obj["text"] = text.str(); + +// KALDI_LOG << "Result " << id << " " << obj.dump(); + + results_[id].push(obj.dump()); +} + void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len) { bool first = false; @@ -135,7 +177,8 @@ void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len) streams_.insert(id); // Define the callback for results. - cuda_pipeline_->SetBestPathCallback( +#if 0 + cuda_pipeline_->SetBestPathCallback( id, [&, id](const std::string &str, bool partial, bool endpoint_detected) { @@ -151,11 +194,19 @@ void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len) KALDI_LOG << "id #" << id << " : " << str; } }); +#endif cuda_pipeline_->SetLatticeCallback( id, - [&, id](CompactLattice &clat) { - KALDI_LOG << "Got lattice from the stream " << id; - }); + [&, id](SegmentedLatticeCallbackParams& params) { + if (params.results.empty()) { + KALDI_WARN << "Empty result for callback"; + return; + } + CompactLattice *clat = params.results[0].GetLatticeResult(); + BaseFloat offset = params.results[0].GetTimeOffsetSeconds(); + PushLattice(id, *clat, offset); + }, + CudaPipelineResult::RESULT_TYPE_LATTICE); } Vector wave; @@ -167,8 +218,24 @@ void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len) dynamic_batcher_->Push(id, first, false, chunk); } -const char* BatchRecognizer::PullResults() +const char* BatchRecognizer::FrontResult(uint64_t id) +{ + if (results_[id].empty()) { + return ""; + } + return results_[id].front().c_str(); +} + +void BatchRecognizer::Pop(uint64_t id) +{ + if (results_[id].empty()) { + return; + } + results_[id].pop(); +} + +void BatchRecognizer::WaitForCompletion() { dynamic_batcher_->WaitForCompletion(); - return ""; } + diff --git a/src/batch_recognizer.h b/src/batch_recognizer.h index c8045d53..0082a364 100644 --- a/src/batch_recognizer.h +++ b/src/batch_recognizer.h @@ -45,9 +45,12 @@ class BatchRecognizer { void FinishStream(uint64_t id); void AcceptWaveform(uint64_t id, const char *data, int len); - const char* PullResults(); + const char *FrontResult(uint64_t id); + void Pop(uint64_t id); + void WaitForCompletion(); private: + void PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset); kaldi::TransitionModel *trans_model_ = nullptr; kaldi::nnet3::AmNnetSimple *nnet_ = nullptr; @@ -64,6 +67,7 @@ class BatchRecognizer { std::set streams_; + std::map > results_; // Rescoring fst::ArcMapFst > *lm_to_subtract_ = nullptr; diff --git a/src/json.h b/src/json.h index 463912ec..2159392b 100644 --- a/src/json.h +++ b/src/json.h @@ -424,7 +424,7 @@ class JSON Class Type = Class::Null; }; -JSON Array() { +inline JSON Array() { return JSON::Make( JSON::Class::Array ); } @@ -435,11 +435,11 @@ JSON Array( T... args ) { return arr; } -JSON Object() { +inline JSON Object() { return JSON::Make( JSON::Class::Object ); } -std::ostream& operator<<( std::ostream &os, const JSON &json ) { +inline std::ostream& operator<<( std::ostream &os, const JSON &json ) { os << json.dump(); return os; } @@ -647,7 +647,7 @@ namespace { } } -JSON JSON::Load( const string &str ) { +inline JSON JSON::Load( const string &str ) { size_t offset = 0; return parse_next( str, offset ); } diff --git a/src/vosk_api.cc b/src/vosk_api.cc index a53dbf87..b2a7a6a4 100644 --- a/src/vosk_api.cc +++ b/src/vosk_api.cc @@ -203,7 +203,17 @@ void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer, int id ((BatchRecognizer *)recognizer)->FinishStream(id); } -const char *vosk_batch_recognizer_results(VoskBatchRecognizer *recognizer) +const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer, int id) { - return ((BatchRecognizer *)recognizer)->PullResults(); + return ((BatchRecognizer *)recognizer)->FrontResult(id); +} + +void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id) +{ + return ((BatchRecognizer *)recognizer)->Pop(id); +} + +void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer) +{ + ((BatchRecognizer *)recognizer)->WaitForCompletion(); } diff --git a/src/vosk_api.h b/src/vosk_api.h index c5b92f1c..7177009c 100644 --- a/src/vosk_api.h +++ b/src/vosk_api.h @@ -305,7 +305,13 @@ void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, int void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer, int id); /** Return results */ -const char *vosk_batch_recognizer_results(VoskBatchRecognizer *recognizer); +const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer, int id); + +/** Release and free first retrieved result */ +void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id); + +/** Wait for the processing */ +void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer); #ifdef __cplusplus }