diff --git a/README.md b/README.md index a7733fa..1f43b53 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,9 @@ Estonian demo: http://bark.phon.ioc.ee/dikteeri/ Changelog --------- + * 2020-03-06: Quite big changes. Upgraded to Python 3 (Python 2.7 is not upported any more). Also, migrated to + use Tornado 6. Also, worker.py and client.py now uses Tornado's websocket client and ws4py is not needed any more. + Post-processing should also work fine. * 2019-06-17: The postprocessing mechanism doesn't work properly with Tornado 5+. Use Tornado 4.5.3 if you need it. * 2018-04-25: Server should now work with Tornado 5 (thanks to @Gastron). If using Python 2, you might need to install the `futures` package (`pip install futures`). * 2017-12-27: Somewhat big changes in the way post-processor is invoked. The problem was that in some use cases, the program that is used for @@ -69,19 +72,13 @@ all the prerequisites manually, one could use the Dockerfile created by José Ed ### Requirements -#### Python 2.7 with the following packages: +#### Python 3.5.2 or newer with the following packages: - * Tornado 4, see http://www.tornadoweb.org/en/stable/ - * ws4py (0.3.0 .. 0.3.2) + * Tornado 6, see http://www.tornadoweb.org/en/stable/ * YAML * JSON -*NB!*: The server doesn't work quite correctly with ws4py 0.3.5 because of a bug I reported here: https://github.com/Lawouach/WebSocket-for-Python/issues/152. -Use ws4py 0.3.2 instead. To install ws4py 0.3.2 using `pip`, run: - - pip install ws4py==0.3.2 - -In addition, you need Python 2.x bindings for gobject-introspection libraries, provided by the `python-gi` +In addition, you need Python bindings for gobject-introspection libraries, provided by the `python-gi` package on Debian and Ubuntu. #### Kaldi @@ -98,7 +95,7 @@ English models are based on Voxforge acoustic models and the CMU Sphinx 2013 ge The language models were heavily pruned so that the resulting FST cascade would be less than the 100 MB GitHub file size limit. -*Update:* the server also supports Kaldi's new "online2" online decoder that uses DNN-based acoustic models with i-vector input. See below on +*Update:* the server also supports Kaldi's "online2" online decoder that uses DNN-based acoustic models with i-vector input. See below on how to use it. According to experiments on two Estonian online decoding setups, the DNN-based models result in about 20% (or more) relatively less errors than GMM-based models (e.g., WER dropped from 13% to 9%). diff --git a/kaldigstserver/client.py b/kaldigstserver/client.py index 9a8e8fe..9985d65 100644 --- a/kaldigstserver/client.py +++ b/kaldigstserver/client.py @@ -1,103 +1,127 @@ __author__ = 'tanel' import argparse -from ws4py.client.threadedclient import WebSocketClient +#from ws4py.client.threadedclient import WebSocketClient import time import threading import sys import urllib -import Queue +import queue import json import time import os +from tornado.ioloop import IOLoop +from tornado import gen +from tornado.websocket import websocket_connect +from concurrent.futures import ThreadPoolExecutor +from tornado.concurrent import run_on_executor + def rate_limited(maxPerSecond): - minInterval = 1.0 / float(maxPerSecond) + min_interval = 1.0 / float(maxPerSecond) def decorate(func): - lastTimeCalled = [0.0] + last_time_called = [0.0] def rate_limited_function(*args,**kargs): - elapsed = time.clock() - lastTimeCalled[0] - leftToWait = minInterval - elapsed - if leftToWait>0: - time.sleep(leftToWait) + elapsed = time.perf_counter() - last_time_called[0] + left_to_wait = min_interval - elapsed + if left_to_wait > 0: + yield gen.sleep(left_to_wait) ret = func(*args,**kargs) - lastTimeCalled[0] = time.clock() + last_time_called[0] = time.perf_counter() return ret return rate_limited_function return decorate +executor = ThreadPoolExecutor(max_workers=5) -class MyClient(WebSocketClient): +class MyClient(): - def __init__(self, audiofile, url, protocols=None, extensions=None, heartbeat_freq=None, byterate=32000, + def __init__(self, audiofile, url, byterate=32000, save_adaptation_state_filename=None, send_adaptation_state_filename=None): - super(MyClient, self).__init__(url, protocols, extensions, heartbeat_freq) + self.url = url self.final_hyps = [] self.audiofile = audiofile self.byterate = byterate - self.final_hyp_queue = Queue.Queue() + self.final_hyp_queue = queue.Queue() self.save_adaptation_state_filename = save_adaptation_state_filename self.send_adaptation_state_filename = send_adaptation_state_filename - + self.ioloop = IOLoop.instance() + self.run() + self.ioloop.start() + + + @gen.coroutine + def run(self): + self.ws = yield websocket_connect(self.url, on_message_callback=self.received_message) + if self.send_adaptation_state_filename is not None: + print("Sending adaptation state from " + self.send_adaptation_state_filename) + try: + adaptation_state_props = json.load(open(self.send_adaptation_state_filename, "r")) + self.ws.write_message(json.dumps(dict(adaptation_state=adaptation_state_props))) + except: + e = sys.exc_info()[0] + print("Failed to send adaptation state: " + e) + + # In Python 3, stdin is always opened as text by argparse + if type(self.audiofile).__name__ == 'TextIOWrapper': + self.audiofile = self.audiofile.buffer + + with self.audiofile as audiostream: + while True: + block = yield from self.ioloop.run_in_executor(executor, audiostream.read, int(self.byterate/4)) + if block == b"": + break + yield self.send_data(block) + self.ws.write_message("EOS") + + + @gen.coroutine @rate_limited(4) def send_data(self, data): - self.send(data, binary=True) - - def opened(self): - #print "Socket opened!" - def send_data_to_ws(): - if self.send_adaptation_state_filename is not None: - print >> sys.stderr, "Sending adaptation state from %s" % self.send_adaptation_state_filename - try: - adaptation_state_props = json.load(open(self.send_adaptation_state_filename, "r")) - self.send(json.dumps(dict(adaptation_state=adaptation_state_props))) - except: - e = sys.exc_info()[0] - print >> sys.stderr, "Failed to send adaptation state: ", e - with self.audiofile as audiostream: - for block in iter(lambda: audiostream.read(self.byterate/4), ""): - self.send_data(block) - print >> sys.stderr, "Audio sent, now sending EOS" - self.send("EOS") - - t = threading.Thread(target=send_data_to_ws) - t.start() + self.ws.write_message(data, binary=True) def received_message(self, m): + if m is None: + #print("Websocket closed() called") + self.final_hyp_queue.put(" ".join(self.final_hyps)) + self.ioloop.stop() + + return + + #print("Received message ...") + #print(str(m) + "\n") response = json.loads(str(m)) - #print >> sys.stderr, "RESPONSE:", response - #print >> sys.stderr, "JSON was:", m + if response['status'] == 0: + #print(response) if 'result' in response: - trans = response['result']['hypotheses'][0]['transcript'].encode('utf-8') + trans = response['result']['hypotheses'][0]['transcript'] if response['result']['final']: - #print >> sys.stderr, trans, self.final_hyps.append(trans) - print >> sys.stderr, '\r%s' % trans.replace("\n", "\\n") + print(trans.replace("\n", "\\n"), end="\n") else: print_trans = trans.replace("\n", "\\n") if len(print_trans) > 80: print_trans = "... %s" % print_trans[-76:] - print >> sys.stderr, '\r%s' % print_trans, + print(print_trans, end="\r") if 'adaptation_state' in response: if self.save_adaptation_state_filename: - print >> sys.stderr, "Saving adaptation state to %s" % self.save_adaptation_state_filename + print("Saving adaptation state to " + self.save_adaptation_state_filename) with open(self.save_adaptation_state_filename, "w") as f: f.write(json.dumps(response['adaptation_state'])) else: - print >> sys.stderr, "Received error from server (status %d)" % response['status'] + print("Received error from server (status %d)" % response['status']) if 'message' in response: - print >> sys.stderr, "Error message:", response['message'] + print("Error message:" + response['message']) def get_full_hyp(self, timeout=60): return self.final_hyp_queue.get(timeout) - def closed(self, code, reason=None): - #print "Websocket closed() called" - #print >> sys.stderr - self.final_hyp_queue.put(" ".join(self.final_hyps)) + # def closed(self, code, reason=None): + # print("Websocket closed() called") + # self.final_hyp_queue.put(" ".join(self.final_hyps)) def main(): @@ -108,20 +132,19 @@ def main(): parser.add_argument('--save-adaptation-state', help="Save adaptation state to file") parser.add_argument('--send-adaptation-state', help="Send adaptation state from file") parser.add_argument('--content-type', default='', help="Use the specified content type (empty by default, for raw files the default is audio/x-raw, layout=(string)interleaved, rate=(int), format=(string)S16LE, channels=(int)1") - parser.add_argument('audiofile', help="Audio file to be sent to the server", type=argparse.FileType('rb'), default=sys.stdin) + parser.add_argument('audiofile', help="Audio file to be sent to the server", type=argparse.FileType('rb')) args = parser.parse_args() content_type = args.content_type if content_type == '' and args.audiofile.name.endswith(".raw"): content_type = "audio/x-raw, layout=(string)interleaved, rate=(int)%d, format=(string)S16LE, channels=(int)1" %(args.rate/2) - - - ws = MyClient(args.audiofile, args.uri + '?%s' % (urllib.urlencode([("content-type", content_type)])), byterate=args.rate, + ws = MyClient(args.audiofile, args.uri + '?%s' % (urllib.parse.urlencode([("content-type", content_type)])), byterate=args.rate, save_adaptation_state_filename=args.save_adaptation_state, send_adaptation_state_filename=args.send_adaptation_state) - ws.connect() + result = ws.get_full_hyp() - print result + print(result) + if __name__ == "__main__": main() diff --git a/kaldigstserver/decoder.py b/kaldigstserver/decoder.py index 811b650..d61b644 100644 --- a/kaldigstserver/decoder.py +++ b/kaldigstserver/decoder.py @@ -11,16 +11,18 @@ GObject.threads_init() Gst.init(None) import logging -import thread +import _thread import os +import sys logger = logging.getLogger(__name__) import pdb class DecoderPipeline(object): - def __init__(self, conf={}): + def __init__(self, ioloop, conf={}): logger.info("Creating decoder using conf: %s" % conf) + self.ioloop = ioloop self.use_cutter = conf.get("use-vad", False) self.create_pipeline(conf) self.outdir = conf.get("out-dir", None) @@ -51,22 +53,24 @@ def create_pipeline(self, conf): self.fakesink = Gst.ElementFactory.make("fakesink", "fakesink") if not self.asr: - print >> sys.stderr, "ERROR: Couldn't create the onlinegmmdecodefaster element!" + print("ERROR: Couldn't create the onlinegmmdecodefaster element!", file=sys.stderr) gst_plugin_path = os.environ.get("GST_PLUGIN_PATH") if gst_plugin_path: - print >> sys.stderr, \ - "Couldn't find onlinegmmdecodefaster element at %s. " \ - "If it's not the right path, try to set GST_PLUGIN_PATH to the right one, and retry. " \ - "You can also try to run the following command: " \ - "'GST_PLUGIN_PATH=%s gst-inspect-1.0 onlinegmmdecodefaster'." \ - % (gst_plugin_path, gst_plugin_path) + print( + "Couldn't find onlinegmmdecodefaster element at %s. " + "If it's not the right path, try to set GST_PLUGIN_PATH to the right one, and retry. " + "You can also try to run the following command: " + "'GST_PLUGIN_PATH=%s gst-inspect-1.0 onlinegmmdecodefaster'." + % (gst_plugin_path, gst_plugin_path), + file=sys.stderr) else: - print >> sys.stderr, \ + print( "The environment variable GST_PLUGIN_PATH wasn't set or it's empty. " \ - "Try to set GST_PLUGIN_PATH environment variable, and retry." - sys.exit(-1); + "Try to set GST_PLUGIN_PATH environment variable, and retry.", + file=sys.stderr) + sys.exit(-1) - for (key, val) in conf.get("decoder", {}).iteritems(): + for (key, val) in conf.get("decoder", {}).items(): logger.info("Setting decoder property: %s = %s" % (key, val)) self.asr.set_property(key, val) @@ -148,23 +152,22 @@ def _on_element_message(self, bus, message): self.asr.set_property("silent", True) def _on_word(self, asr, word): - logger.info("%s: Got word: %s" % (self.request_id, word.decode('utf8'))) + logger.info("%s: Got word: %s" % (self.request_id, word)) if self.word_handler: - self.word_handler(word) - + self.ioloop.add_callback(self.word_handler, word) def _on_error(self, bus, msg): self.error = msg.parse_error() logger.error(self.error) self.finish_request() if self.error_handler: - self.error_handler(self.error[0].message) + self.ioloop.add_callback(self.error_handler, self.error[0].message) def _on_eos(self, bus, msg): logger.info('%s: Pipeline received eos signal' % self.request_id) self.finish_request() if self.eos_handler: - self.eos_handler[0](self.eos_handler[1]) + self.ioloop.add_callback(self.eos_handler[0], self.eos_handler[1]) def finish_request(self): logger.info('%s: Finishing request' % self.request_id) @@ -235,4 +238,4 @@ def cancel(self): #logger.debug("Sending EOS to pipeline") #self.pipeline.send_event(Gst.Event.new_eos()) #self.pipeline.set_state(Gst.State.READY) - logger.info("%s: Cancelled pipeline" % self.request_id) \ No newline at end of file + logger.info("%s: Cancelled pipeline" % self.request_id) diff --git a/kaldigstserver/decoder2.py b/kaldigstserver/decoder2.py index befc152..2159a39 100644 --- a/kaldigstserver/decoder2.py +++ b/kaldigstserver/decoder2.py @@ -11,16 +11,18 @@ GObject.threads_init() Gst.init(None) import logging -import thread +import _thread import os +import sys from collections import OrderedDict +import tornado.ioloop logger = logging.getLogger(__name__) import pdb class DecoderPipeline2(object): - def __init__(self, conf={}): + def __init__(self, ioloop, conf={}): logger.info("Creating decoder using conf: %s" % conf) self.create_pipeline(conf) self.outdir = conf.get("out-dir", None) @@ -35,6 +37,7 @@ def __init__(self, conf={}): self.eos_handler = None self.error_handler = None self.request_id = "" + self.ioloop = ioloop def create_pipeline(self, conf): @@ -51,20 +54,22 @@ def create_pipeline(self, conf): self.fakesink = Gst.ElementFactory.make("fakesink", "fakesink") if not self.asr: - print >> sys.stderr, "ERROR: Couldn't create the kaldinnet2onlinedecoder element!" + print("ERROR: Couldn't create the kaldinnet2onlinedecoder element!", file=sys.stderr) gst_plugin_path = os.environ.get("GST_PLUGIN_PATH") if gst_plugin_path: - print >> sys.stderr, \ + print( "Couldn't find kaldinnet2onlinedecoder element at %s. " \ "If it's not the right path, try to set GST_PLUGIN_PATH to the right one, and retry. " \ "You can also try to run the following command: " \ "'GST_PLUGIN_PATH=%s gst-inspect-1.0 kaldinnet2onlinedecoder'." \ - % (gst_plugin_path, gst_plugin_path) + % (gst_plugin_path, gst_plugin_path), + file=sys.stderr) else: - print >> sys.stderr, \ + print( "The environment variable GST_PLUGIN_PATH wasn't set or it's empty. " \ - "Try to set GST_PLUGIN_PATH environment variable, and retry." - sys.exit(-1); + "Try to set GST_PLUGIN_PATH environment variable, and retry.", + file=sys.stderr) + sys.exit(-1) # This needs to be set first if "use-threaded-decoder" in conf["decoder"]: @@ -83,7 +88,7 @@ def create_pipeline(self, conf): if "model" in decoder_config: decoder_config["model"] = decoder_config.pop("model") - for (key, val) in decoder_config.iteritems(): + for (key, val) in decoder_config.items(): if key != "use-threaded-decoder": logger.info("Setting decoder property: %s = %s" % (key, val)) self.asr.set_property(key, val) @@ -139,33 +144,34 @@ def _connect_decoder(self, element, pad): def _on_partial_result(self, asr, hyp): - logger.info("%s: Got partial result: %s" % (self.request_id, hyp.decode('utf8'))) + logger.info("%s: Got partial result: %s" % (self.request_id, hyp)) if self.result_handler: - self.result_handler(hyp.decode('utf8'), False) + self.ioloop.add_callback(self.result_handler, hyp, False) def _on_final_result(self, asr, hyp): - logger.info("%s: Got final result: %s" % (self.request_id, hyp.decode('utf8'))) + logger.info("%s: Got final result: %s" % (self.request_id, hyp)) if self.result_handler: - self.result_handler(hyp.decode('utf8'), True) + self.ioloop.add_callback(self.result_handler, hyp, True) def _on_full_final_result(self, asr, result_json): - logger.info("%s: Got full final result: %s" % (self.request_id, result_json.decode('utf8'))) + logger.info("%s: Got full final result: %s" % (self.request_id, result_json)) if self.full_result_handler: - self.full_result_handler(result_json) + self.ioloop.add_callback(self.full_result_handler, result_json) def _on_error(self, bus, msg): self.error = msg.parse_error() logger.error(self.error) self.finish_request() if self.error_handler: - self.error_handler(self.error[0].message) + self.ioloop.add_callback(self.error_handler, self.error[0].message) def _on_eos(self, bus, msg): logger.info('%s: Pipeline received eos signal' % self.request_id) #self.decodebin.unlink(self.audioconvert) self.finish_request() if self.eos_handler: - self.eos_handler[0](self.eos_handler[1]) + self.ioloop.add_callback(self.eos_handler[0], self.eos_handler[1]) + def get_adaptation_state(self): return self.asr.get_property("adaptation-state") diff --git a/kaldigstserver/master_server.py b/kaldigstserver/master_server.py index 338b37f..a50967d 100644 --- a/kaldigstserver/master_server.py +++ b/kaldigstserver/master_server.py @@ -15,7 +15,7 @@ import time import threading import functools -from Queue import Queue +from tornado.locks import Condition import tornado.ioloop import tornado.options @@ -102,11 +102,10 @@ class HttpChunkedRecognizeHandler(tornado.web.RequestHandler): Provides a HTTP POST/PUT interface supporting chunked transfer requests, similar to that provided by http://github.com/alumae/ruby-pocketsphinx-server. """ - def prepare(self): self.id = str(uuid.uuid4()) self.final_hyp = "" - self.final_result_queue = Queue() + self.worker_done = Condition() self.user_id = self.request.headers.get("device-id", "none") self.content_id = self.request.headers.get("content-id", "none") logging.info("%s: OPEN: user='%s', content='%s'" % (self.id, self.user_id, self.content_id)) @@ -114,7 +113,7 @@ def prepare(self): self.error_status = 0 self.error_message = None #Waiter thread for final hypothesis: - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + #self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) try: self.worker = self.application.available_workers.pop() self.application.send_status_update() @@ -132,33 +131,30 @@ def prepare(self): self.set_status(503) self.finish("No workers available") + @tornado.gen.coroutine def data_received(self, chunk): assert self.worker is not None logging.debug("%s: Forwarding client message of length %d to worker" % (self.id, len(chunk))) self.worker.write_message(chunk, binary=True) - + + @tornado.gen.coroutine def post(self, *args, **kwargs): - self.end_request(args, kwargs) + yield self.end_request(args, kwargs) + @tornado.gen.coroutine def put(self, *args, **kwargs): - self.end_request(args, kwargs) - - @tornado.concurrent.run_on_executor - def get_final_hyp(self): - logging.info("%s: Waiting for final result..." % self.id) - return self.final_result_queue.get(block=True) + yield self.end_request(args, kwargs) - @tornado.web.asynchronous @tornado.gen.coroutine def end_request(self, *args, **kwargs): logging.info("%s: Handling the end of chunked recognize request" % self.id) assert self.worker is not None - self.worker.write_message("EOS", binary=True) - logging.info("%s: yielding..." % self.id) - hyp = yield self.get_final_hyp() + self.worker.write_message("EOS", binary=False) + logging.info("%s: Waiting for worker to finish" % self.id) + yield self.worker_done.wait() if self.error_status == 0: - logging.info("%s: Final hyp: %s" % (self.id, hyp)) - response = {"status" : 0, "id": self.id, "hypotheses": [{"utterance" : hyp}]} + logging.info("%s: Final hyp: %s" % (self.id, self.final_hyp)) + response = {"status" : 0, "id": self.id, "hypotheses": [{"utterance" : self.final_hyp}]} self.write(response) else: logging.info("%s: Error (status=%d) processing HTTP request: %s" % (self.id, self.error_status, self.error_message)) @@ -171,6 +167,7 @@ def end_request(self, *args, **kwargs): self.finish() logging.info("Everything done") + @tornado.gen.coroutine def send_event(self, event): event_str = str(event) if len(event_str) > 100: @@ -189,9 +186,10 @@ def send_event(self, event): self.error_status = event["status"] self.error_message = event.get("message", "") + @tornado.gen.coroutine def close(self): logging.info("%s: Receiving 'close' from worker" % (self.id)) - self.final_result_queue.put(self.final_hyp) + self.worker_done.notify() class ReferenceHandler(tornado.web.RequestHandler): @@ -249,6 +247,7 @@ def on_close(self): logging.info("Worker " + self.__str__() + " leaving") self.application.available_workers.discard(self) if self.client_socket: + logging.info("Closing client connection") self.client_socket.close() self.application.send_status_update() @@ -272,7 +271,7 @@ def send_event(self, event): if len(event_str) > 100: event_str = event_str[:97] + "..." logging.info("%s: Sending event %s to client" % (self.id, event_str)) - self.write_message(json.dumps(event)) + self.write_message(json.dumps(event).replace('False', 'false').replace('\'', '\"')) def open(self): self.id = str(uuid.uuid4()) @@ -313,7 +312,7 @@ def on_connection_close(self): def on_message(self, message): assert self.worker is not None logging.info("%s: Forwarding client message (%s) of length %d to worker" % (self.id, type(message), len(message))) - if isinstance(message, unicode): + if isinstance(message, str): self.worker.write_message(message, binary=False) else: self.worker.write_message(message, binary=True) diff --git a/kaldigstserver/worker.py b/kaldigstserver/worker.py index 1f06b70..c2f95dd 100644 --- a/kaldigstserver/worker.py +++ b/kaldigstserver/worker.py @@ -3,8 +3,6 @@ import logging import logging.config import time -import thread -import threading import os import argparse from subprocess import Popen, PIPE @@ -18,26 +16,35 @@ import base64 import time +import asyncio +from tornado.platform.asyncio import AnyThreadEventLoopPolicy import tornado.gen import tornado.process import tornado.ioloop import tornado.locks -from ws4py.client.threadedclient import WebSocketClient -import ws4py.messaging +import tornado.websocket +#from ws4py.client.threadedclient import WebSocketClient +#import ws4py.messaging from decoder import DecoderPipeline from decoder2 import DecoderPipeline2 import common +from concurrent.futures import ThreadPoolExecutor +from tornado.concurrent import run_on_executor + + logger = logging.getLogger(__name__) +executor = ThreadPoolExecutor(max_workers=5) + CONNECT_TIMEOUT = 5 SILENCE_TIMEOUT = 5 USE_NNET2 = False -class ServerWebsocket(WebSocketClient): +class Worker(): STATE_CREATED = 0 STATE_CONNECTED = 1 STATE_INITIALIZED = 2 @@ -51,7 +58,6 @@ def __init__(self, uri, decoder_pipeline, post_processor, full_post_processor=No self.decoder_pipeline = decoder_pipeline self.post_processor = post_processor self.full_post_processor = full_post_processor - WebSocketClient.__init__(self, url=uri, heartbeat_freq=10) self.pipeline_initialized = False self.partial_transcript = "" if USE_NNET2: @@ -63,7 +69,6 @@ def __init__(self, uri, decoder_pipeline, post_processor, full_post_processor=No self.decoder_pipeline.set_error_handler(self._on_error) self.decoder_pipeline.set_eos_handler(self._on_eos) self.state = self.STATE_CREATED - self.last_decoder_message = time.time() self.request_id = "" self.timeout_decoder = 5 self.num_segments = 0 @@ -72,11 +77,21 @@ def __init__(self, uri, decoder_pipeline, post_processor, full_post_processor=No self.processing_condition = tornado.locks.Condition() self.num_processing_threads = 0 - - def opened(self): + @tornado.gen.coroutine + def connect_and_run(self): + logger.info("Opening websocket connection to master server") + self.ws = yield tornado.websocket.websocket_connect(self.uri, ping_interval=10) logger.info("Opened websocket connection to server") self.state = self.STATE_CONNECTED self.last_partial_result = "" + self.last_decoder_message = time.time() + while True: + msg = yield self.ws.read_message() + self.received_message(msg) + if msg is None: + self.closed() + break + logger.info("Finished decoding run") def guard_timeout(self): global SILENCE_TIMEOUT @@ -86,10 +101,10 @@ def guard_timeout(self): self.finish_request() event = dict(status=common.STATUS_NO_SPEECH) try: - self.send(json.dumps(event)) + self.ws.write_message(json.dumps(event)) except: logger.warning("%s: Failed to send error event to master" % (self.request_id)) - self.close() + self.ws.close() return logger.debug("%s: Checking that decoder hasn't been silent for more than %d seconds" % (self.request_id, SILENCE_TIMEOUT)) time.sleep(1) @@ -97,17 +112,17 @@ def guard_timeout(self): def received_message(self, m): logger.debug("%s: Got message from server of type %s" % (self.request_id, str(type(m)))) if self.state == self.__class__.STATE_CONNECTED: - props = json.loads(str(m)) + props = json.loads(m) content_type = props['content_type'] self.request_id = props['id'] self.num_segments = 0 self.decoder_pipeline.init_request(self.request_id, content_type) self.last_decoder_message = time.time() - thread.start_new_thread(self.guard_timeout, ()) + tornado.ioloop.IOLoop.current().run_in_executor(executor, self.guard_timeout) logger.info("%s: Started timeout guard" % self.request_id) logger.info("%s: Initialized request" % self.request_id) self.state = self.STATE_INITIALIZED - elif m.data == "EOS": + elif m == "EOS": if self.state != self.STATE_CANCELLING and self.state != self.STATE_EOS_RECEIVED and self.state != self.STATE_FINISHED: self.decoder_pipeline.end_request() self.state = self.STATE_EOS_RECEIVED @@ -115,15 +130,15 @@ def received_message(self, m): logger.info("%s: Ignoring EOS, worker already in state %d" % (self.request_id, self.state)) else: if self.state != self.STATE_CANCELLING and self.state != self.STATE_EOS_RECEIVED and self.state != self.STATE_FINISHED: - if isinstance(m, ws4py.messaging.BinaryMessage): - self.decoder_pipeline.process_data(m.data) + if isinstance(m, bytes): + self.decoder_pipeline.process_data(m) self.state = self.STATE_PROCESSING - elif isinstance(m, ws4py.messaging.TextMessage): + elif isinstance(m, str): props = json.loads(str(m)) if 'adaptation_state' in props: as_props = props['adaptation_state'] if as_props.get('type', "") == "string+gzip+base64": - adaptation_state = zlib.decompress(base64.b64decode(as_props.get('value', ''))) + adaptation_state = zlib.decompress(base64.b64decode(as_props.get('value', ''))).decode("utf-8") logger.info("%s: Setting adaptation state to user-provided value" % (self.request_id)) self.decoder_pipeline.set_adaptation_state(adaptation_state) else: @@ -163,7 +178,7 @@ def finish_request(self): logger.info("%s: Finished waiting for EOS" % self.request_id) - def closed(self, code, reason=None): + def closed(self): logger.debug("%s: Websocket closed() called" % self.request_id) self.finish_request() logger.debug("%s: Websocket closed() finished" % self.request_id) @@ -177,6 +192,7 @@ def _increment_num_processing(self, delta): def _on_result(self, result, final): try: self._increment_num_processing(1) + if final: # final results are handled by _on_full_result() return @@ -187,12 +203,12 @@ def _on_result(self, result, final): logger.info("%s: Postprocessing (final=%s) result.." % (self.request_id, final)) processed_transcripts = yield self.post_process([result], blocking=False) if processed_transcripts: - logger.info("%s: Postprocessing done." % self.request_id) + logger.info("%s: Postprocessing done." % (self.request_id)) event = dict(status=common.STATUS_SUCCESS, segment=self.num_segments, result=dict(hypotheses=[dict(transcript=processed_transcripts[0])], final=final)) try: - self.send(json.dumps(event)) + self.ws.write_message(json.dumps(event)) except: e = sys.exc_info()[1] logger.warning("Failed to send event to master: %s" % e) @@ -203,19 +219,19 @@ def _on_result(self, result, final): def _on_full_result(self, full_result_json): try: self._increment_num_processing(1) - + self.last_decoder_message = time.time() full_result = json.loads(full_result_json) full_result['segment'] = self.num_segments full_result['id'] = self.request_id if full_result.get("status", -1) == common.STATUS_SUCCESS: - logger.debug(u"%s: Before postprocessing: %s" % (self.request_id, repr(full_result).decode("unicode-escape"))) + logger.debug(u"%s: Before postprocessing: %s" % (self.request_id, repr(full_result))) full_result = yield self.post_process_full(full_result) logger.info("%s: Postprocessing done." % self.request_id) - logger.debug(u"%s: After postprocessing: %s" % (self.request_id, repr(full_result).decode("unicode-escape"))) + logger.debug(u"%s: After postprocessing: %s" % (self.request_id, repr(full_result))) try: - self.send(json.dumps(full_result)) + self.ws.write_message(json.dumps(full_result)) except: e = sys.exc_info()[1] logger.warning("Failed to send event to master: %s" % e) @@ -225,7 +241,7 @@ def _on_full_result(self, full_result_json): else: logger.info("%s: Result status is %d, forwarding the result to the server anyway" % (self.request_id, full_result.get("status", -1))) try: - self.send(json.dumps(full_result)) + self.ws.write_message(json.dumps(full_result)) except: e = sys.exc_info()[1] logger.warning("Failed to send event to master: %s" % e) @@ -236,29 +252,29 @@ def _on_full_result(self, full_result_json): def _on_word(self, word): try: self._increment_num_processing(1) - + self.last_decoder_message = time.time() if word != "<#s>": if len(self.partial_transcript) > 0: self.partial_transcript += " " self.partial_transcript += word logger.debug("%s: Postprocessing partial result.." % self.request_id) - processed_transcript = (yield self.post_process([self.partial_transcript], blocking=False))[0] - if processed_transcript: + processed_transcripts = (yield self.post_process([self.partial_transcript], blocking=False)) + if processed_transcripts: logger.debug("%s: Postprocessing done." % self.request_id) event = dict(status=common.STATUS_SUCCESS, segment=self.num_segments, - result=dict(hypotheses=[dict(transcript=processed_transcript)], final=False)) - self.send(json.dumps(event)) + result=dict(hypotheses=[dict(transcript=processed_transcripts[0])], final=False)) + self.ws.write_message(json.dumps(event)) else: logger.info("%s: Postprocessing final result.." % self.request_id) - processed_transcript = (yield self.post_process(self.partial_transcript, blocking=True)) + processed_transcripts = (yield self.post_process([self.partial_transcript], blocking=True)) logger.info("%s: Postprocessing done." % self.request_id) event = dict(status=common.STATUS_SUCCESS, segment=self.num_segments, - result=dict(hypotheses=[dict(transcript=processed_transcript)], final=True)) - self.send(json.dumps(event)) + result=dict(hypotheses=[dict(transcript=processed_transcripts[0])], final=True)) + self.ws.write_message(json.dumps(event)) self.partial_transcript = "" self.num_segments += 1 finally: @@ -275,17 +291,17 @@ def _on_eos(self, data=None): self.state = self.STATE_FINISHED self.send_adaptation_state() - self.close() + self.ws.close() def _on_error(self, error): self.state = self.STATE_FINISHED event = dict(status=common.STATUS_NOT_ALLOWED, message=error) try: - self.send(json.dumps(event)) + self.ws.write_message(json.dumps(event)) except: e = sys.exc_info()[1] logger.warning("Failed to send event to master: %s" % e) - self.close() + self.ws.close() def send_adaptation_state(self): if hasattr(self.decoder_pipeline, 'get_adaptation_state'): @@ -293,11 +309,11 @@ def send_adaptation_state(self): adaptation_state = self.decoder_pipeline.get_adaptation_state() event = dict(status=common.STATUS_SUCCESS, adaptation_state=dict(id=self.request_id, - value=base64.b64encode(zlib.compress(adaptation_state)), + value=base64.b64encode(zlib.compress(adaptation_state.encode())).decode("utf-8"), type="string+gzip+base64", time=time.strftime("%Y-%m-%dT%H:%M:%S"))) try: - self.send(json.dumps(event)) + self.ws.write_message(json.dumps(event)) except: e = sys.exc_info()[1] logger.warning("Failed to send event to master: " + str(e)) @@ -307,30 +323,33 @@ def send_adaptation_state(self): @tornado.gen.coroutine def post_process(self, texts, blocking=False): if self.post_processor: - logging.debug("%s: Waiting for postprocessor lock" % self.request_id) + logging.debug("%s: Waiting for postprocessor lock with blocking=%d" % (self.request_id, blocking)) if blocking: - timeout=None + timeout = None else: - timeout=0.0 + timeout = 0.1 try: with (yield self.post_processor_lock.acquire(timeout)): result = [] for text in texts: - self.post_processor.stdin.write("%s\n" % text.encode("utf-8")) - self.post_processor.stdin.flush() - logging.debug("%s: Starting postprocessing: %s" % (self.request_id, text)) - text = yield self.post_processor.stdout.read_until('\n') - text = text.decode("utf-8") - logging.debug("%s: Postprocessing returned: %s" % (self.request_id, text)) - text = text.strip() - text = text.replace("\\n", "\n") - result.append(text) - raise tornado.gen.Return(result) - except tornado.gen.TimeoutError: - logging.debug("%s: Skipping postprocessing since post-processor already in use" % (self.request_id)) - raise tornado.gen.Return(None) + try: + logging.debug("%s: Starting postprocessing: %s" % (self.request_id, text)) + self.post_processor.stdin.write((text + "\n").encode("utf-8")) + self.post_processor.stdin.flush() + logging.debug("%s: Reading from postpocessor" % (self.request_id)) + text = yield self.post_processor.stdout.read_until(b'\n') + text = text.decode("utf-8").strip() + logging.debug("%s: Postprocessing returned: %s" % (self.request_id, text)) + text = text.replace("\\n", "\n") + result.append(text) + except Exception as ex: + logging.exception("Error when postprocessing") + return result + except tornado.util.TimeoutError: + logging.info("%s: Skipping postprocessing since post-processor already in use" % (self.request_id)) + return None else: - raise tornado.gen.Return(texts) + return texts @tornado.gen.coroutine def post_process_full(self, full_result): @@ -354,20 +373,19 @@ def post_process_full(self, full_result): for (i, hyp) in enumerate(full_result.get("result", {}).get("hypotheses", [])): hyp["original-transcript"] = hyp["transcript"] hyp["transcript"] = processed_transcripts[i] - raise tornado.gen.Return(full_result) + return full_result +@tornado.gen.coroutine def main_loop(uri, decoder_pipeline, post_processor, full_post_processor=None): while True: - ws = ServerWebsocket(uri, decoder_pipeline, post_processor, full_post_processor=full_post_processor) - try: - logger.info("Opening websocket connection to master server") - ws.connect() - ws.run_forever() + worker = Worker(uri, decoder_pipeline, post_processor, full_post_processor=full_post_processor) + try: + yield worker.connect_and_run() except Exception: logger.error("Couldn't connect to server, waiting for %d seconds", CONNECT_TIMEOUT) - time.sleep(CONNECT_TIMEOUT) + yield tornado.gen.sleep(CONNECT_TIMEOUT) # fixes a race condition - time.sleep(1) + yield tornado.gen.sleep(1) @@ -398,7 +416,7 @@ def main(): post_processor = None if "post-processor" in conf: STREAM = tornado.process.Subprocess.STREAM - post_processor = tornado.process.Subprocess(conf["post-processor"], shell=True, stdin=PIPE, stdout=STREAM) + post_processor = tornado.process.Subprocess(conf["post-processor"], shell=True, stdin=PIPE, stdout=STREAM, ) full_post_processor = None @@ -411,13 +429,13 @@ def main(): global SILENCE_TIMEOUT SILENCE_TIMEOUT = conf.get("silence-timeout", 5) if USE_NNET2: - decoder_pipeline = DecoderPipeline2(conf) + decoder_pipeline = DecoderPipeline2(tornado.ioloop.IOLoop.current(), conf) else: - decoder_pipeline = DecoderPipeline(conf) + decoder_pipeline = DecoderPipeline(tornado.ioloop.IOLoop.current(), conf) - loop = GObject.MainLoop() - thread.start_new_thread(loop.run, ()) - thread.start_new_thread(main_loop, (args.uri, decoder_pipeline, post_processor, full_post_processor)) + gobject_loop = GObject.MainLoop() + tornado.ioloop.IOLoop.current().run_in_executor(executor, gobject_loop.run) + tornado.ioloop.IOLoop.current().spawn_callback(main_loop, args.uri, decoder_pipeline, post_processor, full_post_processor) tornado.ioloop.IOLoop.current().start()