Skip to content

Commit

Permalink
[TASK] Make most long-running tasks threaded.
Browse files Browse the repository at this point in the history
[TASK] Only send OCR + Translation results to requesting client
  • Loading branch information
Sharrnah committed Dec 13, 2022
1 parent 0a384ea commit 61a1fc4
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,38 @@
WS_CLIENTS = set()


def tts_request(msgObj, websocket):
silero_wav, sample_rate = silero.tts.tts(msgObj["value"]["text"])
if silero_wav is not None:
if msgObj["value"]["to_device"]:
if "device_index" in msgObj["value"]:
silero.tts.play_audio(silero_wav, msgObj["value"]["device_index"])
else:
silero.tts.play_audio(silero_wav, settings.GetOption("device_out_index"))
else:
AnswerMessage(websocket, json.dumps({"type": "tts_result", "wav_data": silero_wav.tolist(), "sample_rate": sample_rate}))
if msgObj["value"]["download"]:
wav_data = silero.tts.return_wav_file_binary(silero_wav)
wav_data = base64.b64encode(wav_data).decode('utf-8')
AnswerMessage(websocket, json.dumps({"type": "tts_save", "wav_data": wav_data}))
else:
print("TTS failed")


def flan_req(msgObj, websocket):
flan_result = flanLanguageModel.flan.encode(msgObj["text"])
BroadcastMessage(json.dumps({"type": "flan_result", "flan_result": flan_result}))


def ocr_req(msgObj, websocket):
window_name = settings.GetOption("ocr_window_name")
ocr_result = easyocr.run_image_processing(window_name, ['en', msgObj["value"]["ocr_lang"]])
translate_result, txt_from_lang, txt_to_lang = (texttranslate.TranslateLanguage(" -- ".join(ocr_result), msgObj["value"]["from_lang"], msgObj["value"]["to_lang"]))
AnswerMessage(websocket, json.dumps(
{"type": "translate_result", "original_text": "\n".join(ocr_result), "translate_result": "\n".join(translate_result.split(" -- ")), "txt_from_lang": txt_from_lang,
"txt_to_lang": txt_to_lang}))


def websocketMessageHandler(msgObj, websocket):
if msgObj["type"] == "setting_change":
settings.SetOption(msgObj["name"], msgObj["value"])
Expand All @@ -27,42 +59,25 @@ def websocketMessageHandler(msgObj, websocket):
if msgObj["type"] == "translate_req":
if msgObj["value"]["to_lang"] != "": # if to_lang is empty, don't translate
translate_result, txt_from_lang, txt_to_lang = texttranslate.TranslateLanguage(msgObj["value"]["text"], msgObj["value"]["from_lang"], msgObj["value"]["to_lang"])
BroadcastMessage(json.dumps({"type": "translate_result", "translate_result": translate_result, "txt_from_lang": txt_from_lang, "txt_to_lang": txt_to_lang}))
AnswerMessage(websocket, json.dumps({"type": "translate_result", "translate_result": translate_result, "txt_from_lang": txt_from_lang, "txt_to_lang": txt_to_lang}))

if msgObj["type"] == "ocr_req":
window_name = settings.GetOption("ocr_window_name")
ocr_result = easyocr.run_image_processing(window_name, ['en', msgObj["value"]["ocr_lang"]])
translate_result, txt_from_lang, txt_to_lang = (texttranslate.TranslateLanguage(" -- ".join(ocr_result), msgObj["value"]["from_lang"], msgObj["value"]["to_lang"]))
BroadcastMessage(json.dumps(
{"type": "translate_result", "original_text": "\n".join(ocr_result), "translate_result": "\n".join(translate_result.split(" -- ")), "txt_from_lang": txt_from_lang,
"txt_to_lang": txt_to_lang}))
ocr_thread = threading.Thread(target=ocr_req, args=(msgObj, websocket))
ocr_thread.start()

if msgObj["type"] == "tts_req":
if silero.init():
silero_wav, sample_rate = silero.tts.tts(msgObj["value"]["text"])
if silero_wav is not None:
if msgObj["value"]["to_device"]:
if "device_index" in msgObj["value"]:
silero.tts.play_audio(silero_wav, msgObj["value"]["device_index"])
else:
silero.tts.play_audio(silero_wav, settings.GetOption("device_out_index"))
else:
AnswerMessage(websocket, json.dumps({"type": "tts_result", "wav_data": silero_wav.tolist(), "sample_rate": sample_rate}))
if msgObj["value"]["download"]:
wav_data = silero.tts.return_wav_file_binary(silero_wav)
wav_data = base64.b64encode(wav_data).decode('utf-8')
AnswerMessage(websocket, json.dumps({"type": "tts_save", "wav_data": wav_data}))
else:
print("TTS failed")
tts_thread = threading.Thread(target=tts_request, args=(msgObj, websocket))
tts_thread.start()

if msgObj["type"] == "tts_voice_save_req":
if silero.init():
silero.tts.save_voice()

if msgObj["type"] == "flan_req":
if flanLanguageModel.init():
flan_result = flanLanguageModel.flan.encode(msgObj["text"])
BroadcastMessage(json.dumps({"type": "flan_result", "flan_result": flan_result}))
flan_thread = threading.Thread(target=flan_req, args=msgObj)
flan_thread.start()

if msgObj["type"] == "get_windows_list":
windows_list = WindowCapture.list_window_names()
Expand Down

0 comments on commit 61a1fc4

Please sign in to comment.