diff --git a/python/example/test_gpu_batch.py b/python/example/test_gpu_batch.py new file mode 100755 index 00000000..3a65bda8 --- /dev/null +++ b/python/example/test_gpu_batch.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +import sys +import os +import wave +from time import sleep +import json +from timeit import default_timer as timer + + +from vosk import Model, BatchRecognizer, GpuInit + +GpuInit() + +rec = BatchRecognizer() + +fnames = open("tedlium.list").readlines() +fds = [open(x.strip(), "rb") for x in fnames] +uids = [fname.strip().split('/')[-1][:-4] for fname in fnames] +results = [""] * len(fnames) +ended = set() +tot_samples = 0 + +start_time = timer() + +while True: + + # Feed in the data + for i, fd in enumerate(fds): + if i in ended: + continue + data = fd.read(16000) + if len(data) == 0: + rec.FinishStream(i) + ended.add(i) + continue + rec.AcceptWaveform(i, data) + tot_samples += len(data) + + # Wait for results from CUDA + rec.Wait() + + # Retrieve and add results + for i, fd in enumerate(fds): + res = rec.Result(i) + if len(res) != 0: + results[i] = results[i] + " " + json.loads(res)['text'] + + if len(ended) == len(fds): + break + +end_time = timer() + +for i in range(len(results)): + print (uids[i], results[i].strip()) + +print ("Processed %d seconds of audio in %d seconds (%f xRT)" % (tot_samples / 16000.0 / 2, end_time - start_time, + (tot_samples / 16000.0 / 2 / (end_time - start_time))), file=sys.stderr) diff --git a/python/vosk/__init__.py b/python/vosk/__init__.py index cf39a472..0e60c2ba 100644 --- a/python/vosk/__init__.py +++ b/python/vosk/__init__.py @@ -101,3 +101,32 @@ def GpuInit(): def GpuThreadInit(): _c.vosk_gpu_thread_init() + +class BatchRecognizer(object): + + def __init__(self, *args): + self._handle = _c.vosk_batch_recognizer_new() + + if self._handle == _ffi.NULL: + raise Exception("Failed to create a recognizer") + + def __del__(self): + _c.vosk_batch_recognizer_free(self._handle) + + def AcceptWaveform(self, uid, data): + res = _c.vosk_batch_recognizer_accept_waveform(self._handle, uid, data, len(data)) + + def Result(self, uid): + ptr = _c.vosk_batch_recognizer_front_result(self._handle, uid) + res = _ffi.string(ptr).decode('utf-8') + _c.vosk_batch_recognizer_pop(self._handle, uid) + return res + + def FinishStream(self, uid): + _c.vosk_batch_recognizer_finish_stream(self._handle, uid) + + def Wait(self): + _c.vosk_batch_recognizer_wait(self._handle) + + def GetPendingChunks(self, uid): + return _c.vosk_batch_recognizer_get_pending_chunks(self._handle, uid) diff --git a/src/Makefile b/src/Makefile index 54e96ca7..9965db65 100644 --- a/src/Makefile +++ b/src/Makefile @@ -18,14 +18,14 @@ EXTRA_LDFLAGS?= OUTDIR?=. VOSK_SOURCES= \ - kaldi_recognizer.cc \ + recognizer.cc \ language_model.cc \ model.cc \ spk_model.cc \ vosk_api.cc VOSK_HEADERS= \ - kaldi_recognizer.h \ + recognizer.h \ language_model.h \ model.h \ spk_model.h \ @@ -39,13 +39,13 @@ LIBS= \ $(KALDI_ROOT)/src/decoder/kaldi-decoder.a \ $(KALDI_ROOT)/src/ivector/kaldi-ivector.a \ $(KALDI_ROOT)/src/gmm/kaldi-gmm.a \ - $(KALDI_ROOT)/src/nnet3/kaldi-nnet3.a \ $(KALDI_ROOT)/src/tree/kaldi-tree.a \ $(KALDI_ROOT)/src/feat/kaldi-feat.a \ $(KALDI_ROOT)/src/lat/kaldi-lat.a \ $(KALDI_ROOT)/src/lm/kaldi-lm.a \ $(KALDI_ROOT)/src/rnnlm/kaldi-rnnlm.a \ $(KALDI_ROOT)/src/hmm/kaldi-hmm.a \ + $(KALDI_ROOT)/src/nnet3/kaldi-nnet3.a \ $(KALDI_ROOT)/src/transform/kaldi-transform.a \ $(KALDI_ROOT)/src/cudamatrix/kaldi-cudamatrix.a \ $(KALDI_ROOT)/src/matrix/kaldi-matrix.a \ @@ -66,7 +66,7 @@ ifeq ($(HAVE_OPENBLAS_CLAPACK), 1) endif ifeq ($(HAVE_MKL), 1) - CFLAGS += -I$(MKL_ROOT)/include + CFLAGS += -DHAVE_MKL=1 -I$(MKL_ROOT)/include LIBS += -L$(MKL_ROOT)/lib/intel64 -Wl,-rpath=$(MKL_ROOT)/lib/intel64 -lmkl_rt -lmkl_intel_lp64 -lmkl_core -lmkl_sequential endif @@ -75,8 +75,16 @@ ifeq ($(HAVE_ACCELERATE), 1) endif ifeq ($(HAVE_CUDA), 1) + VOSK_SOURCES += batch_recognizer.cc + VOSK_HEADERS += batch_recognizer.h + CFLAGS+=-DHAVE_CUDA=1 -I$(CUDA_ROOT)/include - LIBS+=-L$(CUDA_ROOT)/lib64 -lcuda -lcublas -lcusparse -lcudart -lcurand -lcufft -lcusolver -lnvToolsExt + + LIBS := \ + $(KALDI_ROOT)/src/cudadecoder/kaldi-cudadecoder.a \ + $(KALDI_ROOT)/src/cudafeat/kaldi-cudafeat.a \ + $(LIBS) \ + -L$(CUDA_ROOT)/lib64 -lcuda -lcublas -lcusparse -lcudart -lcurand -lcufft -lcusolver -lnvToolsExt endif all: $(OUTDIR)/libvosk.$(EXT) diff --git a/src/batch_recognizer.cc b/src/batch_recognizer.cc new file mode 100644 index 00000000..3337ee10 --- /dev/null +++ b/src/batch_recognizer.cc @@ -0,0 +1,246 @@ +// Copyright 2019-2020 Alpha Cephei Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "batch_recognizer.h" + +#include "fstext/fstext-utils.h" +#include "lat/sausages.h" +#include "json.h" + +#include + +using namespace fst; +using namespace kaldi::nnet3; +using CorrelationID = CudaOnlinePipelineDynamicBatcher::CorrelationID; + +BatchRecognizer::BatchRecognizer() { + BatchedThreadedNnet3CudaOnlinePipelineConfig batched_decoder_config; + + kaldi::ParseOptions po("something"); + batched_decoder_config.Register(&po); + po.ReadConfigFile("model/conf/model.conf"); + + batched_decoder_config.num_worker_threads = -1; + batched_decoder_config.max_batch_size = 200; + batched_decoder_config.reset_on_endpoint = true; + batched_decoder_config.use_gpu_feature_extraction = true; + + batched_decoder_config.feature_opts.feature_type = "mfcc"; + batched_decoder_config.feature_opts.mfcc_config = "model/conf/mfcc.conf"; + batched_decoder_config.feature_opts.ivector_extraction_config = "model/conf/ivector.conf"; + batched_decoder_config.decoder_opts.max_active = 7000; + batched_decoder_config.decoder_opts.default_beam = 13.0; + batched_decoder_config.decoder_opts.lattice_beam = 6.0; + batched_decoder_config.compute_opts.acoustic_scale = 1.0; + batched_decoder_config.compute_opts.frame_subsampling_factor = 3; + batched_decoder_config.compute_opts.frames_per_chunk = 180; + + struct stat buffer; + + string nnet3_rxfilename_ = "model/am/final.mdl"; + string hclg_fst_rxfilename_ = "model/graph/HCLG.fst"; + string word_syms_rxfilename_ = "model/graph/words.txt"; + string winfo_rxfilename_ = "model/graph/phones/word_boundary.int"; + string std_fst_rxfilename_ = "model/rescore/G.fst"; + string carpa_rxfilename_ = "model/rescore/G.carpa"; + + trans_model_ = new kaldi::TransitionModel(); + nnet_ = new kaldi::nnet3::AmNnetSimple(); + { + bool binary; + kaldi::Input ki(nnet3_rxfilename_, &binary); + trans_model_->Read(ki.Stream(), binary); + nnet_->Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(nnet_->GetNnet())); + SetDropoutTestMode(true, &(nnet_->GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(nnet_->GetNnet())); + } + + if (stat(hclg_fst_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading HCLG from " << hclg_fst_rxfilename_; + hclg_fst_ = fst::ReadFstKaldiGeneric(hclg_fst_rxfilename_); + } + + KALDI_LOG << "Loading words from " << word_syms_rxfilename_; + if (!(word_syms_ = fst::SymbolTable::ReadText(word_syms_rxfilename_))) { + KALDI_ERR << "Could not read symbol table from file " + << word_syms_rxfilename_; + } + KALDI_ASSERT(word_syms_); + + if (stat(winfo_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading winfo " << winfo_rxfilename_; + kaldi::WordBoundaryInfoNewOpts opts; + winfo_ = new kaldi::WordBoundaryInfo(opts, winfo_rxfilename_); + } + + if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading subtract G.fst model from " << std_fst_rxfilename_; + graph_lm_fst_ = fst::ReadAndPrepareLmFst(std_fst_rxfilename_); + KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_; + ReadKaldiObject(carpa_rxfilename_, &const_arpa_); + } + + + + cuda_pipeline_ = new BatchedThreadedNnet3CudaOnlinePipeline + (batched_decoder_config, *hclg_fst_, *nnet_, *trans_model_); + cuda_pipeline_->SetSymbolTable(*word_syms_); + + CudaOnlinePipelineDynamicBatcherConfig dynamic_batcher_config; + dynamic_batcher_ = new CudaOnlinePipelineDynamicBatcher(dynamic_batcher_config, + *cuda_pipeline_); +} + +BatchRecognizer::~BatchRecognizer() { + + delete trans_model_; + delete nnet_; + delete word_syms_; + delete winfo_; + delete hclg_fst_; + delete graph_lm_fst_; + + delete lm_to_subtract_; + delete carpa_to_add_; + delete carpa_to_add_scale_; + + delete cuda_pipeline_; + delete dynamic_batcher_; +} + +void BatchRecognizer::FinishStream(uint64_t id) +{ + Vector wave; + SubVector chunk(wave.Data(), 0); + dynamic_batcher_->Push(id, false, true, chunk); + streams_.erase(id); +} + + +void BatchRecognizer::PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset) +{ + fst::ScaleLattice(fst::GraphLatticeScale(0.9), &clat); + + CompactLattice aligned_lat; + WordAlignLattice(clat, *trans_model_, *winfo_, 0, &aligned_lat); + + MinimumBayesRisk mbr(aligned_lat); + const vector &conf = mbr.GetOneBestConfidences(); + const vector &words = mbr.GetOneBest(); + const vector > × = + mbr.GetOneBestTimes(); + + int size = words.size(); + + json::JSON obj; + stringstream text; + + // Create JSON object + for (int i = 0; i < size; i++) { + json::JSON word; + + word["word"] = word_syms_->Find(words[i]); + word["start"] = times[i].first * 0.03 + offset; + word["end"] = times[i].second * 0.03 + offset; + word["conf"] = conf[i]; + obj["result"].append(word); + + if (i) { + text << " "; + } + text << word_syms_->Find(words[i]); + } + obj["text"] = text.str(); + +// KALDI_LOG << "Result " << id << " " << obj.dump(); + + results_[id].push(obj.dump()); +} + +void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len) +{ + bool first = false; + + if (streams_.find(id) == streams_.end()) { + first = true; + streams_.insert(id); + + // Define the callback for results. +#if 0 + cuda_pipeline_->SetBestPathCallback( + id, + [&, id](const std::string &str, bool partial, + bool endpoint_detected) { + if (partial) { + KALDI_LOG << "id #" << id << " [partial] : " << str << ":"; + } + + if (endpoint_detected) { + KALDI_LOG << "id #" << id << " [endpoint detected]"; + } + + if (!partial) { + KALDI_LOG << "id #" << id << " : " << str; + } + }); +#endif + cuda_pipeline_->SetLatticeCallback( + id, + [&, id](SegmentedLatticeCallbackParams& params) { + if (params.results.empty()) { + KALDI_WARN << "Empty result for callback"; + return; + } + CompactLattice *clat = params.results[0].GetLatticeResult(); + BaseFloat offset = params.results[0].GetTimeOffsetSeconds(); + PushLattice(id, *clat, offset); + }, + CudaPipelineResult::RESULT_TYPE_LATTICE); + } + + Vector wave; + wave.Resize(len / 2, kUndefined); + for (int i = 0; i < len / 2; i++) + wave(i) = *(((short *)data) + i); + SubVector chunk(wave.Data(), wave.Dim()); + + dynamic_batcher_->Push(id, first, false, chunk); +} + +const char* BatchRecognizer::FrontResult(uint64_t id) +{ + if (results_[id].empty()) { + return ""; + } + return results_[id].front().c_str(); +} + +void BatchRecognizer::Pop(uint64_t id) +{ + if (results_[id].empty()) { + return; + } + results_[id].pop(); +} + +void BatchRecognizer::WaitForCompletion() +{ + dynamic_batcher_->WaitForCompletion(); +} + +int BatchRecognizer::GetPendingChunks(uint64_t id) +{ + return dynamic_batcher_->GetPendingChunks(id); +} diff --git a/src/batch_recognizer.h b/src/batch_recognizer.h new file mode 100644 index 00000000..f26dd54b --- /dev/null +++ b/src/batch_recognizer.h @@ -0,0 +1,81 @@ +// Copyright 2019 Alpha Cephei Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef VOSK_GPU_RECOGNIZER_H +#define VOSK_GPU_RECOGNIZER_H + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "fstext/fstext-lib.h" +#include "fstext/fstext-utils.h" +#include "decoder/lattice-faster-decoder.h" +#include "feat/feature-mfcc.h" +#include "lat/kaldi-lattice.h" +#include "lat/word-align-lattice.h" +#include "lat/compose-lattice-pruned.h" +#include "nnet3/am-nnet-simple.h" +#include "nnet3/nnet-am-decodable-simple.h" +#include "nnet3/nnet-utils.h" + +#include "cudadecoder/cuda-online-pipeline-dynamic-batcher.h" +#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h" +#include "cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h" +#include "cudadecoder/cuda-pipeline-common.h" + +#include "model.h" + +using namespace kaldi; +using namespace kaldi::cuda_decoder; + +class BatchRecognizer { + public: + BatchRecognizer(); + ~BatchRecognizer(); + + void FinishStream(uint64_t id); + void AcceptWaveform(uint64_t id, const char *data, int len); + const char *FrontResult(uint64_t id); + void Pop(uint64_t id); + void WaitForCompletion(); + int GetPendingChunks(uint64_t id); + + private: + void PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset); + + kaldi::TransitionModel *trans_model_ = nullptr; + kaldi::nnet3::AmNnetSimple *nnet_ = nullptr; + const fst::SymbolTable *word_syms_ = nullptr; + + fst::Fst *hclg_fst_ = nullptr; + kaldi::WordBoundaryInfo *winfo_ = nullptr; + + fst::VectorFst *graph_lm_fst_ = nullptr; + kaldi::ConstArpaLm const_arpa_; + + BatchedThreadedNnet3CudaOnlinePipeline *cuda_pipeline_ = nullptr; + CudaOnlinePipelineDynamicBatcher *dynamic_batcher_ = nullptr; + + + std::set streams_; + std::map > results_; + + // Rescoring + fst::ArcMapFst > *lm_to_subtract_ = nullptr; + kaldi::ConstArpaLmDeterministicFst *carpa_to_add_ = nullptr; + fst::ScaleDeterministicOnDemandFst *carpa_to_add_scale_ = nullptr; + + float sample_frequency_; +}; + +#endif /* VOSK_GPU_RECOGNIZER_H */ diff --git a/src/json.h b/src/json.h index 463912ec..2159392b 100644 --- a/src/json.h +++ b/src/json.h @@ -424,7 +424,7 @@ class JSON Class Type = Class::Null; }; -JSON Array() { +inline JSON Array() { return JSON::Make( JSON::Class::Array ); } @@ -435,11 +435,11 @@ JSON Array( T... args ) { return arr; } -JSON Object() { +inline JSON Object() { return JSON::Make( JSON::Class::Object ); } -std::ostream& operator<<( std::ostream &os, const JSON &json ) { +inline std::ostream& operator<<( std::ostream &os, const JSON &json ) { os << json.dump(); return os; } @@ -647,7 +647,7 @@ namespace { } } -JSON JSON::Load( const string &str ) { +inline JSON JSON::Load( const string &str ) { size_t offset = 0; return parse_next( str, offset ); } diff --git a/src/model.cc b/src/model.cc index 8b5e12cc..c83d07a8 100644 --- a/src/model.cc +++ b/src/model.cc @@ -241,9 +241,9 @@ void Model::ReadDataFiles() SetDropoutTestMode(true, &(nnet_->GetNnet())); nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(nnet_->GetNnet())); } + decodable_info_ = new nnet3::DecodableNnetSimpleLoopedInfo(decodable_opts_, nnet_); - if (stat(final_ie_rxfilename_.c_str(), &buffer) == 0) { KALDI_LOG << "Loading i-vector extractor from " << final_ie_rxfilename_; diff --git a/src/model.h b/src/model.h index d5feedd0..c36a96aa 100644 --- a/src/model.h +++ b/src/model.h @@ -36,7 +36,8 @@ using namespace kaldi; using namespace std; -class KaldiRecognizer; +class Recognizer; +class BatchRecognizer; class Model { @@ -52,7 +53,8 @@ class Model { void ConfigureV2(); void ReadDataFiles(); - friend class KaldiRecognizer; + friend class Recognizer; + friend class BatchRecognizer; string model_path_str_; string nnet3_rxfilename_; diff --git a/src/kaldi_recognizer.cc b/src/recognizer.cc similarity index 93% rename from src/kaldi_recognizer.cc rename to src/recognizer.cc index 86cf9bdd..f25ff0ee 100644 --- a/src/kaldi_recognizer.cc +++ b/src/recognizer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "kaldi_recognizer.h" +#include "recognizer.h" #include "json.h" #include "fstext/fstext-utils.h" #include "lat/sausages.h" @@ -21,7 +21,7 @@ using namespace fst; using namespace kaldi::nnet3; -KaldiRecognizer::KaldiRecognizer(Model *model, float sample_frequency) : model_(model), spk_model_(0), sample_frequency_(sample_frequency) { +Recognizer::Recognizer(Model *model, float sample_frequency) : model_(model), spk_model_(0), sample_frequency_(sample_frequency) { model_->Ref(); @@ -46,7 +46,7 @@ KaldiRecognizer::KaldiRecognizer(Model *model, float sample_frequency) : model_( InitRescoring(); } -KaldiRecognizer::KaldiRecognizer(Model *model, float sample_frequency, char const *grammar) : model_(model), spk_model_(0), sample_frequency_(sample_frequency) +Recognizer::Recognizer(Model *model, float sample_frequency, char const *grammar) : model_(model), spk_model_(0), sample_frequency_(sample_frequency) { model_->Ref(); @@ -107,7 +107,7 @@ KaldiRecognizer::KaldiRecognizer(Model *model, float sample_frequency, char cons InitRescoring(); } -KaldiRecognizer::KaldiRecognizer(Model *model, float sample_frequency, SpkModel *spk_model) : model_(model), spk_model_(spk_model), sample_frequency_(sample_frequency) { +Recognizer::Recognizer(Model *model, float sample_frequency, SpkModel *spk_model) : model_(model), spk_model_(spk_model), sample_frequency_(sample_frequency) { model_->Ref(); spk_model->Ref(); @@ -135,7 +135,7 @@ KaldiRecognizer::KaldiRecognizer(Model *model, float sample_frequency, SpkModel InitRescoring(); } -KaldiRecognizer::~KaldiRecognizer() { +Recognizer::~Recognizer() { delete decoder_; delete feature_pipeline_; delete silence_weighting_; @@ -155,7 +155,7 @@ KaldiRecognizer::~KaldiRecognizer() { spk_model_->Unref(); } -void KaldiRecognizer::InitState() +void Recognizer::InitState() { frame_offset_ = 0; samples_processed_ = 0; @@ -164,7 +164,7 @@ void KaldiRecognizer::InitState() state_ = RECOGNIZER_INITIALIZED; } -void KaldiRecognizer::InitRescoring() +void Recognizer::InitRescoring() { if (model_->graph_lm_fst_) { @@ -185,7 +185,7 @@ void KaldiRecognizer::InitRescoring() } } -void KaldiRecognizer::CleanUp() +void Recognizer::CleanUp() { delete silence_weighting_; silence_weighting_ = new kaldi::OnlineSilenceWeighting(*model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); @@ -223,7 +223,7 @@ void KaldiRecognizer::CleanUp() } } -void KaldiRecognizer::UpdateSilenceWeights() +void Recognizer::UpdateSilenceWeights() { if (silence_weighting_->Active() && feature_pipeline_->NumFramesReady() > 0 && feature_pipeline_->IvectorFeature() != nullptr) { @@ -236,17 +236,17 @@ void KaldiRecognizer::UpdateSilenceWeights() } } -void KaldiRecognizer::SetMaxAlternatives(int max_alternatives) +void Recognizer::SetMaxAlternatives(int max_alternatives) { max_alternatives_ = max_alternatives; } -void KaldiRecognizer::SetWords(bool words) +void Recognizer::SetWords(bool words) { words_ = words; } -void KaldiRecognizer::SetSpkModel(SpkModel *spk_model) +void Recognizer::SetSpkModel(SpkModel *spk_model) { if (state_ == RECOGNIZER_RUNNING) { KALDI_ERR << "Can't add speaker model to already running recognizer"; @@ -257,7 +257,7 @@ void KaldiRecognizer::SetSpkModel(SpkModel *spk_model) spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); } -bool KaldiRecognizer::AcceptWaveform(const char *data, int len) +bool Recognizer::AcceptWaveform(const char *data, int len) { Vector wave; wave.Resize(len / 2, kUndefined); @@ -266,7 +266,7 @@ bool KaldiRecognizer::AcceptWaveform(const char *data, int len) return AcceptWaveform(wave); } -bool KaldiRecognizer::AcceptWaveform(const short *sdata, int len) +bool Recognizer::AcceptWaveform(const short *sdata, int len) { Vector wave; wave.Resize(len, kUndefined); @@ -275,7 +275,7 @@ bool KaldiRecognizer::AcceptWaveform(const short *sdata, int len) return AcceptWaveform(wave); } -bool KaldiRecognizer::AcceptWaveform(const float *fdata, int len) +bool Recognizer::AcceptWaveform(const float *fdata, int len) { Vector wave; wave.Resize(len, kUndefined); @@ -284,7 +284,7 @@ bool KaldiRecognizer::AcceptWaveform(const float *fdata, int len) return AcceptWaveform(wave); } -bool KaldiRecognizer::AcceptWaveform(Vector &wdata) +bool Recognizer::AcceptWaveform(Vector &wdata) { // Cleanup if we finalized previous utterance or the whole feature pipeline if (!(state_ == RECOGNIZER_RUNNING || state_ == RECOGNIZER_INITIALIZED)) { @@ -343,7 +343,7 @@ static void RunNnetComputation(const MatrixBase &features, #define MIN_SPK_FEATS 50 -bool KaldiRecognizer::GetSpkVector(Vector &out_xvector, int *num_spk_frames) +bool Recognizer::GetSpkVector(Vector &out_xvector, int *num_spk_frames) { vector nonsilence_frames; if (silence_weighting_->Active() && feature_pipeline_->NumFramesReady() > 0) { @@ -409,7 +409,7 @@ bool KaldiRecognizer::GetSpkVector(Vector &out_xvector, int *num_spk_ } -const char *KaldiRecognizer::MbrResult(CompactLattice &rlat) +const char *Recognizer::MbrResult(CompactLattice &rlat) { CompactLattice aligned_lat; if (model_->winfo_) { @@ -523,7 +523,7 @@ static bool CompactLatticeToWordAlignmentWeight(const CompactLattice &clat, } -const char *KaldiRecognizer::NbestResult(CompactLattice &clat) +const char *Recognizer::NbestResult(CompactLattice &clat) { Lattice lat; Lattice nbest_lat; @@ -584,7 +584,7 @@ const char *KaldiRecognizer::NbestResult(CompactLattice &clat) return StoreReturn(obj.dump()); } -const char* KaldiRecognizer::GetResult() +const char* Recognizer::GetResult() { if (decoder_->NumFramesDecoded() == 0) { return StoreEmptyReturn(); @@ -645,7 +645,7 @@ const char* KaldiRecognizer::GetResult() } -const char* KaldiRecognizer::PartialResult() +const char* Recognizer::PartialResult() { if (state_ != RECOGNIZER_RUNNING) { return StoreEmptyReturn(); @@ -676,7 +676,7 @@ const char* KaldiRecognizer::PartialResult() return StoreReturn(res.dump()); } -const char* KaldiRecognizer::Result() +const char* Recognizer::Result() { if (state_ != RECOGNIZER_RUNNING) { return StoreEmptyReturn(); @@ -686,7 +686,7 @@ const char* KaldiRecognizer::Result() return GetResult(); } -const char* KaldiRecognizer::FinalResult() +const char* Recognizer::FinalResult() { if (state_ != RECOGNIZER_RUNNING) { return StoreEmptyReturn(); @@ -714,7 +714,7 @@ const char* KaldiRecognizer::FinalResult() return last_result_.c_str(); } -void KaldiRecognizer::Reset() +void Recognizer::Reset() { if (state_ == RECOGNIZER_RUNNING) { decoder_->FinalizeDecoding(); @@ -723,7 +723,7 @@ void KaldiRecognizer::Reset() state_ = RECOGNIZER_ENDPOINT; } -const char *KaldiRecognizer::StoreEmptyReturn() +const char *Recognizer::StoreEmptyReturn() { if (!max_alternatives_) { return StoreReturn("{\"text\": \"\"}"); @@ -733,7 +733,7 @@ const char *KaldiRecognizer::StoreEmptyReturn() } // Store result in recognizer and return as const string -const char *KaldiRecognizer::StoreReturn(const string &res) +const char *Recognizer::StoreReturn(const string &res) { last_result_ = res; return last_result_.c_str(); diff --git a/src/kaldi_recognizer.h b/src/recognizer.h similarity index 91% rename from src/kaldi_recognizer.h rename to src/recognizer.h index 934e237e..e5a733d1 100644 --- a/src/kaldi_recognizer.h +++ b/src/recognizer.h @@ -33,19 +33,19 @@ using namespace kaldi; -enum KaldiRecognizerState { +enum RecognizerState { RECOGNIZER_INITIALIZED, RECOGNIZER_RUNNING, RECOGNIZER_ENDPOINT, RECOGNIZER_FINALIZED }; -class KaldiRecognizer { +class Recognizer { public: - KaldiRecognizer(Model *model, float sample_frequency); - KaldiRecognizer(Model *model, float sample_frequency, SpkModel *spk_model); - KaldiRecognizer(Model *model, float sample_frequency, char const *grammar); - ~KaldiRecognizer(); + Recognizer(Model *model, float sample_frequency); + Recognizer(Model *model, float sample_frequency, SpkModel *spk_model); + Recognizer(Model *model, float sample_frequency, char const *grammar); + ~Recognizer(); void SetMaxAlternatives(int max_alternatives); void SetSpkModel(SpkModel *spk_model); void SetWords(bool words); @@ -101,7 +101,7 @@ class KaldiRecognizer { int64 samples_processed_; int64 samples_round_start_; - KaldiRecognizerState state_; + RecognizerState state_; string last_result_; }; diff --git a/src/spk_model.h b/src/spk_model.h index 07cbd4b0..9a76c62a 100644 --- a/src/spk_model.h +++ b/src/spk_model.h @@ -22,7 +22,7 @@ using namespace kaldi; -class KaldiRecognizer; +class Recognizer; class SpkModel { @@ -32,7 +32,7 @@ class SpkModel { void Unref(); protected: - friend class KaldiRecognizer; + friend class Recognizer; ~SpkModel() {}; kaldi::nnet3::Nnet speaker_nnet; diff --git a/src/vosk_api.cc b/src/vosk_api.cc index ba76a73b..65356038 100644 --- a/src/vosk_api.cc +++ b/src/vosk_api.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "vosk_api.h" -#include "kaldi_recognizer.h" + +#include "recognizer.h" #include "model.h" #include "spk_model.h" #if HAVE_CUDA #include "cudamatrix/cu-device.h" +#include "batch_recognizer.h" #endif #include @@ -67,7 +69,7 @@ void vosk_spk_model_free(VoskSpkModel *model) VoskRecognizer *vosk_recognizer_new(VoskModel *model, float sample_rate) { try { - return (VoskRecognizer *)new KaldiRecognizer((Model *)model, sample_rate); + return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate); } catch (...) { return nullptr; } @@ -76,7 +78,7 @@ VoskRecognizer *vosk_recognizer_new(VoskModel *model, float sample_rate) VoskRecognizer *vosk_recognizer_new_spk(VoskModel *model, float sample_rate, VoskSpkModel *spk_model) { try { - return (VoskRecognizer *)new KaldiRecognizer((Model *)model, sample_rate, (SpkModel *)spk_model); + return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate, (SpkModel *)spk_model); } catch (...) { return nullptr; } @@ -85,7 +87,7 @@ VoskRecognizer *vosk_recognizer_new_spk(VoskModel *model, float sample_rate, Vos VoskRecognizer *vosk_recognizer_new_grm(VoskModel *model, float sample_rate, const char *grammar) { try { - return (VoskRecognizer *)new KaldiRecognizer((Model *)model, sample_rate, grammar); + return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate, grammar); } catch (...) { return nullptr; } @@ -93,12 +95,12 @@ VoskRecognizer *vosk_recognizer_new_grm(VoskModel *model, float sample_rate, con void vosk_recognizer_set_max_alternatives(VoskRecognizer *recognizer, int max_alternatives) { - ((KaldiRecognizer *)recognizer)->SetMaxAlternatives(max_alternatives); + ((Recognizer *)recognizer)->SetMaxAlternatives(max_alternatives); } void vosk_recognizer_set_words(VoskRecognizer *recognizer, int words) { - ((KaldiRecognizer *)recognizer)->SetWords((bool)words); + ((Recognizer *)recognizer)->SetWords((bool)words); } void vosk_recognizer_set_spk_model(VoskRecognizer *recognizer, VoskSpkModel *spk_model) @@ -106,13 +108,13 @@ void vosk_recognizer_set_spk_model(VoskRecognizer *recognizer, VoskSpkModel *spk if (recognizer == nullptr || spk_model == nullptr) { return; } - ((KaldiRecognizer *)recognizer)->SetSpkModel((SpkModel *)spk_model); + ((Recognizer *)recognizer)->SetSpkModel((SpkModel *)spk_model); } int vosk_recognizer_accept_waveform(VoskRecognizer *recognizer, const char *data, int length) { try { - return ((KaldiRecognizer *)(recognizer))->AcceptWaveform(data, length); + return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); } catch (...) { return -1; } @@ -121,7 +123,7 @@ int vosk_recognizer_accept_waveform(VoskRecognizer *recognizer, const char *data int vosk_recognizer_accept_waveform_s(VoskRecognizer *recognizer, const short *data, int length) { try { - return ((KaldiRecognizer *)(recognizer))->AcceptWaveform(data, length); + return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); } catch (...) { return -1; } @@ -130,7 +132,7 @@ int vosk_recognizer_accept_waveform_s(VoskRecognizer *recognizer, const short *d int vosk_recognizer_accept_waveform_f(VoskRecognizer *recognizer, const float *data, int length) { try { - return ((KaldiRecognizer *)(recognizer))->AcceptWaveform(data, length); + return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); } catch (...) { return -1; } @@ -138,27 +140,27 @@ int vosk_recognizer_accept_waveform_f(VoskRecognizer *recognizer, const float *d const char *vosk_recognizer_result(VoskRecognizer *recognizer) { - return ((KaldiRecognizer *)recognizer)->Result(); + return ((Recognizer *)recognizer)->Result(); } const char *vosk_recognizer_partial_result(VoskRecognizer *recognizer) { - return ((KaldiRecognizer *)recognizer)->PartialResult(); + return ((Recognizer *)recognizer)->PartialResult(); } const char *vosk_recognizer_final_result(VoskRecognizer *recognizer) { - return ((KaldiRecognizer *)recognizer)->FinalResult(); + return ((Recognizer *)recognizer)->FinalResult(); } void vosk_recognizer_reset(VoskRecognizer *recognizer) { - ((KaldiRecognizer *)recognizer)->Reset(); + ((Recognizer *)recognizer)->Reset(); } void vosk_recognizer_free(VoskRecognizer *recognizer) { - delete (KaldiRecognizer *)(recognizer); + delete (Recognizer *)(recognizer); } void vosk_set_log_level(int log_level) @@ -169,6 +171,8 @@ void vosk_set_log_level(int log_level) void vosk_gpu_init() { #if HAVE_CUDA +// kaldi::CuDevice::EnableTensorCores(true); +// kaldi::CuDevice::EnableTf32Compute(true); kaldi::CuDevice::Instantiate().SelectGpuId("yes"); kaldi::CuDevice::Instantiate().AllowMultithreading(); #endif @@ -180,3 +184,65 @@ void vosk_gpu_thread_init() kaldi::CuDevice::Instantiate(); #endif } + +VoskBatchRecognizer *vosk_batch_recognizer_new() +{ +#if HAVE_CUDA + return (VoskBatchRecognizer *)(new BatchRecognizer()); +#else + return NULL; +#endif +} + +void vosk_batch_recognizer_free(VoskBatchRecognizer *recognizer) +{ +#if HAVE_CUDA + delete ((BatchRecognizer *)recognizer); +#endif +} + +void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, int id, const char *data, int length) +{ +#if HAVE_CUDA + ((BatchRecognizer *)recognizer)->AcceptWaveform(id, data, length); +#endif +} + +void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer, int id) +{ +#if HAVE_CUDA + ((BatchRecognizer *)recognizer)->FinishStream(id); +#endif +} + +const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer, int id) +{ +#if HAVE_CUDA + return ((BatchRecognizer *)recognizer)->FrontResult(id); +#else + return NULL; +#endif +} + +void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id) +{ +#if HAVE_CUDA + ((BatchRecognizer *)recognizer)->Pop(id); +#endif +} + +void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer) +{ +#if HAVE_CUDA + ((BatchRecognizer *)recognizer)->WaitForCompletion(); +#endif +} + +int vosk_batch_recognizer_get_pending_chunks(VoskBatchRecognizer *recognizer, int id) +{ +#if HAVE_CUDA + return ((BatchRecognizer *)recognizer)->GetPendingChunks(id); +#else + return 0; +#endif +} diff --git a/src/vosk_api.h b/src/vosk_api.h index 7636caa6..f6a981cb 100644 --- a/src/vosk_api.h +++ b/src/vosk_api.h @@ -39,6 +39,10 @@ typedef struct VoskSpkModel VoskSpkModel; * speaker information and so on */ typedef struct VoskRecognizer VoskRecognizer; +/** + * Batch recognizer object + */ +typedef struct VoskBatchRecognizer VoskBatchRecognizer; /** Loads model data from the file and returns the model object * @@ -285,6 +289,33 @@ void vosk_gpu_init(); */ void vosk_gpu_thread_init(); +/** Creates the batch recognizer object + * + * @returns recognizer object or NULL if problem occured */ +VoskBatchRecognizer *vosk_batch_recognizer_new(); + +/** Releases batch recognizer object + * Underlying model is also unreferenced and if needed released */ +void vosk_batch_recognizer_free(VoskBatchRecognizer *recognizer); + +/** Accept batch voice data */ +void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, int id, const char *data, int length); + +/** Closes the stream */ +void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer, int id); + +/** Return results */ +const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer, int id); + +/** Release and free first retrieved result */ +void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id); + +/** Wait for the processing */ +void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer); + +/** Get amount of pending chunks for more intelligent waiting */ +int vosk_batch_recognizer_get_pending_chunks(VoskBatchRecognizer *recognizer, int id); + #ifdef __cplusplus } #endif