From 224957158f3524b0ce6ce521c603bf430f6d9998 Mon Sep 17 00:00:00 2001 From: SunXiaoye <31361630+JingofXin@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:56:14 +0800 Subject: [PATCH] Support bark for TTS (#100) --- lazyllm/common/globals.py | 2 +- lazyllm/components/auto/autodeploy.py | 3 + lazyllm/components/text_to_speech/__init__.py | 0 lazyllm/components/text_to_speech/bark.py | 70 +++++++++++++++++++ .../utils/downloader/model_mapping.py | 6 +- lazyllm/tools/webpages/webmodule.py | 21 ++++-- tests/advanced_tests/test_deploy.py | 6 ++ 7 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 lazyllm/components/text_to_speech/__init__.py create mode 100644 lazyllm/components/text_to_speech/bark.py diff --git a/lazyllm/common/globals.py b/lazyllm/common/globals.py index 50ce126f..c5022429 100644 --- a/lazyllm/common/globals.py +++ b/lazyllm/common/globals.py @@ -66,7 +66,7 @@ def __getattr__(self, __name: str) -> Any: raise AttributeError(f'Attr {__name} not found in globals') def clear(self): - self.__data.pop(self._sid) + self.__data.pop(self._sid, None) def _clear_all(self): self.__data.clear() diff --git a/lazyllm/components/auto/autodeploy.py b/lazyllm/components/auto/autodeploy.py index 184a2cdd..afe7dfaa 100644 --- a/lazyllm/components/auto/autodeploy.py +++ b/lazyllm/components/auto/autodeploy.py @@ -5,6 +5,7 @@ from .auto_helper import model_map, get_model_name, check_requirements from lazyllm.components.embedding.embed import EmbeddingDeploy from lazyllm.components.stable_diffusion.stable_diffusion3 import StableDiffusionDeploy +from lazyllm.components.text_to_speech.bark import BarkDeploy from ..utils.downloader import ModelManager class AutoDeploy(LazyLLMDeployBase): @@ -20,6 +21,8 @@ def __new__(cls, base_model, source=lazyllm.config['model_source'], trust_remote return EmbeddingDeploy(trust_remote_code, launcher) elif type == 'sd' or ModelManager.get_model_type(model_name) == 'sd': return StableDiffusionDeploy(launcher) + elif type == 'tts' or ModelManager.get_model_type(model_name) == 'tts': + return BarkDeploy(launcher) map_name = model_map(model_name) candidates = get_configer().query_deploy(lazyllm.config['gpu_type'], launcher.ngpus, map_name, max_token_num) diff --git a/lazyllm/components/text_to_speech/__init__.py b/lazyllm/components/text_to_speech/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lazyllm/components/text_to_speech/bark.py b/lazyllm/components/text_to_speech/bark.py new file mode 100644 index 00000000..b0e32c52 --- /dev/null +++ b/lazyllm/components/text_to_speech/bark.py @@ -0,0 +1,70 @@ +import os +import json + +import lazyllm +from lazyllm import LOG +from lazyllm.thirdparty import torch +from lazyllm.thirdparty import transformers as tf +from ..utils.downloader import ModelManager + +class Bark(object): + + def __init__(self, base_sd, source=None, trust_remote_code=True, init=False): + source = lazyllm.config['model_source'] if not source else source + self.base_sd = ModelManager(source).download(base_sd) + self.trust_remote_code = trust_remote_code + self.processor, self.bark = None, None + self.init_flag = lazyllm.once_flag() + self.device = 'cpu' + if init: + lazyllm.call_once(self.init_flag, self.load_bark) + + def load_bark(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.processor = tf.AutoProcessor.from_pretrained(self.base_sd) + self.processor.speaker_embeddings['repo_or_path'] = self.base_sd + self.bark = tf.BarkModel.from_pretrained(self.base_sd, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2").to(self.device) + + def __call__(self, string): + lazyllm.call_once(self.init_flag, self.load_bark) + if isinstance(string, str): + query = string + voice_preset = "v2/zh_speaker_9" + elif isinstance(string, dict): + query = string['inputs'] + voice_preset = string['voice_preset'] + else: + raise TypeError(f"Not support input type:{type(string)}, requires str or dict.") + inputs = self.processor(query, voice_preset=voice_preset).to(self.device) + speech = self.bark.generate(**inputs) * 32767 + res = {'sounds': ( + self.bark.generation_config.sample_rate, + speech.cpu().numpy().squeeze().tolist() + )} + return json.dumps(res) + + +class BarkDeploy(object): + keys_name_handle = { + 'inputs': 'inputs', + } + message_format = { + 'inputs': 'Who are you ?', + 'voice_preset': None, + } + default_headers = {'Content-Type': 'application/json'} + + def __init__(self, launcher=None): + self.launcher = launcher + + def __call__(self, finetuned_model=None, base_model=None): + if not os.path.exists(finetuned_model) or \ + not any(filename.endswith('.bin') or filename.endswith('.safetensors') + for _, _, filename in os.walk(finetuned_model) if filename): + if not finetuned_model: + LOG.warning(f"Note! That finetuned_model({finetuned_model}) is an invalid path, " + f"base_model({base_model}) will be used") + finetuned_model = base_model + return lazyllm.deploy.RelayServer(func=Bark(finetuned_model), launcher=self.launcher)() diff --git a/lazyllm/components/utils/downloader/model_mapping.py b/lazyllm/components/utils/downloader/model_mapping.py index 461e1b6b..ecb5b38d 100644 --- a/lazyllm/components/utils/downloader/model_mapping.py +++ b/lazyllm/components/utils/downloader/model_mapping.py @@ -285,7 +285,11 @@ "stable-diffusion-3-medium": { "source": {"huggingface": "stabilityai/stable-diffusion-3-medium", "modelscope": "AI-ModelScope/stable-diffusion-3-medium-diffusers"}, "type": "sd", - }, + }, + "bark":{ + "source": {"huggingface": "suno/bark", "modelscope": "mapjack/bark"}, + "type": "tts", + }, "llava-1.5-7b": { "source": {"huggingface": "llava-hf/llava-1.5-7b-hf", "modelscope": "huangjintao/llava-1.5-7b-hf"}, "type": "vlm", diff --git a/lazyllm/tools/webpages/webmodule.py b/lazyllm/tools/webpages/webmodule.py index fcde6c03..ad4f764c 100644 --- a/lazyllm/tools/webpages/webmodule.py +++ b/lazyllm/tools/webpages/webmodule.py @@ -12,6 +12,7 @@ from PIL import Image from io import BytesIO from types import GeneratorType +import numpy as np import lazyllm from lazyllm import LOG, globals @@ -93,6 +94,9 @@ def init_web(self, component_descs): for _, gname, name, ctype, value in component_descs: if ctype in ('Checkbox', 'Text'): components.append(getattr(gr, ctype)(interactive=True, value=value, label=f'{gname}.{name}')) + elif ctype == 'Dropdown': + components.append(getattr(gr, ctype)(interactive=True, choices=value, + label=f'{gname}.{name}')) else: raise KeyError(f'invalid component type: {ctype}') with gr.Row(): @@ -262,7 +266,11 @@ def get_log_and_message(s): elif 'images_base64' in r: image_data = r.pop('images_base64')[0] image = Image.open(BytesIO(base64.b64decode(image_data))) - return "The image is: ", "".join(log_history), image + return "The image is: ", "".join(log_history), {'img': image} + elif 'sounds' in r: + sound_data = r.pop('sounds') + sound_data = (sound_data[0], np.array(sound_data[1]).astype(np.int16)) + return "The Audio is: ", "".join(log_history), {'audio': sound_data} else: s = s except (ValueError, KeyError, TypeError): @@ -272,11 +280,14 @@ def get_log_and_message(s): return s, "".join(log_history), None log_history = [] - image = None + file = None if isinstance(result, (str, dict)): - result, log, image = get_log_and_message(result) - if image: - chat_history[-1][1] = gr.Image(image) + result, log, file = get_log_and_message(result) + if file: + if 'img' in file: + chat_history[-1][1] = gr.Image(file['img']) + if 'audio' in file: + chat_history[-1][1] = gr.Audio(file['audio']) elif isinstance(result, str): chat_history[-1][1] = result elif isinstance(result, GeneratorType): diff --git a/tests/advanced_tests/test_deploy.py b/tests/advanced_tests/test_deploy.py index 345e665a..87fef796 100644 --- a/tests/advanced_tests/test_deploy.py +++ b/tests/advanced_tests/test_deploy.py @@ -96,6 +96,12 @@ def test_sd3(self): res = m('a little cat') assert "images_base64" in json.loads(res) + def test_bark(self): + m = lazyllm.TrainableModule('bark') + m.update_server() + res = m('你好啊,很高兴认识你。') + assert "sounds" in json.loads(res) + def test_vlm_and_lmdeploy(self): chat = lazyllm.TrainableModule('internvl-chat-2b-v1-5').deploy_method(deploy.LMDeploy) m = lazyllm.ServerModule(chat)