From 465c2602bfb79a7b72117b7914c090c28ac02f4b Mon Sep 17 00:00:00 2001 From: Sharrnah <55756126+Sharrnah@users.noreply.github.com> Date: Fri, 4 Nov 2022 21:49:11 +0100 Subject: [PATCH 1/2] [TASK] Added max queue size and timeout, in case too many tasks are queued [FEATURE] Added FLAN-T5 A.I. Model [FEATURE] Added settings.yaml file [TASK] default ascii conversion disabled now, since VRChat supports it now. [TASK] Added VRChat chat notification signal --- README.md | 4 +- VRC_OSCLib.py | 3 +- audioWhisper.py | 47 +++-- audioprocessor.py | 79 +++++++-- flanLanguageModel.py | 128 ++++++++++++++ requirements.txt | 2 +- settings.py | 82 ++++++++- texttranslate.py | 2 +- websocket.py | 9 + .../streaming-overlay-01/index.html | 50 ++++-- .../streaming-overlay-01/main.css | 10 +- websocket_clients/websocket-remote/index.html | 167 +++++++++++------- websocket_clients/websocket-remote/main.css | 10 +- 13 files changed, 462 insertions(+), 131 deletions(-) create mode 100644 flanLanguageModel.py diff --git a/README.md b/README.md index 955cc9e..ba47f09 100644 --- a/README.md +++ b/README.md @@ -62,14 +62,16 @@ _(because of the 2 GB Limit, no direct release files on GitHub)_ | `--osc_ip` | 0 | IP to send OSC messages to. Set to '0' to disable. (For VRChat this should mostly be 127.0.0.1) | | `--osc_port` | 9000 | Port to send OSC message to. ('9000' as default for VRChat) | | `--osc_address` | /chatbox/input | The Address the OSC messages are send to. ('/chatbox/input' as default for VRChat) | -| `--osc_convert_ascii` | True | Convert Text to ASCII compatible when sending over OSC. (Can be set to 'False' as soon as VRChat supports non-ASCII characters) | +| `--osc_convert_ascii` | False | Convert Text to ASCII compatible when sending over OSC. | | `--websocket_ip` | 0 | IP where Websocket Server listens on. Set to '0' to disable. | | `--websocket_port` | 5000 | Port where Websocket Server listens on. | | `--txt_translator` | M2M100 | The Model the AI is loading for text translations. can be 'M2M100', 'ARGOS' or 'None'. | | `--m2m100_size` | small | The Model size if M2M100 text translator is used. can be 'small' or 'large'. (has no effect with --txt_translator ARGOS) | | `--m2m100_device` | auto | The device used for M2M100 translation. can be 'auto', 'cuda' or 'cpu' (has no effect with --txt_translator ARGOS) | | `--ocr_window_name` | VRChat | Window name of the application for OCR translations. | +| `--flan_enabled` | False | Enable FLAN-T5 A.I. (General A.I. which can be used for Question Answering.) | | `--open_browser` | False | Open default Browser with websocket-remote on start. (requires --websocket_ip to be set as well) | +| `--config` | None | Use the specified config file instead of the default 'settings.yaml' (relative to the current path) [overwrites without asking!!!] | | `--verbose` | False | Whether to print verbose output. | ## Usage with 3rd Party Applications diff --git a/VRC_OSCLib.py b/VRC_OSCLib.py index 3ddd519..d49c380 100644 --- a/VRC_OSCLib.py +++ b/VRC_OSCLib.py @@ -106,7 +106,7 @@ def Message(data="example", address="/example", IP='127.0.0.1', PORT=9000): # OSC Send Chat -def Chat(data="example", send=True, address="/chatbox/input", IP='127.0.0.1', PORT=9000, convert_ascii=True): +def Chat(data="example", send=True, nofify=True, address="/chatbox/input", IP='127.0.0.1', PORT=9000, convert_ascii=True): # OSC Bild client = udp_client.UDPClient(IP, PORT) msg = OscMessageBuilder(address=address) @@ -115,6 +115,7 @@ def Chat(data="example", send=True, address="/chatbox/input", IP='127.0.0.1', PO else: msg.add_arg(data) msg.add_arg(send) + msg.add_arg(nofify) m = msg.build() # OSC Send diff --git a/audioWhisper.py b/audioWhisper.py index 79bffdf..af1242f 100644 --- a/audioWhisper.py +++ b/audioWhisper.py @@ -3,6 +3,7 @@ import speech_recognition as sr import audioprocessor import os +from pathlib import Path import click import VRC_OSCLib import websocket @@ -31,7 +32,7 @@ @click.option("--osc_ip", default="0", help="IP to send OSC message to. Set to '0' to disable", type=str) @click.option("--osc_port", default=9000, help="Port to send OSC message to. ('9000' as default for VRChat)", type=int) @click.option("--osc_address", default="/chatbox/input", help="The Address the OSC messages are send to. ('/chatbox/input' as default for VRChat)", type=str) -@click.option("--osc_convert_ascii", default='True', help="Convert Text to ASCII compatible when sending over OSC.", type=str) +@click.option("--osc_convert_ascii", default='False', help="Convert Text to ASCII compatible when sending over OSC.", type=str) @click.option("--websocket_ip", default="0", help="IP where Websocket Server listens on. Set to '0' to disable", type=str) @click.option("--websocket_port", default=5000, help="Port where Websocket Server listens on. ('5000' as default)", type=int) @click.option("--ai_device", default=None, help="The Device the AI is loaded on. can be 'cuda' or 'cpu'. default does autodetect", type=click.Choice(["cuda", "cpu"])) @@ -39,10 +40,18 @@ @click.option("--m2m100_size", default="small", help="The Model size if M2M100 text translator is used. can be 'small' or 'large'. default is small. (has no effect with ARGOS)", type=click.Choice(["small", "large"])) @click.option("--m2m100_device", default="auto", help="The device used for M2M100 translation. (has no effect with ARGOS)", type=click.Choice(["auto", "cuda", "cpu"])) @click.option("--ocr_window_name", default="VRChat", help="Window name of the application for OCR translations. (Default: 'VRChat')", type=str) +@click.option("--flan_enabled", default=False, help="Enable FLAN-T5 A.I. (General A.I. which can be used for Question Answering.)", type=bool) @click.option("--open_browser", default=False, help="Open default Browser with websocket-remote on start. (requires --websocket_ip to be set as well)", is_flag=True, type=bool) +@click.option("--config", default=None, help="Use the specified config file instead of the default 'settings.yaml' (relative to the current path) [overwrites without asking!!!]", type=str) @click.option("--verbose", default=False, help="Whether to print verbose output", is_flag=True, type=bool) -def main(devices, device_index, sample_rate, task, model, language, condition_on_previous_text, energy, pause, dynamic_energy, phrase_time_limit, osc_ip, osc_port, - osc_address, osc_convert_ascii, websocket_ip, websocket_port, ai_device, txt_translator, m2m100_size, m2m100_device, ocr_window_name, open_browser, verbose): +@click.pass_context +def main(ctx, devices, device_index, sample_rate, task, model, language, condition_on_previous_text, energy, dynamic_energy, pause, phrase_time_limit, osc_ip, osc_port, + osc_address, osc_convert_ascii, websocket_ip, websocket_port, ai_device, txt_translator, m2m100_size, m2m100_device, ocr_window_name, flan_enabled, open_browser, config, verbose): + + # Load settings from file + if config is not None: + settings.SETTINGS_PATH = Path(Path.cwd() / config) + settings.LoadYaml(settings.SETTINGS_PATH) if str2bool(devices): audio = pyaudio.PyAudio() @@ -64,11 +73,12 @@ def main(devices, device_index, sample_rate, task, model, language, condition_on print("###################################") # set initial settings - settings.SetOption("whisper_task", task) - settings.SetOption("condition_on_previous_text", condition_on_previous_text) - settings.SetOption("model", model) + settings.IsArgumentSetting(ctx, "task") and settings.SetOption("whisper_task", task) + + settings.IsArgumentSetting(ctx, "condition_on_previous_text") and settings.SetOption("condition_on_previous_text", condition_on_previous_text) + settings.IsArgumentSetting(ctx, "model") and settings.SetOption("model", model) - settings.SetOption("current_language", language) + settings.IsArgumentSetting(ctx, "language") and settings.SetOption("current_language", language) # check if english only model is loaded, and configure whisper languages accordingly. if model.endswith(".en") and language not in {"en", "English"}: @@ -81,23 +91,24 @@ def main(devices, device_index, sample_rate, task, model, language, condition_on else: settings.SetOption("whisper_languages", audioprocessor.whisper_get_languages()) - settings.SetOption("ai_device", ai_device) + settings.IsArgumentSetting(ctx, "ai_device") and settings.SetOption("ai_device", ai_device) settings.SetOption("verbose", verbose) - settings.SetOption("osc_ip", osc_ip) - settings.SetOption("osc_port", osc_port) - settings.SetOption("osc_address", osc_address) - settings.SetOption("osc_convert_ascii", str2bool(osc_convert_ascii)) + settings.IsArgumentSetting(ctx, "osc_ip") and settings.SetOption("osc_ip", osc_ip) + settings.IsArgumentSetting(ctx, "osc_port") and settings.SetOption("osc_port", osc_port) + settings.IsArgumentSetting(ctx, "osc_address") and settings.SetOption("osc_address", osc_address) + settings.IsArgumentSetting(ctx, "osc_convert_ascii") and settings.SetOption("osc_convert_ascii", str2bool(osc_convert_ascii)) - settings.SetOption("websocket_ip", websocket_ip) - settings.SetOption("websocket_port", websocket_port) + settings.IsArgumentSetting(ctx, "websocket_ip") and settings.SetOption("websocket_ip", websocket_ip) + settings.IsArgumentSetting(ctx, "websocket_port") and settings.SetOption("websocket_port", websocket_port) - settings.SetOption("txt_translator", txt_translator) - settings.SetOption("m2m100_size", m2m100_size) + settings.IsArgumentSetting(ctx, "txt_translator") and settings.SetOption("txt_translator", txt_translator) + settings.IsArgumentSetting(ctx, "m2m100_size") and settings.SetOption("m2m100_size", m2m100_size) texttranslate.SetDevice(m2m100_device) - settings.SetOption("ocr_window_name", ocr_window_name) + settings.IsArgumentSetting(ctx, "ocr_window_name") and settings.SetOption("ocr_window_name", ocr_window_name) + settings.IsArgumentSetting(ctx, "flan_enabled") and settings.SetOption("flan_enabled", flan_enabled) if websocket_ip != "0": websocket.StartWebsocketServer(websocket_ip, websocket_port) @@ -137,7 +148,7 @@ def main(devices, device_index, sample_rate, task, model, language, condition_on audioprocessor.q.put(audio_data) # set typing indicator for VRChat - if osc_ip != "0": + if osc_ip != "0" and settings.GetOption("osc_typing_indicator"): VRC_OSCLib.Bool(True, "/chatbox/typing", IP=osc_ip, PORT=osc_port) # send start info for processing indicator in websocket client websocket.BroadcastMessage(json.dumps({"type": "processing_start", "data": True})) diff --git a/audioprocessor.py b/audioprocessor.py index c302c03..a0762b2 100644 --- a/audioprocessor.py +++ b/audioprocessor.py @@ -10,6 +10,7 @@ from pydub import AudioSegment from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE import io +import flanLanguageModel # some regular mistakenly recognized words/sentences on mostly silence audio, which are ignored in processing blacklist = [ @@ -23,7 +24,10 @@ # make all list entries lowercase for later comparison blacklist = list((map(lambda x: x.lower(), blacklist))) -q = queue.Queue() +max_queue_size = 5 +queue_timeout = 5 + +q = queue.Queue(maxsize=max_queue_size) def whisper_get_languages_list_keys(): @@ -46,11 +50,10 @@ def whisper_get_languages(): def whisper_result_handling(result): verbose = settings.GetOption("verbose") osc_ip = settings.GetOption("osc_ip") - osc_address = settings.GetOption("osc_address") - osc_port = settings.GetOption("osc_port") - websocket_ip = settings.GetOption("websocket_ip") + flan_whisper_answer = settings.GetOption("flan_whisper_answer") predicted_text = result.get('text').strip() + result["type"] = "transcript" if not predicted_text.lower() in blacklist: if not verbose: @@ -68,13 +71,48 @@ def whisper_result_handling(result): result["txt_translation"] = predicted_text result["txt_translation_target"] = to_lang - # Send over OSC - if osc_ip != "0": - VRC_OSCLib.Chat(predicted_text, True, osc_address, IP=osc_ip, PORT=osc_port, - convert_ascii=settings.GetOption("osc_convert_ascii")) - # Send to Websocket - if websocket_ip != "0": - websocket.BroadcastMessage(json.dumps(result)) + # replace predicted_text with FLAN response + flan_loaded = False + # check if FLAN is enabled + if flan_whisper_answer and flanLanguageModel.init(): + flan_osc_prefix = settings.GetOption("flan_osc_prefix") + flan_loaded = True + result["type"] = "flan_answer" + # Only process using FLAN if question is asked + if settings.GetOption("flan_process_only_questions"): + prompted_text, prompt_change = flanLanguageModel.flan.whisper_result_prompter(predicted_text) + if prompt_change: + predicted_text = flanLanguageModel.flan.encode(prompted_text) + result['flan_answer'] = predicted_text + print("FLAN question: " + prompted_text) + print("FLAN result: " + predicted_text) + send_message(flan_osc_prefix + predicted_text, result) + # otherwise process every text with FLAN + else: + print("flan general processing") + predicted_text = flanLanguageModel.flan.encode(predicted_text) + result['text'] = predicted_text + print("FLAN result: " + predicted_text) + send_message(flan_osc_prefix + predicted_text, result) + + # send regular message if flan was not loaded + if not flan_loaded: + send_message(predicted_text, result) + + +def send_message(predicted_text, result_obj): + osc_ip = settings.GetOption("osc_ip") + osc_address = settings.GetOption("osc_address") + osc_port = settings.GetOption("osc_port") + websocket_ip = settings.GetOption("websocket_ip") + + # Send over OSC + if osc_ip != "0": + VRC_OSCLib.Chat(predicted_text, True, True, osc_address, IP=osc_ip, PORT=osc_port, + convert_ascii=settings.GetOption("osc_convert_ascii")) + # Send to Websocket + if websocket_ip != "0": + websocket.BroadcastMessage(json.dumps(result_obj)) def load_whisper(model, ai_device): @@ -97,10 +135,24 @@ def whisper_worker(): whisper_ai_device = settings.GetOption("ai_device") audio_model = load_whisper(whisper_model, whisper_ai_device) - print("Say something!") + print("Whisper AI Ready. You can now say something!") while True: - audio_sample = convert_audio(q.get()) + try: + audio = q.get(timeout=queue_timeout) + except queue.Empty: + # print("Queue processing timed out. Skipping...") + continue + except queue.Full: + print("Queue is full. Skipping...") + continue + + # skip if queue is full + if q.qsize() >= max_queue_size: + q.task_done() + continue + + audio_sample = convert_audio(audio) whisper_task = settings.GetOption("whisper_task") @@ -112,6 +164,7 @@ def whisper_worker(): condition_on_previous_text=whisper_condition_on_previous_text) whisper_result_handling(result) + q.task_done() diff --git a/flanLanguageModel.py b/flanLanguageModel.py new file mode 100644 index 0000000..7c20114 --- /dev/null +++ b/flanLanguageModel.py @@ -0,0 +1,128 @@ +# pip install accelerate +from transformers import T5Tokenizer, T5ForConditionalGeneration +import torch +from pathlib import Path +import os +import settings +import random +import downloader + +# MODEL_LINKS = { +# "small": "google/flan-t5-small", +# "base": "google/flan-t5-base", +# "large": "google/flan-t5-large", +# "xl": "google/flan-t5-xl", +# "xxl": "google/flan-t5-xxl" +# } + +MODEL_LINKS = { + "small": "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/FLAN-T5%2Fsmall.zip", + "base": "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/FLAN-T5%2Fbase.zip", + "large": "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/FLAN-T5%2Flarge.zip", + "xl": "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/FLAN-T5%2Fxl.zip", + "xxl": "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/FLAN-T5%2Fxxl.zip" +} + +cache_path = Path(Path.cwd() / ".cache" / "flan-t5-cache") +os.makedirs(cache_path, exist_ok=True) +weight_offload_folder = Path(cache_path / "weight_offload") +os.makedirs(weight_offload_folder, exist_ok=True) + +flan = None + +PROMPT_FORMATTING = { + "question": ["about ", "across ", "after ", "against ", "along ", "am ", "amn't ", "among ", "are ", "aren't ", "around ", "at ", "before ", "behind ", "between ", + "beyond ", "but ", "by ", "can ", "can't ", "concerning ", "could ", "couldn't ", "despite ", "did ", "didn't ", "do ", "does ", "doesn't ", "don't ", + "down ", "during ", "except ", "following ", "for ", "from ", "had ", "hadn't ", "has ", "hasn't ", "have ", "haven't ", "how ", "how's ", "in ", + "including ", "into ", "is ", "isn't ", "like ", "may ", "mayn't ", "might ", "mightn't ", "must ", "mustn't ", "near ", "of ", "off ", "on ", "out ", + "over ", "plus ", "shall ", "shan't ", "should ", "shouldn't ", "since ", "through ", "throughout ", "to ", "towards ", "under ", "until ", "up ", "upon ", + "was ", "wasn't ", "were ", "weren't ", "what ", "what's ", "when ", "when's ", "where ", "where's ", "which ", "which's ", "who ", "who's ", "why ", + "why's ", "will ", "with ", "within ", "without ", "won't ", "would ", "wouldn't "] +} + + +class FlanLanguageModel: + tokenizer = None + model = None + model_size = "large" + max_length = 50 # max result token length. default is 20 + bit_length = 32 # can be 32 = 32 float, 16 = 16 float or 8 = 8 int + device_map = "auto" # can be "auto" or None + low_cpu_mem_usage = True + + # Set the device. "cuda" for GPU or None for CPU + def __init__(self, model_size, device="auto", bit_length=32): + self.model_size = model_size + self.device_map = device + self.bit_length = bit_length + + model_path = Path(cache_path / model_size) + + if not model_path.exists(): + print(f"Downloading {model_size} FLAN-T5 model...") + downloader.download_extract(MODEL_LINKS[model_size], str(cache_path.resolve())) + + model_path_string = str(model_path.resolve()) + + self.tokenizer = T5Tokenizer.from_pretrained(model_path_string, cache_dir=str(cache_path.resolve())) + + match self.bit_length: + case 16: # 16 bit float + self.model = T5ForConditionalGeneration.from_pretrained(model_path_string, device_map=self.device_map, torch_dtype=torch.float16, + offload_folder=str(weight_offload_folder.resolve())) + case 8: # 8 bit int + self.model = T5ForConditionalGeneration.from_pretrained(model_path_string, device_map=self.device_map, load_in_8bit=True, + offload_folder=str(weight_offload_folder.resolve())) + case _: # 32 bit float + self.model = T5ForConditionalGeneration.from_pretrained(model_path_string, device_map=self.device_map, + offload_folder=str(weight_offload_folder.resolve())) + + # Try to modify prompts to get better results + @staticmethod + def whisper_result_prompter(whisper_result: str): + prompt_change = False + question = whisper_result.strip().lower() + question_prompt = whisper_result.strip() + + possible_prompt_prefixes = [] + # looks like a question + if "?" in question and any(ele in question for ele in PROMPT_FORMATTING['question']): + possible_prompt_prefixes.append("Answer the following question by reasoning step-by-step. ") + possible_prompt_prefixes.append("Answer the following question. ") + possible_prompt_prefixes.append("Question: ") + possible_prompt_prefixes.append("Q: ") + prompt_change = True + + if prompt_change: + question_prompt = random.choice(possible_prompt_prefixes) + question_prompt + + return question_prompt, prompt_change + + def encode(self, input_text, token_length=max_length): + if self.device_map == "auto": + input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to("cuda") + else: + input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids + + outputs = self.model.generate(input_ids, max_new_tokens=token_length) + result = self.tokenizer.decode(outputs[0]).replace("", "").replace("", "").replace("", "").strip() + + return result + + +def init(): + global flan + if settings.GetOption("flan_enabled") and flan is None: + model_size = settings.GetOption("flan_size") + flan_bits = settings.GetOption("flan_bits") + flan_device = "auto" if settings.GetOption("flan_device") == "cuda" or settings.GetOption("flan_device") == "auto" else None + print(f"Flan {model_size} is Loading to {('GPU' if flan_device == 'auto' else 'CPU')} using {flan_bits} bit {('INT' if flan_bits == 8 else 'float')} precision...") + + flan = FlanLanguageModel(model_size, bit_length=flan_bits, device=flan_device) + print("Flan loaded.") + return True + else: + if flan is not None: + return True + else: + return False diff --git a/requirements.txt b/requirements.txt index 1a568fa..f729cd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ pykakasi>=2.2.1 ctranslate2>=2.17.0 fairseq sentencepiece>=0.1.96 -protobuf +protobuf==3.20.1 progressbar2 pywin32 diff --git a/settings.py b/settings.py index 6e82518..0f199bb 100644 --- a/settings.py +++ b/settings.py @@ -1,22 +1,88 @@ +# noinspection PyPackageRequirements +import yaml +import os +from pathlib import Path +from click import core + +SETTINGS_PATH = Path(Path.cwd() / 'settings.yaml') + TRANSLATE_SETTINGS = { - # argostranslate settings - "txt_translate": False, - "src_lang": "en", - "trg_lang": "fr", - "txt_ascii": False, + # text translate settings + "txt_translate": False, # if enabled, pipes whisper A.I. results through text translator + "src_lang": "en", # source language for text translator (Whisper A.I. in translation mode always translates to "en") + "trg_lang": "fr", # target language for text translator + "txt_ascii": False, # if enabled, text translator will convert text to romaji. "txt_translator": "M2M100", # can be "M2M100" or "ARGOS" + "m2m100_size": "small", # M2M100 model size. Can be "small" or "large" # ocr settings - "ocr_lang": "en", + "ocr_lang": "en", # language for OCR image to text recognition. + "ocr_window_name": "VRChat", # window name for OCR image to text recognition. # whisper settings - "whisper_task": "transcribe" + "ai_device": None, # can be None (auto), "cuda" or "cpu". + "whisper_task": "transcribe", # Whisper A.I. Can do "transcribe" or "translate" + "current_language": None, # can be None (auto) or any Whisper supported language. + "model": "small", # Whisper model size. Can be "tiny", "base", "small", "medium" or "large" + "condition_on_previous_text": False, # if enabled, Whisper will condition on previous text. (more prone to loops or getting stuck) + + # OSC settings + "osc_ip": "0", + "osc_port": 9000, + "osc_address": "/chatbox/input", + "osc_typing_indicator": True, + "osc_convert_ascii": "False", + + # websocket settings + "websocket_ip": "0", + "websocket_port": 5000, + + # FLAN settings + "flan_enabled": False, # Enable FLAN A.I. + "flan_size": "xl", # FLAN model size. Can be "small", "base", "large", "xl" or "xxl" + "flan_bits": 32, # precision can be set to 32 (float), 16 (float) or 8 (int) bits. 8 bits is the fastest but least precise + "flan_device": "cpu", # can be "cpu", "cuda" or "auto". ("cuda" and "auto" doing the same) + "flan_whisper_answer": True, # if True, the FLAN A.I. will answer to results from the Whisper A.I. + "flan_process_only_questions": True, # if True, the FLAN A.I. will only answer to questions + "flan_osc_prefix": "AI: " # prefix for OSC messages } def SetOption(setting, value): - TRANSLATE_SETTINGS[setting] = value + if setting in TRANSLATE_SETTINGS: + if TRANSLATE_SETTINGS[setting] != value: + TRANSLATE_SETTINGS[setting] = value + # Save settings + SaveYaml(SETTINGS_PATH) + else: + TRANSLATE_SETTINGS[setting] = value + # Save settings + SaveYaml(SETTINGS_PATH) def GetOption(setting): return TRANSLATE_SETTINGS[setting] + + +def LoadYaml(path): + print(path) + if os.path.exists(path): + with open(path, "r") as f: + TRANSLATE_SETTINGS.update(yaml.safe_load(f)) + + +def SaveYaml(path): + to_save_settings = TRANSLATE_SETTINGS.copy() + if "whisper_languages" in to_save_settings: + del to_save_settings['whisper_languages'] + if "lang_swap" in to_save_settings: + del to_save_settings['lang_swap'] + if "verbose" in to_save_settings: + del to_save_settings['verbose'] + + with open(path, "w") as f: + yaml.dump(to_save_settings, f) + + +def IsArgumentSetting(ctx, argument_name): + return ctx.get_parameter_source(argument_name) == core.ParameterSource.COMMANDLINE diff --git a/texttranslate.py b/texttranslate.py index 90f127f..341bd40 100644 --- a/texttranslate.py +++ b/texttranslate.py @@ -1,7 +1,7 @@ import settings import pykakasi import texttranslateARGOS -#import texttranslateM2M100 +# import texttranslateM2M100 import texttranslateM2M100_CTranslate2 diff --git a/websocket.py b/websocket.py index 28acf5f..7ddba8a 100644 --- a/websocket.py +++ b/websocket.py @@ -2,16 +2,20 @@ import asyncio import websockets import json + import texttranslate import imagetranslate from windowcapture import WindowCapture import settings import VRC_OSCLib +import flanLanguageModel + WS_CLIENTS = set() def websocketMessageHandler(msgObj): + if msgObj["type"] == "setting_change": settings.SetOption(msgObj["name"], msgObj["value"]) BroadcastMessage(json.dumps({"type": "translate_settings", "data": settings.TRANSLATE_SETTINGS})) # broadcast updated settings to all clients @@ -26,6 +30,11 @@ def websocketMessageHandler(msgObj): translate_result = (texttranslate.TranslateLanguage(" -- ".join(ocr_result), msgObj["from_lang"], msgObj["to_lang"])) BroadcastMessage(json.dumps({"type": "translate_result", "original_text": "\n".join(ocr_result), "translate_result": "\n".join(translate_result.split(" -- "))})) + if msgObj["type"] == "flan_req": + if flanLanguageModel.init(): + flan_result = flanLanguageModel.flan.encode(msgObj["text"]) + BroadcastMessage(json.dumps({"type": "flan_result", "flan_result": flan_result})) + if msgObj["type"] == "get_windows_list": windows_list = WindowCapture.list_window_names() BroadcastMessage(json.dumps({"type": "windows_list", "data": windows_list})) diff --git a/websocket_clients/streaming-overlay-01/index.html b/websocket_clients/streaming-overlay-01/index.html index 14466b0..913be10 100644 --- a/websocket_clients/streaming-overlay-01/index.html +++ b/websocket_clients/streaming-overlay-01/index.html @@ -13,7 +13,10 @@