Skip to content

Commit

Permalink
Merge pull request alphacep#800 from alphacep/batch
Browse files Browse the repository at this point in the history
Batch GPU decoding
  • Loading branch information
nshmyrev authored Dec 24, 2021
2 parents a4721de + 525b722 commit ed4c15b
Show file tree
Hide file tree
Showing 13 changed files with 583 additions and 62 deletions.
58 changes: 58 additions & 0 deletions python/example/test_gpu_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

import sys
import os
import wave
from time import sleep
import json
from timeit import default_timer as timer


from vosk import Model, BatchRecognizer, GpuInit

GpuInit()

rec = BatchRecognizer()

fnames = open("tedlium.list").readlines()
fds = [open(x.strip(), "rb") for x in fnames]
uids = [fname.strip().split('/')[-1][:-4] for fname in fnames]
results = [""] * len(fnames)
ended = set()
tot_samples = 0

start_time = timer()

while True:

# Feed in the data
for i, fd in enumerate(fds):
if i in ended:
continue
data = fd.read(16000)
if len(data) == 0:
rec.FinishStream(i)
ended.add(i)
continue
rec.AcceptWaveform(i, data)
tot_samples += len(data)

# Wait for results from CUDA
rec.Wait()

# Retrieve and add results
for i, fd in enumerate(fds):
res = rec.Result(i)
if len(res) != 0:
results[i] = results[i] + " " + json.loads(res)['text']

if len(ended) == len(fds):
break

end_time = timer()

for i in range(len(results)):
print (uids[i], results[i].strip())

print ("Processed %d seconds of audio in %d seconds (%f xRT)" % (tot_samples / 16000.0 / 2, end_time - start_time,
(tot_samples / 16000.0 / 2 / (end_time - start_time))), file=sys.stderr)
29 changes: 29 additions & 0 deletions python/vosk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,32 @@ def GpuInit():

def GpuThreadInit():
_c.vosk_gpu_thread_init()

class BatchRecognizer(object):

def __init__(self, *args):
self._handle = _c.vosk_batch_recognizer_new()

if self._handle == _ffi.NULL:
raise Exception("Failed to create a recognizer")

def __del__(self):
_c.vosk_batch_recognizer_free(self._handle)

def AcceptWaveform(self, uid, data):
res = _c.vosk_batch_recognizer_accept_waveform(self._handle, uid, data, len(data))

def Result(self, uid):
ptr = _c.vosk_batch_recognizer_front_result(self._handle, uid)
res = _ffi.string(ptr).decode('utf-8')
_c.vosk_batch_recognizer_pop(self._handle, uid)
return res

def FinishStream(self, uid):
_c.vosk_batch_recognizer_finish_stream(self._handle, uid)

def Wait(self):
_c.vosk_batch_recognizer_wait(self._handle)

def GetPendingChunks(self, uid):
return _c.vosk_batch_recognizer_get_pending_chunks(self._handle, uid)
18 changes: 13 additions & 5 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ EXTRA_LDFLAGS?=
OUTDIR?=.

VOSK_SOURCES= \
kaldi_recognizer.cc \
recognizer.cc \
language_model.cc \
model.cc \
spk_model.cc \
vosk_api.cc

VOSK_HEADERS= \
kaldi_recognizer.h \
recognizer.h \
language_model.h \
model.h \
spk_model.h \
Expand All @@ -39,13 +39,13 @@ LIBS= \
$(KALDI_ROOT)/src/decoder/kaldi-decoder.a \
$(KALDI_ROOT)/src/ivector/kaldi-ivector.a \
$(KALDI_ROOT)/src/gmm/kaldi-gmm.a \
$(KALDI_ROOT)/src/nnet3/kaldi-nnet3.a \
$(KALDI_ROOT)/src/tree/kaldi-tree.a \
$(KALDI_ROOT)/src/feat/kaldi-feat.a \
$(KALDI_ROOT)/src/lat/kaldi-lat.a \
$(KALDI_ROOT)/src/lm/kaldi-lm.a \
$(KALDI_ROOT)/src/rnnlm/kaldi-rnnlm.a \
$(KALDI_ROOT)/src/hmm/kaldi-hmm.a \
$(KALDI_ROOT)/src/nnet3/kaldi-nnet3.a \
$(KALDI_ROOT)/src/transform/kaldi-transform.a \
$(KALDI_ROOT)/src/cudamatrix/kaldi-cudamatrix.a \
$(KALDI_ROOT)/src/matrix/kaldi-matrix.a \
Expand All @@ -66,7 +66,7 @@ ifeq ($(HAVE_OPENBLAS_CLAPACK), 1)
endif

ifeq ($(HAVE_MKL), 1)
CFLAGS += -I$(MKL_ROOT)/include
CFLAGS += -DHAVE_MKL=1 -I$(MKL_ROOT)/include
LIBS += -L$(MKL_ROOT)/lib/intel64 -Wl,-rpath=$(MKL_ROOT)/lib/intel64 -lmkl_rt -lmkl_intel_lp64 -lmkl_core -lmkl_sequential
endif

Expand All @@ -75,8 +75,16 @@ ifeq ($(HAVE_ACCELERATE), 1)
endif

ifeq ($(HAVE_CUDA), 1)
VOSK_SOURCES += batch_recognizer.cc
VOSK_HEADERS += batch_recognizer.h

CFLAGS+=-DHAVE_CUDA=1 -I$(CUDA_ROOT)/include
LIBS+=-L$(CUDA_ROOT)/lib64 -lcuda -lcublas -lcusparse -lcudart -lcurand -lcufft -lcusolver -lnvToolsExt

LIBS := \
$(KALDI_ROOT)/src/cudadecoder/kaldi-cudadecoder.a \
$(KALDI_ROOT)/src/cudafeat/kaldi-cudafeat.a \
$(LIBS) \
-L$(CUDA_ROOT)/lib64 -lcuda -lcublas -lcusparse -lcudart -lcurand -lcufft -lcusolver -lnvToolsExt
endif

all: $(OUTDIR)/libvosk.$(EXT)
Expand Down
246 changes: 246 additions & 0 deletions src/batch_recognizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
// Copyright 2019-2020 Alpha Cephei Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "batch_recognizer.h"

#include "fstext/fstext-utils.h"
#include "lat/sausages.h"
#include "json.h"

#include <sys/stat.h>

using namespace fst;
using namespace kaldi::nnet3;
using CorrelationID = CudaOnlinePipelineDynamicBatcher::CorrelationID;

BatchRecognizer::BatchRecognizer() {
BatchedThreadedNnet3CudaOnlinePipelineConfig batched_decoder_config;

kaldi::ParseOptions po("something");
batched_decoder_config.Register(&po);
po.ReadConfigFile("model/conf/model.conf");

batched_decoder_config.num_worker_threads = -1;
batched_decoder_config.max_batch_size = 200;
batched_decoder_config.reset_on_endpoint = true;
batched_decoder_config.use_gpu_feature_extraction = true;

batched_decoder_config.feature_opts.feature_type = "mfcc";
batched_decoder_config.feature_opts.mfcc_config = "model/conf/mfcc.conf";
batched_decoder_config.feature_opts.ivector_extraction_config = "model/conf/ivector.conf";
batched_decoder_config.decoder_opts.max_active = 7000;
batched_decoder_config.decoder_opts.default_beam = 13.0;
batched_decoder_config.decoder_opts.lattice_beam = 6.0;
batched_decoder_config.compute_opts.acoustic_scale = 1.0;
batched_decoder_config.compute_opts.frame_subsampling_factor = 3;
batched_decoder_config.compute_opts.frames_per_chunk = 180;

struct stat buffer;

string nnet3_rxfilename_ = "model/am/final.mdl";
string hclg_fst_rxfilename_ = "model/graph/HCLG.fst";
string word_syms_rxfilename_ = "model/graph/words.txt";
string winfo_rxfilename_ = "model/graph/phones/word_boundary.int";
string std_fst_rxfilename_ = "model/rescore/G.fst";
string carpa_rxfilename_ = "model/rescore/G.carpa";

trans_model_ = new kaldi::TransitionModel();
nnet_ = new kaldi::nnet3::AmNnetSimple();
{
bool binary;
kaldi::Input ki(nnet3_rxfilename_, &binary);
trans_model_->Read(ki.Stream(), binary);
nnet_->Read(ki.Stream(), binary);
SetBatchnormTestMode(true, &(nnet_->GetNnet()));
SetDropoutTestMode(true, &(nnet_->GetNnet()));
nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(nnet_->GetNnet()));
}

if (stat(hclg_fst_rxfilename_.c_str(), &buffer) == 0) {
KALDI_LOG << "Loading HCLG from " << hclg_fst_rxfilename_;
hclg_fst_ = fst::ReadFstKaldiGeneric(hclg_fst_rxfilename_);
}

KALDI_LOG << "Loading words from " << word_syms_rxfilename_;
if (!(word_syms_ = fst::SymbolTable::ReadText(word_syms_rxfilename_))) {
KALDI_ERR << "Could not read symbol table from file "
<< word_syms_rxfilename_;
}
KALDI_ASSERT(word_syms_);

if (stat(winfo_rxfilename_.c_str(), &buffer) == 0) {
KALDI_LOG << "Loading winfo " << winfo_rxfilename_;
kaldi::WordBoundaryInfoNewOpts opts;
winfo_ = new kaldi::WordBoundaryInfo(opts, winfo_rxfilename_);
}

if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) {
KALDI_LOG << "Loading subtract G.fst model from " << std_fst_rxfilename_;
graph_lm_fst_ = fst::ReadAndPrepareLmFst(std_fst_rxfilename_);
KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_;
ReadKaldiObject(carpa_rxfilename_, &const_arpa_);
}



cuda_pipeline_ = new BatchedThreadedNnet3CudaOnlinePipeline
(batched_decoder_config, *hclg_fst_, *nnet_, *trans_model_);
cuda_pipeline_->SetSymbolTable(*word_syms_);

CudaOnlinePipelineDynamicBatcherConfig dynamic_batcher_config;
dynamic_batcher_ = new CudaOnlinePipelineDynamicBatcher(dynamic_batcher_config,
*cuda_pipeline_);
}

BatchRecognizer::~BatchRecognizer() {

delete trans_model_;
delete nnet_;
delete word_syms_;
delete winfo_;
delete hclg_fst_;
delete graph_lm_fst_;

delete lm_to_subtract_;
delete carpa_to_add_;
delete carpa_to_add_scale_;

delete cuda_pipeline_;
delete dynamic_batcher_;
}

void BatchRecognizer::FinishStream(uint64_t id)
{
Vector<BaseFloat> wave;
SubVector<BaseFloat> chunk(wave.Data(), 0);
dynamic_batcher_->Push(id, false, true, chunk);
streams_.erase(id);
}


void BatchRecognizer::PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset)
{
fst::ScaleLattice(fst::GraphLatticeScale(0.9), &clat);

CompactLattice aligned_lat;
WordAlignLattice(clat, *trans_model_, *winfo_, 0, &aligned_lat);

MinimumBayesRisk mbr(aligned_lat);
const vector<BaseFloat> &conf = mbr.GetOneBestConfidences();
const vector<int32> &words = mbr.GetOneBest();
const vector<pair<BaseFloat, BaseFloat> > &times =
mbr.GetOneBestTimes();

int size = words.size();

json::JSON obj;
stringstream text;

// Create JSON object
for (int i = 0; i < size; i++) {
json::JSON word;

word["word"] = word_syms_->Find(words[i]);
word["start"] = times[i].first * 0.03 + offset;
word["end"] = times[i].second * 0.03 + offset;
word["conf"] = conf[i];
obj["result"].append(word);

if (i) {
text << " ";
}
text << word_syms_->Find(words[i]);
}
obj["text"] = text.str();

// KALDI_LOG << "Result " << id << " " << obj.dump();

results_[id].push(obj.dump());
}

void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len)
{
bool first = false;

if (streams_.find(id) == streams_.end()) {
first = true;
streams_.insert(id);

// Define the callback for results.
#if 0
cuda_pipeline_->SetBestPathCallback(
id,
[&, id](const std::string &str, bool partial,
bool endpoint_detected) {
if (partial) {
KALDI_LOG << "id #" << id << " [partial] : " << str << ":";
}

if (endpoint_detected) {
KALDI_LOG << "id #" << id << " [endpoint detected]";
}

if (!partial) {
KALDI_LOG << "id #" << id << " : " << str;
}
});
#endif
cuda_pipeline_->SetLatticeCallback(
id,
[&, id](SegmentedLatticeCallbackParams& params) {
if (params.results.empty()) {
KALDI_WARN << "Empty result for callback";
return;
}
CompactLattice *clat = params.results[0].GetLatticeResult();
BaseFloat offset = params.results[0].GetTimeOffsetSeconds();
PushLattice(id, *clat, offset);
},
CudaPipelineResult::RESULT_TYPE_LATTICE);
}

Vector<BaseFloat> wave;
wave.Resize(len / 2, kUndefined);
for (int i = 0; i < len / 2; i++)
wave(i) = *(((short *)data) + i);
SubVector<BaseFloat> chunk(wave.Data(), wave.Dim());

dynamic_batcher_->Push(id, first, false, chunk);
}

const char* BatchRecognizer::FrontResult(uint64_t id)
{
if (results_[id].empty()) {
return "";
}
return results_[id].front().c_str();
}

void BatchRecognizer::Pop(uint64_t id)
{
if (results_[id].empty()) {
return;
}
results_[id].pop();
}

void BatchRecognizer::WaitForCompletion()
{
dynamic_batcher_->WaitForCompletion();
}

int BatchRecognizer::GetPendingChunks(uint64_t id)
{
return dynamic_batcher_->GetPendingChunks(id);
}
Loading

0 comments on commit ed4c15b

Please sign in to comment.