Skip to content

Commit

Permalink
Per-stream wait API
Browse files Browse the repository at this point in the history
  • Loading branch information
nshmyrev committed Dec 23, 2021
1 parent 848b2dc commit cb0f8e6
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 14 deletions.
9 changes: 7 additions & 2 deletions python/example/batch/asr_server_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,19 @@ async def recognize(websocket, path):
continue

rec.AcceptWaveform(uid, message)
await asyncio.sleep(len(message) / 16000.0 / 2)

while rec.GetPendingChunks(uid) > 0:
await asyncio.sleep(0.1)

res = rec.Result(uid)
if len(res) == 0:
await websocket.send('{ "partial" : "" }')
else:
await websocket.send(res)

rec.Wait()
while rec.GetPendingChunks(uid) > 0:
await asyncio.sleep(0.1)

res = rec.Result(uid)
await websocket.send(res)

Expand Down
32 changes: 24 additions & 8 deletions python/example/batch/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import os
import wave
from time import sleep
import json
from timeit import default_timer as timer


from vosk import Model, BatchRecognizer, GpuInit

Expand All @@ -13,9 +16,16 @@

fnames = open("tedlium.list").readlines()
fds = [open(x.strip(), "rb") for x in fnames]
uids = [fname.strip().split('/')[-1][:-4] for fname in fnames]
results = [""] * len(fnames)
ended = set()
tot_samples = 0

start_time = timer()

while True:

# Feed in the data
for i, fd in enumerate(fds):
if i in ended:
continue
Expand All @@ -25,18 +35,24 @@
ended.add(i)
continue
rec.AcceptWaveform(i, data)
tot_samples += len(data)

sleep(0.3)
# Wait for results from CUDA
rec.Wait()

# Retrieve and add results
for i, fd in enumerate(fds):
res = rec.Result(i)
print (i, res)
if len(res) != 0:
results[i] = results[i] + " " + json.loads(res)['text']

if len(ended) == len(fds):
break

sleep(20)
print ("Done")
for i, fd in enumerate(fds):
res = rec.Result(i)
print (i, res)
rec.Wait()
end_time = timer()

for i in range(len(results)):
print (uids[i], results[i].strip())

print ("Processed %d seconds of audio in %d seconds (%f xRT)" % (tot_samples / 16000.0 / 2, end_time - start_time,
(tot_samples / 16000.0 / 2 / (end_time - start_time))), file=sys.stderr)
3 changes: 3 additions & 0 deletions python/vosk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ def FinishStream(self, uid):

def Wait(self):
_c.vosk_batch_recognizer_wait(self._handle)

def GetPendingChunks(self, uid):
return _c.vosk_batch_recognizer_get_pending_chunks(self._handle, uid)
10 changes: 7 additions & 3 deletions src/batch_recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ BatchRecognizer::BatchRecognizer() {
batched_decoder_config.feature_opts.feature_type = "mfcc";
batched_decoder_config.feature_opts.mfcc_config = "model/conf/mfcc.conf";
batched_decoder_config.feature_opts.ivector_extraction_config = "model/conf/ivector.conf";
batched_decoder_config.decoder_opts.max_active = 5000;
batched_decoder_config.decoder_opts.default_beam = 10.0;
batched_decoder_config.decoder_opts.lattice_beam = 4.0;
batched_decoder_config.decoder_opts.max_active = 7000;
batched_decoder_config.decoder_opts.default_beam = 13.0;
batched_decoder_config.decoder_opts.lattice_beam = 6.0;
batched_decoder_config.compute_opts.acoustic_scale = 1.0;
batched_decoder_config.compute_opts.frame_subsampling_factor = 3;
batched_decoder_config.compute_opts.frames_per_chunk = 51;
Expand Down Expand Up @@ -239,3 +239,7 @@ void BatchRecognizer::WaitForCompletion()
dynamic_batcher_->WaitForCompletion();
}

int BatchRecognizer::GetPendingChunks(uint64_t id)
{
return dynamic_batcher_->GetPendingChunks(id);
}
1 change: 1 addition & 0 deletions src/batch_recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class BatchRecognizer {
const char *FrontResult(uint64_t id);
void Pop(uint64_t id);
void WaitForCompletion();
int GetPendingChunks(uint64_t id);

private:
void PushLattice(uint64_t id, CompactLattice &clat, BaseFloat offset);
Expand Down
7 changes: 6 additions & 1 deletion src/vosk_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,15 @@ const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer,

void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id)
{
return ((BatchRecognizer *)recognizer)->Pop(id);
((BatchRecognizer *)recognizer)->Pop(id);
}

void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer)
{
((BatchRecognizer *)recognizer)->WaitForCompletion();
}

int vosk_batch_recognizer_get_pending_chunks(VoskBatchRecognizer *recognizer, int id)
{
return ((BatchRecognizer *)recognizer)->GetPendingChunks(id);
}
3 changes: 3 additions & 0 deletions src/vosk_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer, int id);
/** Wait for the processing */
void vosk_batch_recognizer_wait(VoskBatchRecognizer *recognizer);

/** Get amount of pending chunks for more intelligent waiting */
int vosk_batch_recognizer_get_pending_chunks(VoskBatchRecognizer *recognizer, int id);

#ifdef __cplusplus
}
#endif
Expand Down

0 comments on commit cb0f8e6

Please sign in to comment.