diff --git a/websocket.py b/websocket.py index 9d96a10..aabb419 100644 --- a/websocket.py +++ b/websocket.py @@ -15,6 +15,38 @@ WS_CLIENTS = set() +def tts_request(msgObj, websocket): + silero_wav, sample_rate = silero.tts.tts(msgObj["value"]["text"]) + if silero_wav is not None: + if msgObj["value"]["to_device"]: + if "device_index" in msgObj["value"]: + silero.tts.play_audio(silero_wav, msgObj["value"]["device_index"]) + else: + silero.tts.play_audio(silero_wav, settings.GetOption("device_out_index")) + else: + AnswerMessage(websocket, json.dumps({"type": "tts_result", "wav_data": silero_wav.tolist(), "sample_rate": sample_rate})) + if msgObj["value"]["download"]: + wav_data = silero.tts.return_wav_file_binary(silero_wav) + wav_data = base64.b64encode(wav_data).decode('utf-8') + AnswerMessage(websocket, json.dumps({"type": "tts_save", "wav_data": wav_data})) + else: + print("TTS failed") + + +def flan_req(msgObj, websocket): + flan_result = flanLanguageModel.flan.encode(msgObj["text"]) + BroadcastMessage(json.dumps({"type": "flan_result", "flan_result": flan_result})) + + +def ocr_req(msgObj, websocket): + window_name = settings.GetOption("ocr_window_name") + ocr_result = easyocr.run_image_processing(window_name, ['en', msgObj["value"]["ocr_lang"]]) + translate_result, txt_from_lang, txt_to_lang = (texttranslate.TranslateLanguage(" -- ".join(ocr_result), msgObj["value"]["from_lang"], msgObj["value"]["to_lang"])) + AnswerMessage(websocket, json.dumps( + {"type": "translate_result", "original_text": "\n".join(ocr_result), "translate_result": "\n".join(translate_result.split(" -- ")), "txt_from_lang": txt_from_lang, + "txt_to_lang": txt_to_lang})) + + def websocketMessageHandler(msgObj, websocket): if msgObj["type"] == "setting_change": settings.SetOption(msgObj["name"], msgObj["value"]) @@ -27,33 +59,16 @@ def websocketMessageHandler(msgObj, websocket): if msgObj["type"] == "translate_req": if msgObj["value"]["to_lang"] != "": # if to_lang is empty, don't translate translate_result, txt_from_lang, txt_to_lang = texttranslate.TranslateLanguage(msgObj["value"]["text"], msgObj["value"]["from_lang"], msgObj["value"]["to_lang"]) - BroadcastMessage(json.dumps({"type": "translate_result", "translate_result": translate_result, "txt_from_lang": txt_from_lang, "txt_to_lang": txt_to_lang})) + AnswerMessage(websocket, json.dumps({"type": "translate_result", "translate_result": translate_result, "txt_from_lang": txt_from_lang, "txt_to_lang": txt_to_lang})) if msgObj["type"] == "ocr_req": - window_name = settings.GetOption("ocr_window_name") - ocr_result = easyocr.run_image_processing(window_name, ['en', msgObj["value"]["ocr_lang"]]) - translate_result, txt_from_lang, txt_to_lang = (texttranslate.TranslateLanguage(" -- ".join(ocr_result), msgObj["value"]["from_lang"], msgObj["value"]["to_lang"])) - BroadcastMessage(json.dumps( - {"type": "translate_result", "original_text": "\n".join(ocr_result), "translate_result": "\n".join(translate_result.split(" -- ")), "txt_from_lang": txt_from_lang, - "txt_to_lang": txt_to_lang})) + ocr_thread = threading.Thread(target=ocr_req, args=(msgObj, websocket)) + ocr_thread.start() if msgObj["type"] == "tts_req": if silero.init(): - silero_wav, sample_rate = silero.tts.tts(msgObj["value"]["text"]) - if silero_wav is not None: - if msgObj["value"]["to_device"]: - if "device_index" in msgObj["value"]: - silero.tts.play_audio(silero_wav, msgObj["value"]["device_index"]) - else: - silero.tts.play_audio(silero_wav, settings.GetOption("device_out_index")) - else: - AnswerMessage(websocket, json.dumps({"type": "tts_result", "wav_data": silero_wav.tolist(), "sample_rate": sample_rate})) - if msgObj["value"]["download"]: - wav_data = silero.tts.return_wav_file_binary(silero_wav) - wav_data = base64.b64encode(wav_data).decode('utf-8') - AnswerMessage(websocket, json.dumps({"type": "tts_save", "wav_data": wav_data})) - else: - print("TTS failed") + tts_thread = threading.Thread(target=tts_request, args=(msgObj, websocket)) + tts_thread.start() if msgObj["type"] == "tts_voice_save_req": if silero.init(): @@ -61,8 +76,8 @@ def websocketMessageHandler(msgObj, websocket): if msgObj["type"] == "flan_req": if flanLanguageModel.init(): - flan_result = flanLanguageModel.flan.encode(msgObj["text"]) - BroadcastMessage(json.dumps({"type": "flan_result", "flan_result": flan_result})) + flan_thread = threading.Thread(target=flan_req, args=msgObj) + flan_thread.start() if msgObj["type"] == "get_windows_list": windows_list = WindowCapture.list_window_names()