diff --git a/python/example/batch/asr_server_gpu.py b/python/example/batch/asr_server_gpu.py deleted file mode 100755 index 11885e9f..00000000 --- a/python/example/batch/asr_server_gpu.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 - -import json -import os -import sys -import asyncio -import pathlib -import websockets -import logging - -from vosk import BatchRecognizer, GpuInit - - -async def recognize(websocket, path): - global args - global loop - global pool - global rec - global client_cnt - - uid = client_cnt - client_cnt += 1 - - logging.info('Connection %d from %s', uid, websocket.remote_address); - - while True: - - message = await websocket.recv() - - if message == '{"eof" : 1}': - rec.FinishStream(uid) - break - - if isinstance(message, str) and 'config' in message: - continue - - rec.AcceptWaveform(uid, message) - - while rec.GetPendingChunks(uid) > 0: - await asyncio.sleep(0.1) - - res = rec.Result(uid) - if len(res) == 0: - await websocket.send('{ "partial" : "" }') - else: - await websocket.send(res) - - while rec.GetPendingChunks(uid) > 0: - await asyncio.sleep(0.1) - - res = rec.Result(uid) - await websocket.send(res) - -def start(): - - global rec - global args - global loop - global client_cnt - - # Enable loging if needed - # - # logger = logging.getLogger('websockets') - # logger.setLevel(logging.INFO) - # logger.addHandler(logging.StreamHandler()) - logging.basicConfig(level=logging.INFO) - - args = type('', (), {})() - - args.interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0') - args.port = int(os.environ.get('VOSK_SERVER_PORT', 2700)) - - GpuInit() - - rec = BatchRecognizer() - - client_cnt = 0 - - loop = asyncio.get_event_loop() - - start_server = websockets.serve( - recognize, args.interface, args.port) - - logging.info("Listening on %s:%d", args.interface, args.port) - loop.run_until_complete(start_server) - loop.run_forever() - - -if __name__ == '__main__': - start() diff --git a/python/example/batch/test_batch.py b/python/example/test_gpu_batch.py similarity index 97% rename from python/example/batch/test_batch.py rename to python/example/test_gpu_batch.py index 8737a746..3a65bda8 100755 --- a/python/example/batch/test_batch.py +++ b/python/example/test_gpu_batch.py @@ -29,7 +29,7 @@ for i, fd in enumerate(fds): if i in ended: continue - data = fd.read(8000) + data = fd.read(16000) if len(data) == 0: rec.FinishStream(i) ended.add(i) diff --git a/src/batch_recognizer.cc b/src/batch_recognizer.cc index 78cfc6f2..3337ee10 100644 --- a/src/batch_recognizer.cc +++ b/src/batch_recognizer.cc @@ -31,9 +31,10 @@ BatchRecognizer::BatchRecognizer() { batched_decoder_config.Register(&po); po.ReadConfigFile("model/conf/model.conf"); - batched_decoder_config.num_worker_threads = 4; - batched_decoder_config.max_batch_size = 100; + batched_decoder_config.num_worker_threads = -1; + batched_decoder_config.max_batch_size = 200; batched_decoder_config.reset_on_endpoint = true; + batched_decoder_config.use_gpu_feature_extraction = true; batched_decoder_config.feature_opts.feature_type = "mfcc"; batched_decoder_config.feature_opts.mfcc_config = "model/conf/mfcc.conf"; diff --git a/src/vosk_api.cc b/src/vosk_api.cc index 1f77eb6c..3f740d7b 100644 --- a/src/vosk_api.cc +++ b/src/vosk_api.cc @@ -171,6 +171,8 @@ void vosk_set_log_level(int log_level) void vosk_gpu_init() { #if HAVE_CUDA +// kaldi::CuDevice::EnableTensorCores(true); +// kaldi::CuDevice::EnableTf32Compute(true); kaldi::CuDevice::Instantiate().SelectGpuId("yes"); kaldi::CuDevice::Instantiate().AllowMultithreading(); #endif