Skip to content

Commit

Permalink
Expose results in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
nshmyrev committed Dec 17, 2021
1 parent 60f0396 commit 848b2dc
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 25 deletions.
85 changes: 85 additions & 0 deletions python/example/batch/asr_server_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3

import json
import os
import sys
import asyncio
import pathlib
import websockets
import logging

from vosk import BatchRecognizer, GpuInit


async def recognize(websocket, path):
global args
global loop
global pool
global rec
global client_cnt

uid = client_cnt
client_cnt += 1

logging.info('Connection %d from %s', uid, websocket.remote_address);

while True:

message = await websocket.recv()

if message == '{"eof" : 1}':
rec.FinishStream(uid)
break

if isinstance(message, str) and 'config' in message:
continue

rec.AcceptWaveform(uid, message)
await asyncio.sleep(len(message) / 16000.0 / 2)
res = rec.Result(uid)
if len(res) == 0:
await websocket.send('{ "partial" : "" }')
else:
await websocket.send(res)

rec.Wait()
res = rec.Result(uid)
await websocket.send(res)

def start():

global rec
global args
global loop
global client_cnt

# Enable loging if needed
#
# logger = logging.getLogger('websockets')
# logger.setLevel(logging.INFO)
# logger.addHandler(logging.StreamHandler())
logging.basicConfig(level=logging.INFO)

args = type('', (), {})()

args.interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0')
args.port = int(os.environ.get('VOSK_SERVER_PORT', 2700))

GpuInit()

rec = BatchRecognizer()

client_cnt = 0

loop = asyncio.get_event_loop()

start_server = websockets.serve(
recognize, args.interface, args.port)

logging.info("Listening on %s:%d", args.interface, args.port)
loop.run_until_complete(start_server)
loop.run_forever()


if __name__ == '__main__':
start()
24 changes: 19 additions & 5 deletions python/example/batch/test_batch.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
#!/usr/bin/env python3

from vosk import Model, BatchRecognizer, GpuInit, GpuThreadInit
import sys
import os
import wave
from time import sleep

from vosk import Model, BatchRecognizer, GpuInit

GpuInit()
GpuThreadInit()

rec = BatchRecognizer()

fnames = open("tedlium.list").readlines()
fds = [open(x.strip(), "rb") for x in fnames]
ended = set()
while True:

for i, fd in enumerate(fds):
if i in ended:
continue
data = fd.read(8000)
if len(data) == 0:
rec.FinishStream(i)
ended.add(i)
else:
rec.AcceptWaveform(i, data)
rec.Results()
continue
rec.AcceptWaveform(i, data)

sleep(0.3)
for i, fd in enumerate(fds):
res = rec.Result(i)
print (i, res)

if len(ended) == len(fds):
break

sleep(20)
print ("Done")
for i, fd in enumerate(fds):
res = rec.Result(i)
print (i, res)
rec.Wait()
10 changes: 8 additions & 2 deletions python/vosk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,14 @@ def __del__(self):
def AcceptWaveform(self, uid, data):
res = _c.vosk_batch_recognizer_accept_waveform(self._handle, uid, data, len(data))

def Results(self):
return _ffi.string(_c.vosk_batch_recognizer_results(self._handle)).decode('utf-8')
def Result(self, uid):
ptr = _c.vosk_batch_recognizer_front_result(self._handle, uid)
res = _ffi.string(ptr).decode('utf-8')
_c.vosk_batch_recognizer_pop(self._handle, uid)
return res

def FinishStream(self, uid):
_c.vosk_batch_recognizer_finish_stream(self._handle, uid)

def Wait(self):
_c.vosk_batch_recognizer_wait(self._handle)
87 changes: 77 additions & 10 deletions src/batch_recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "fstext/fstext-utils.h"
#include "lat/sausages.h"
#include "json.h"

#include <sys/stat.h>

Expand All @@ -37,12 +38,12 @@ BatchRecognizer::BatchRecognizer() {
batched_decoder_config.feature_opts.feature_type = "mfcc";
batched_decoder_config.feature_opts.mfcc_config = "model/conf/mfcc.conf";
batched_decoder_config.feature_opts.ivector_extraction_config = "model/conf/ivector.conf";
batched_decoder_config.decoder_opts.max_active = 7000;
batched_decoder_config.decoder_opts.default_beam = 13.0;
batched_decoder_config.decoder_opts.lattice_beam = 8.0;
batched_decoder_config.decoder_opts.max_active = 5000;
batched_decoder_config.decoder_opts.default_beam = 10.0;
batched_decoder_config.decoder_opts.lattice_beam = 4.0;
batched_decoder_config.compute_opts.acoustic_scale = 1.0;
batched_decoder_config.compute_opts.frame_subsampling_factor = 3;
batched_decoder_config.compute_opts.frames_per_chunk = 312;
batched_decoder_config.compute_opts.frames_per_chunk = 51;

struct stat buffer;

Expand Down Expand Up @@ -126,6 +127,47 @@ void BatchRecognizer::FinishStream(uint64_t id)
streams_.erase(id);
}


void BatchRecognizer::PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset)
{
fst::ScaleLattice(fst::GraphLatticeScale(0.9), &clat);

CompactLattice aligned_lat;
WordAlignLattice(clat, *trans_model_, *winfo_, 0, &aligned_lat);

MinimumBayesRisk mbr(aligned_lat);
const vector<BaseFloat> &conf = mbr.GetOneBestConfidences();
const vector<int32> &words = mbr.GetOneBest();
const vector<pair<BaseFloat, BaseFloat> > &times =
mbr.GetOneBestTimes();

int size = words.size();

json::JSON obj;
stringstream text;

// Create JSON object
for (int i = 0; i < size; i++) {
json::JSON word;

word["word"] = word_syms_->Find(words[i]);
word["start"] = times[i].first * 0.03 + offset;
word["end"] = times[i].second * 0.03 + offset;
word["conf"] = conf[i];
obj["result"].append(word);

if (i) {
text << " ";
}
text << word_syms_->Find(words[i]);
}
obj["text"] = text.str();

// KALDI_LOG << "Result " << id << " " << obj.dump();

results_[id].push(obj.dump());
}

void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len)
{
bool first = false;
Expand All @@ -135,7 +177,8 @@ void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len)
streams_.insert(id);

// Define the callback for results.
cuda_pipeline_->SetBestPathCallback(
#if 0
cuda_pipeline_->SetBestPathCallback(
id,
[&, id](const std::string &str, bool partial,
bool endpoint_detected) {
Expand All @@ -151,11 +194,19 @@ void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len)
KALDI_LOG << "id #" << id << " : " << str;
}
});
#endif
cuda_pipeline_->SetLatticeCallback(
id,
[&, id](CompactLattice &clat) {
KALDI_LOG << "Got lattice from the stream " << id;
});
[&, id](SegmentedLatticeCallbackParams& params) {
if (params.results.empty()) {
KALDI_WARN << "Empty result for callback";
return;
}
CompactLattice *clat = params.results[0].GetLatticeResult();
BaseFloat offset = params.results[0].GetTimeOffsetSeconds();
PushLattice(id, *clat, offset);
},
CudaPipelineResult::RESULT_TYPE_LATTICE);
}

Vector<BaseFloat> wave;
Expand All @@ -167,8 +218,24 @@ void BatchRecognizer::AcceptWaveform(uint64_t id, const char *data, int len)
dynamic_batcher_->Push(id, first, false, chunk);
}

const char* BatchRecognizer::PullResults()
const char* BatchRecognizer::FrontResult(uint64_t id)
{
if (results_[id].empty()) {
return "";
}
return results_[id].front().c_str();
}

void BatchRecognizer::Pop(uint64_t id)
{
if (results_[id].empty()) {
return;
}
results_[id].pop();
}

void BatchRecognizer::WaitForCompletion()
{
dynamic_batcher_->WaitForCompletion();
return "";
}

6 changes: 5 additions & 1 deletion src/batch_recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ class BatchRecognizer {

void FinishStream(uint64_t id);
void AcceptWaveform(uint64_t id, const char *data, int len);
const char* PullResults();
const char *FrontResult(uint64_t id);
void Pop(uint64_t id);
void WaitForCompletion();

private:
void PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset);

kaldi::TransitionModel *trans_model_ = nullptr;
kaldi::nnet3::AmNnetSimple *nnet_ = nullptr;
Expand All @@ -64,6 +67,7 @@ class BatchRecognizer {


std::set<int> streams_;
std::map<int, std::queue<std::string> > results_;

// Rescoring
fst::ArcMapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_to_subtract_ = nullptr;
Expand Down
8 changes: 4 additions & 4 deletions src/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class JSON
Class Type = Class::Null;
};

JSON Array() {
inline JSON Array() {
return JSON::Make( JSON::Class::Array );
}

Expand All @@ -435,11 +435,11 @@ JSON Array( T... args ) {
return arr;
}

JSON Object() {
inline JSON Object() {
return JSON::Make( JSON::Class::Object );
}

std::ostream& operator<<( std::ostream &os, const JSON &json ) {
inline std::ostream& operator<<( std::ostream &os, const JSON &json ) {
os << json.dump();
return os;
}
Expand Down Expand Up @@ -647,7 +647,7 @@ namespace {
}
}

JSON JSON::Load( const string &str ) {
inline JSON JSON::Load( const string &str ) {
size_t offset = 0;
return parse_next( str, offset );
}
Expand Down
14 changes: 12 additions & 2 deletions src/vosk_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,17 @@ void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer, int id
((BatchRecognizer *)recognizer)->FinishStream(id);
}

const char *vosk_batch_recognizer_results(VoskBatchRecognizer *recognizer)
const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer, int id)
{
return ((BatchRecognizer *)recognizer)->PullResults();
return ((BatchRecognizer *)recognizer)->FrontResult(id);
}

void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id)
{
return ((BatchRecognizer *)recognizer)->Pop(id);
}

void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer)
{
((BatchRecognizer *)recognizer)->WaitForCompletion();
}
8 changes: 7 additions & 1 deletion src/vosk_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,13 @@ void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, int
void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer, int id);

/** Return results */
const char *vosk_batch_recognizer_results(VoskBatchRecognizer *recognizer);
const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer, int id);

/** Release and free first retrieved result */
void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id);

/** Wait for the processing */
void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer);

#ifdef __cplusplus
}
Expand Down

0 comments on commit 848b2dc

Please sign in to comment.