Skip to content

Commit

Permalink
add support for TongyiQwen tts (infiniflow#2311)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

add support for TongyiQwen tts
infiniflow#1853

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>
  • Loading branch information
hangters and aopstudio authored Sep 9, 2024
1 parent 2fbe274 commit 7c665b0
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
18 changes: 12 additions & 6 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,24 @@
"max_tokens": 2048,
"model_type": "embedding"
},
{
"llm_name": "sambert-zhide-v1",
"tags": "TTS",
"max_tokens": 2048,
"model_type": "tts"
},
{
"llm_name": "sambert-zhiru-v1",
"tags": "TTS",
"max_tokens": 2048,
"model_type": "tts"
},
{
"llm_name": "text-embedding-v3",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8192,
"model_type": "embedding"
},
{
"llm_name": "paraformer-realtime-8k-v1",
"tags": "SPEECH2TEXT",
"max_tokens": 26214400,
"model_type": "speech2text"
},
{
"llm_name": "qwen-vl-max",
"tags": "LLM,CHAT,IMAGE2TEXT",
Expand Down
3 changes: 2 additions & 1 deletion rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,6 @@
}

TTSModel = {
"Fish Audio": FishAudioTTS
"Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS
}
60 changes: 59 additions & 1 deletion rag/llm/tts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from rag.utils import num_tokens_from_string
import json
import re

import time
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
Expand Down Expand Up @@ -96,3 +96,61 @@ def tts(self, text):

except httpx.HTTPStatusError as e:
raise RuntimeError(f"**ERROR**: {e}")


class QwenTTS(Base):
def __init__(self, key, model_name, base_url=""):
import dashscope

self.model_name = model_name
dashscope.api_key = key

def tts(self, text):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque

class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()

def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break

def on_open(self):
pass

def on_complete(self):
self.dque.append(None)

def on_error(self, response: SpeechSynthesisResponse):
raise RuntimeError(str(response))

def on_close(self):
pass

def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())

text = self.normalize_text(text)
callback = Callback()
SpeechSynthesizer.call(model=self.model_name,
text=text,
callback=callback,
format="mp3")
try:
for data in callback._run():
yield data
yield num_tokens_from_string(text)

except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")

0 comments on commit 7c665b0

Please sign in to comment.