From f4121844e66328654c76e90cbf767f39a5b14ab0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?=
<41315874+fumiama@users.noreply.github.com>
Date: Mon, 24 Jun 2024 22:11:56 +0900
Subject: [PATCH] doc: sync to latest grammar (#427)
---
ChatTTS/core.py | 6 +-
ChatTTS/model/gpt.py | 2 +-
ChatTTS/norm.py | 6 +-
ChatTTS/utils/gpu.py | 4 +-
README.md | 62 +++++++---
examples/ipynb/colab.ipynb | 127 ++++++++++++++-------
examples/ipynb/example.ipynb | 87 +++++++++++---
examples/web/funcs.py | 8 +-
tools/llm/__init__.py | 1 +
{ChatTTS/experimental => tools/llm}/llm.py | 3 +-
tools/logger/log.py | 4 +-
11 files changed, 224 insertions(+), 86 deletions(-)
create mode 100644 tools/llm/__init__.py
rename {ChatTTS/experimental => tools/llm}/llm.py (99%)
diff --git a/ChatTTS/core.py b/ChatTTS/core.py
index b3300c943..7ca0bbe99 100644
--- a/ChatTTS/core.py
+++ b/ChatTTS/core.py
@@ -46,7 +46,7 @@ def has_loaded(self, use_decoder = False):
for module in check_list:
if not hasattr(self, module) and module not in self.pretrain_models:
- self.logger.warn(f'{module} not initialized.')
+ self.logger.warning(f'{module} not initialized.')
not_finish = True
if not not_finish:
@@ -75,7 +75,7 @@ def download_models(
except:
download_path = None
if download_path is None or force_redownload:
- self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
+ self.logger.log(logging.INFO, f'download from HF: https://huggingface.co/2Noise/ChatTTS')
try:
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
except:
@@ -232,7 +232,7 @@ def _load(
try:
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
except RuntimeError as e:
- self.logger.warning(f'Compile failed,{e}. fallback to normal mode.')
+ self.logger.warning(f'compile failed: {e}. fallback to normal mode.')
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py
index 171c010eb..4a1b264f6 100644
--- a/ChatTTS/model/gpt.py
+++ b/ChatTTS/model/gpt.py
@@ -412,7 +412,7 @@ def generate(
pbar.update(1)
if not finish.all():
- self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
+ self.logger.warning(f'incomplete result. hit max_new_token: {max_new_token}')
del finish
diff --git a/ChatTTS/norm.py b/ChatTTS/norm.py
index 7c1440d3c..08e710e03 100644
--- a/ChatTTS/norm.py
+++ b/ChatTTS/norm.py
@@ -137,7 +137,7 @@ def __call__(
text = self._apply_half2full_map(text)
invalid_characters = self._count_invalid_characters(text)
if len(invalid_characters):
- self.logger.warn(f'found invalid characters: {invalid_characters}')
+ self.logger.warning(f'found invalid characters: {invalid_characters}')
text = self._apply_character_map(text)
if do_homophone_replacement:
arr, replaced_words = _fast_replace(
@@ -153,10 +153,10 @@ def __call__(
def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
if name in self.normalizers:
- self.logger.warn(f"name {name} has been registered")
+ self.logger.warning(f"name {name} has been registered")
return False
if not isinstance(normalizer, Callable[[str], str]):
- self.logger.warn("normalizer must have caller type (str) -> str")
+ self.logger.warning("normalizer must have caller type (str) -> str")
return False
self.normalizers[name] = normalizer
return True
diff --git a/ChatTTS/utils/gpu.py b/ChatTTS/utils/gpu.py
index 248dff3cc..17a30670c 100644
--- a/ChatTTS/utils/gpu.py
+++ b/ChatTTS/utils/gpu.py
@@ -18,10 +18,10 @@ def select_device(min_memory=2047):
device = torch.device('cpu')
elif torch.backends.mps.is_available():
# For Apple M1/M2 chips with Metal Performance Shaders
- logger.get_logger().info('Apple GPU found, using MPS.')
+ logger.get_logger().info('apple GPU found, using MPS.')
device = torch.device('mps')
else:
- logger.get_logger().warning('No GPU found, use CPU instead')
+ logger.get_logger().warning('no GPU found, use CPU instead')
device = torch.device('cpu')
return device
diff --git a/README.md b/README.md
index dadd521e4..1312c0319 100644
--- a/README.md
+++ b/README.md
@@ -112,7 +112,7 @@ chat.load_models(compile=False) # Set to True for better performance
texts = ["PUT YOUR TEXT HERE",]
-wavs = chat.infer(texts, )
+wavs = chat.infer(texts)
torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
```
@@ -125,23 +125,27 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
rand_spk = chat.sample_random_speaker()
-params_infer_code = {
- 'spk_emb': rand_spk, # add sampled speaker
- 'temperature': .3, # using custom temperature
- 'top_P': 0.7, # top P decode
- 'top_K': 20, # top K decode
-}
+params_infer_code = ChatTTS.Chat.InferCodeParams(
+ spk_emb = rand_spk, # add sampled speaker
+ temperature = .3, # using custom temperature
+ top_P = 0.7, # top P decode
+ top_K = 20, # top K decode
+)
###################################
# For sentence level manual control.
# use oral_(0-9), laugh_(0-2), break_(0-7)
# to generate special token in text to synthesize.
-params_refine_text = {
- 'prompt': '[oral_2][laugh_0][break_6]'
-}
+params_refine_text = ChatTTS.Chat.RefineTextParams(
+ prompt='[oral_2][laugh_0][break_6]',
+)
-wavs = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
+wavs = chat.infer(
+ texts,
+ params_refine_text=params_refine_text,
+ params_infer_code=params_infer_code,
+)
###################################
# For word level manual control.
@@ -163,16 +167,42 @@ capabilities with precise control over prosodic elements [laugh]like like
[uv_break] use the project responsibly at your own risk.[uv_break]
""".replace('\n', '') # English is still experimental.
-params_refine_text = {
- 'prompt': '[oral_2][laugh_0][break_4]'
-}
-# audio_array_cn = chat.infer(inputs_cn, params_refine_text=params_refine_text)
+params_refine_text = ChatTTS.Chat.RefineTextParams(
+ prompt='[oral_2][laugh_0][break_4]',
+)
+
audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
```
+
+
+
+
+
+**male speaker**
+
+ |
+
+
+**female speaker**
+
+ |
+
+
+
+
[male speaker](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
+ |
+
+
[female speaker](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
+
+ |
+
+
+
+
## FAQ
@@ -206,4 +236,4 @@ In the current released model, the only token-level control units are `[laugh]`,
data:image/s3,"s3://crabby-images/fa980/fa980b8a26023d603bcdc17a7f466353d244ba57" alt="counter"
-
\ No newline at end of file
+
diff --git a/examples/ipynb/colab.ipynb b/examples/ipynb/colab.ipynb
index 589eaace1..6e604c5b8 100644
--- a/examples/ipynb/colab.ipynb
+++ b/examples/ipynb/colab.ipynb
@@ -51,6 +51,7 @@
"\n",
"from ChatTTS import ChatTTS\n",
"from ChatTTS.tools.logger import get_logger\n",
+ "from ChatTTS.tools.normalizer import normalizer_en_nemo_text, normalizer_zh_tn\n",
"from IPython.display import Audio"
]
},
@@ -71,19 +72,40 @@
},
"outputs": [],
"source": [
- "chat = ChatTTS.Chat(get_logger(\"ChatTTS\"))"
+ "logger = get_logger(\"ChatTTS\")\n",
+ "chat = ChatTTS.Chat(logger, remove_exist=True)\n",
+ "\n",
+ "# try to load normalizer\n",
+ "try:\n",
+ " chat.normalizer.register(\"en\", normalizer_en_nemo_text())\n",
+ "except:\n",
+ " logger.warning('Package nemo_text_processing not found!')\n",
+ " logger.warning(\n",
+ " 'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing',\n",
+ " )\n",
+ "try:\n",
+ " chat.normalizer.register(\"zh\", normalizer_zh_tn())\n",
+ "except:\n",
+ " logger.warning('Package WeTextProcessing not found!')\n",
+ " logger.warning(\n",
+ " 'Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing',\n",
+ " )"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "id": "3Ty427FZNH30"
+ },
"source": [
"### Here are three choices for loading models:"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "id": "NInF7Lk1NH30"
+ },
"source": [
"#### 1. Load models from Hugging Face:"
]
@@ -91,7 +113,9 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "id": "VVtNlNosNH30"
+ },
"outputs": [],
"source": [
"# use force_redownload=True if the weights have been updated.\n",
@@ -100,7 +124,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "id": "AhBD5WUPNH30"
+ },
"source": [
"#### 2. Load models from local directories 'asset' and 'config':"
]
@@ -108,7 +134,9 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "id": "83UwV6SGNH31"
+ },
"outputs": [],
"source": [
"chat.load_models()\n",
@@ -117,7 +145,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "id": "c0qjGPNkNH31"
+ },
"source": [
"#### 3. Load models from a custom path:"
]
@@ -125,7 +155,9 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "id": "oCSBx0Q7NH31"
+ },
"outputs": [],
"source": [
"# write the model path into custom_path\n",
@@ -134,7 +166,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "id": "VoEki3XMNH31"
+ },
"source": [
"### You can also unload models to save the memory"
]
@@ -142,7 +176,9 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "id": "3FdsTSxoNH31"
+ },
"outputs": [],
"source": [
"chat.unload()"
@@ -219,8 +255,13 @@
},
"outputs": [],
"source": [
- "params_infer_code = {'prompt':'[speed_5]', 'temperature':.3}\n",
- "params_refine_text = {'prompt':'[oral_2][laugh_0][break_6]'}\n",
+ "params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
+ " prompt='[speed_5]',\n",
+ " temperature=.3,\n",
+ ")\n",
+ "params_refine_text = ChatTTS.Chat.RefineTextParams(\n",
+ " prompt='[oral_2][laugh_0][break_6]',\n",
+ ")\n",
"\n",
"wav = chat.infer('四川美食可多了,有麻辣火锅、宫保鸡丁、麻婆豆腐、担担面、回锅肉、夫妻肺片等,每样都让人垂涎三尺。', \\\n",
" params_refine_text=params_refine_text, params_infer_code=params_infer_code)"
@@ -255,7 +296,9 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
- "params_infer_code = {'spk_emb' : rand_spk, }\n",
+ "params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
+ " spk_emb=rand_spk,\n",
+ ")\n",
"\n",
"wav = chat.infer('四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。', \\\n",
" params_refine_text=params_refine_text, params_infer_code=params_infer_code)"
@@ -302,7 +345,7 @@
},
"outputs": [],
"source": [
- "wav = chat.infer(refined_text)"
+ "wav = chat.infer(refined_text, skip_refine_text=True)"
]
},
{
@@ -316,80 +359,86 @@
"Audio(wav[0], rate=24_000, autoplay=True)"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GG5AMbQbbSrl"
+ },
+ "source": [
+ "## LLM Call"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "R2WjuVrWbSrl"
+ "id": "3rkfwc3UbSrl"
},
"outputs": [],
"source": [
- "text = 'so we found being competitive and collaborative [uv_break] was a huge way of staying [uv_break] motivated towards our goals, [uv_break] so [uv_break] one person to call [uv_break] when you fall off, [uv_break] one person who [uv_break] gets you back [uv_break] on then [uv_break] one person [uv_break] to actually do the activity with.'\n",
- "wav = chat.infer(text, skip_refine_text=True)"
+ "from ChatTTS.tools.llm import ChatOpenAI\n",
+ "\n",
+ "API_KEY = ''\n",
+ "client = ChatOpenAI(api_key=API_KEY,\n",
+ " base_url=\"https://api.deepseek.com\",\n",
+ " model=\"deepseek-chat\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "71Y4pBdl-_Yd"
+ "id": "TTkIsXozbSrm"
},
"outputs": [],
"source": [
- "Audio(wav[0], rate=24_000, autoplay=True)"
+ "user_question = '四川有哪些好吃的美食呢?'"
]
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {
- "id": "GG5AMbQbbSrl"
+ "id": "3yT8uNz-RVy1"
},
+ "outputs": [],
"source": [
- "## LLM Call"
+ "text = client.call(user_question, prompt_version = 'deepseek')\n",
+ "text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "3rkfwc3UbSrl"
+ "id": "6qddpv7lRW-3"
},
"outputs": [],
"source": [
- "from ChatTTS.experimental.llm import llm_api\n",
- "\n",
- "API_KEY = ''\n",
- "client = llm_api(api_key=API_KEY,\n",
- " base_url=\"https://api.deepseek.com\",\n",
- " model=\"deepseek-chat\")"
+ "text = client.call(text, prompt_version = 'deepseek_TN')\n",
+ "text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "TTkIsXozbSrm"
+ "id": "qNhCJG4VbSrm"
},
"outputs": [],
"source": [
- "user_question = '四川有哪些好吃的美食呢?'\n",
- "text = client.call(user_question, prompt_version = 'deepseek')\n",
- "print(text)\n",
- "text = client.call(text, prompt_version = 'deepseek_TN')\n",
- "print(text)"
+ "wav = chat.infer(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "qNhCJG4VbSrm"
+ "id": "Wq1XQHmFRQI3"
},
"outputs": [],
"source": [
- "params_infer_code = {'spk_emb' : rand_spk, 'temperature':.3}\n",
- "\n",
- "wav = chat.infer(text, params_infer_code=params_infer_code)"
+ "Audio(wav[0], rate=24_000, autoplay=True)"
]
}
],
diff --git a/examples/ipynb/example.ipynb b/examples/ipynb/example.ipynb
index 0ef0e6a0a..6eb3cd172 100644
--- a/examples/ipynb/example.ipynb
+++ b/examples/ipynb/example.ipynb
@@ -34,6 +34,7 @@
"\n",
"import ChatTTS\n",
"from tools.logger import get_logger\n",
+ "from tools.normalizer import normalizer_en_nemo_text, normalizer_zh_tn\n",
"from IPython.display import Audio"
]
},
@@ -52,7 +53,24 @@
"source": [
"os.chdir(root_dir)\n",
"\n",
- "chat = ChatTTS.Chat(get_logger(\"ChatTTS\"))"
+ "logger = get_logger(\"ChatTTS\")\n",
+ "chat = ChatTTS.Chat(logger)\n",
+ "\n",
+ "# try to load normalizer\n",
+ "try:\n",
+ " chat.normalizer.register(\"en\", normalizer_en_nemo_text())\n",
+ "except:\n",
+ " logger.warning('Package nemo_text_processing not found!')\n",
+ " logger.warning(\n",
+ " 'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing',\n",
+ " )\n",
+ "try:\n",
+ " chat.normalizer.register(\"zh\", normalizer_zh_tn())\n",
+ "except:\n",
+ " logger.warning('Package WeTextProcessing not found!')\n",
+ " logger.warning(\n",
+ " 'Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing',\n",
+ " )"
]
},
{
@@ -186,8 +204,13 @@
"metadata": {},
"outputs": [],
"source": [
- "params_infer_code = {'prompt':'[speed_5]', 'temperature':.3}\n",
- "params_refine_text = {'prompt':'[oral_2][laugh_0][break_6]'}\n",
+ "params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
+ " prompt='[speed_5]',\n",
+ " temperature=.3,\n",
+ ")\n",
+ "params_refine_text = ChatTTS.Chat.RefineTextParams(\n",
+ " prompt='[oral_2][laugh_0][break_6]',\n",
+ ")\n",
"\n",
"wav = chat.infer('四川美食可多了,有麻辣火锅、宫保鸡丁、麻婆豆腐、担担面、回锅肉、夫妻肺片等,每样都让人垂涎三尺。', \\\n",
" params_refine_text=params_refine_text, params_infer_code=params_infer_code)"
@@ -216,7 +239,9 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
- "params_infer_code = {'spk_emb' : rand_spk, }\n",
+ "params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
+ " spk_emb=rand_spk,\n",
+ ")\n",
"\n",
"wav = chat.infer('四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。', \\\n",
" params_refine_text=params_refine_text, params_infer_code=params_infer_code)"
@@ -245,7 +270,17 @@
"outputs": [],
"source": [
"text = \"So we found being competitive and collaborative was a huge way of staying motivated towards our goals, so one person to call when you fall off, one person who gets you back on then one person to actually do the activity with.\"\n",
- "chat.infer(text, refine_text_only=True)"
+ "refined_text = chat.infer(text, refine_text_only=True)\n",
+ "refined_text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wav = chat.infer(refined_text, skip_refine_text=True)"
]
},
{
@@ -254,8 +289,7 @@
"metadata": {},
"outputs": [],
"source": [
- "text = 'so we found being competitive and collaborative [uv_break] was a huge way of staying [uv_break] motivated towards our goals, [uv_break] so [uv_break] one person to call [uv_break] when you fall off, [uv_break] one person who [uv_break] gets you back [uv_break] on then [uv_break] one person [uv_break] to actually do the activity with.'\n",
- "wav = chat.infer(text, skip_refine_text=True)"
+ "Audio(wav[0], rate=24_000, autoplay=True)"
]
},
{
@@ -271,10 +305,10 @@
"metadata": {},
"outputs": [],
"source": [
- "from ChatTTS.experimental.llm import llm_api\n",
+ "from tools.llm import ChatOpenAI\n",
"\n",
"API_KEY = ''\n",
- "client = llm_api(api_key=API_KEY,\n",
+ "client = ChatOpenAI(api_key=API_KEY,\n",
" base_url=\"https://api.deepseek.com\",\n",
" model=\"deepseek-chat\")"
]
@@ -285,11 +319,27 @@
"metadata": {},
"outputs": [],
"source": [
- "user_question = '四川有哪些好吃的美食呢?'\n",
+ "user_question = '四川有哪些好吃的美食呢?'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"text = client.call(user_question, prompt_version = 'deepseek')\n",
- "print(text)\n",
+ "text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"text = client.call(text, prompt_version = 'deepseek_TN')\n",
- "print(text)"
+ "text"
]
},
{
@@ -298,9 +348,16 @@
"metadata": {},
"outputs": [],
"source": [
- "params_infer_code = {'spk_emb' : rand_spk, 'temperature':.3}\n",
- "\n",
- "wav = chat.infer(text, params_infer_code=params_infer_code)"
+ "wav = chat.infer(text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Audio(wav[0], rate=24_000, autoplay=True)"
]
}
],
diff --git a/examples/web/funcs.py b/examples/web/funcs.py
index 4e5b89af1..5d00761db 100644
--- a/examples/web/funcs.py
+++ b/examples/web/funcs.py
@@ -50,15 +50,15 @@ def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
try:
chat.normalizer.register("en", normalizer_en_nemo_text())
except:
- logger.warn('Package nemo_text_processing not found!')
- logger.warn(
+ logger.warning('Package nemo_text_processing not found!')
+ logger.warning(
'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing',
)
try:
chat.normalizer.register("zh", normalizer_zh_tn())
except:
- logger.warn('Package WeTextProcessing not found!')
- logger.warn(
+ logger.warning('Package WeTextProcessing not found!')
+ logger.warning(
'Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing',
)
return ret
diff --git a/tools/llm/__init__.py b/tools/llm/__init__.py
new file mode 100644
index 000000000..19983fc7e
--- /dev/null
+++ b/tools/llm/__init__.py
@@ -0,0 +1 @@
+from .llm import ChatOpenAI
diff --git a/ChatTTS/experimental/llm.py b/tools/llm/llm.py
similarity index 99%
rename from ChatTTS/experimental/llm.py
rename to tools/llm/llm.py
index a4e5f94d8..352e3907b 100644
--- a/ChatTTS/experimental/llm.py
+++ b/tools/llm/llm.py
@@ -1,4 +1,3 @@
-
from openai import OpenAI
prompt_dict = {
@@ -22,7 +21,7 @@
],
}
-class llm_api:
+class ChatOpenAI:
def __init__(self, api_key, base_url, model):
self.client = OpenAI(
api_key = api_key,
diff --git a/tools/logger/log.py b/tools/logger/log.py
index 7d26ce78d..4400116d7 100644
--- a/tools/logger/log.py
+++ b/tools/logger/log.py
@@ -51,9 +51,11 @@ def format(self, record: logging.LogRecord):
return logstr
-def get_logger(name: str, lv = logging.INFO):
+def get_logger(name: str, lv = logging.INFO, remove_exist=False):
logger = logging.getLogger(name)
logger.setLevel(lv)
+ if remove_exist and logger.hasHandlers():
+ logger.handlers.clear()
if not logger.hasHandlers():
syslog = logging.StreamHandler()
syslog.setFormatter(Formatter())