Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparkTTS #2535

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def apikey_json(keys):

elif factory =="XunFei Spark":
llm_name = req["llm_name"]
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
elif req["model_type"] == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])

elif factory == "BaiduYiyan":
llm_name = req["llm_name"]
Expand Down
3 changes: 2 additions & 1 deletion rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,6 @@
TTSModel = {
"Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS,
"OpenAI":OpenAITTS
"OpenAI":OpenAITTS,
"XunFei Spark":SparkTTS
}
124 changes: 118 additions & 6 deletions rag/llm/tts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,30 @@
# limitations under the License.
#

import requests
from typing import Annotated, Literal
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import queue
import re
import ssl
import time
from abc import ABC
from datetime import datetime
from time import mktime
from typing import Annotated, Literal
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time

import httpx
import ormsgpack
import requests
import websocket
from pydantic import BaseModel, conint

from rag.utils import num_tokens_from_string
import json
import re
import time


class ServeReferenceAudio(BaseModel):
Expand Down Expand Up @@ -161,7 +175,7 @@ def on_event(self, result: SpeechSynthesisResult):

class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url: base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
Expand All @@ -185,3 +199,101 @@ def tts(self, text, voice="alloy"):
for chunk in response.iter_content():
if chunk:
yield chunk


class SparkTTS:
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2

def __init__(self, key, model_name, base_url=""):
key = json.loads(key)
self.APPID = key.get("spark_app_id", "xxxxxxx")
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
self.APIKey = key.get("spark_api_key", "xxxxxx")
self.model_name = model_name
self.CommonArgs = {"app_id": self.APPID}
self.audio_queue = queue.Queue()

# 用来存储音频数据

# 生成url
def create_url(self):
url = 'wss://tts-api.xfyun.cn/v2/tts'
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
return url

def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name

class Callback:
def __init__(self):
self.audio_queue = audio_queue

def on_message(self, ws, message):
message = json.loads(message)
code = message["code"]
sid = message["sid"]
audio = message["data"]["audio"]
audio = base64.b64decode(audio)
status = message["data"]["status"]
if status == 2:
ws.close()
if code != 0:
errMsg = message["message"]
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
else:
self.audio_queue.put(audio)

def on_error(self, ws, error):
raise Exception(error)

def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志

def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs,
"business": BusinessArgs,
"data": Data}
ws.send(json.dumps(d))

thread.start_new_thread(run, ())

wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ vertexai==1.64.0
volcengine==1.0.146
voyageai==0.2.3
webdriver_manager==4.0.1
websocket==0.2.1
websocket-client==1.8.0
Werkzeug==3.0.3
wikipedia==1.4.0
word2number==1.1
Expand Down
6 changes: 6 additions & 0 deletions web/src/locales/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ The above is the content you need to summarize.`,
SparkModelNameMessage: 'Please select Spark model',
addSparkAPIPassword: 'Spark APIPassword',
SparkAPIPasswordMessage: 'please input your APIPassword',
addSparkAPPID: 'Spark APPID',
SparkAPPIDMessage: 'please input your APPID',
addSparkAPISecret: 'Spark APISecret',
SparkAPISecretMessage: 'please input your APISecret',
addSparkAPIKey: 'Spark APIKey',
SparkAPIKeyMessage: 'please input your APIKey',
yiyanModelNameMessage: 'Please input model name',
addyiyanAK: 'yiyan API KEY',
yiyanAKMessage: 'Please input your API KEY',
Expand Down
6 changes: 6 additions & 0 deletions web/src/locales/zh-traditional.ts
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,12 @@ export default {
SparkModelNameMessage: '請選擇星火模型!',
addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '請輸入 APIPassword',
addSparkAPPID: '星火 APPID',
SparkAPPIDMessage: '請輸入 APPID',
addSparkAPISecret: '星火 APISecret',
SparkAPISecretMessage: '請輸入 APISecret',
addSparkAPIKey: '星火 APIKey',
SparkAPIKeyMessage: '請輸入 APIKey',
yiyanModelNameMessage: '輸入模型名稱',
addyiyanAK: '一言 API KEY',
yiyanAKMessage: '請輸入 API KEY',
Expand Down
6 changes: 6 additions & 0 deletions web/src/locales/zh.ts
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,12 @@ export default {
SparkModelNameMessage: '请选择星火模型!',
addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '请输入 APIPassword',
addSparkAPPID: '星火 APPID',
SparkAPPIDMessage: '请输入 APPID',
addSparkAPISecret: '星火 APISecret',
SparkAPISecretMessage: '请输入 APISecret',
addSparkAPIKey: '星火 APIKey',
SparkAPIKeyMessage: '请输入 APIKey',
yiyanModelNameMessage: '请输入模型名称',
addyiyanAK: '一言 API KEY',
yiyanAKMessage: '请输入 API KEY',
Expand Down
70 changes: 56 additions & 14 deletions web/src/pages/user-setting/setting-model/spark-modal/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
vision: boolean;
spark_api_password: string;
spark_app_id: string;
spark_api_secret: string;
spark_api_key: string;
};

const { Option } = Select;
Expand Down Expand Up @@ -63,28 +66,67 @@ const SparkModal = ({
>
<Select placeholder={t('modelTypeMessage')}>
<Option value="chat">chat</Option>
<Option value="tts">tts</Option>
</Select>
</Form.Item>
<Form.Item<FieldType>
label={t('modelName')}
name="llm_name"
initialValue={'Spark-Max'}
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
<Option value="Spark-Max">Spark-Max</Option>
<Option value="Spark-Lite">Spark-Lite</Option>
<Option value="Spark-Pro">Spark-Pro</Option>
<Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
<Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
</Select>
<Input placeholder={t('modelNameMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('addSparkAPIPassword')}
name="spark_api_password"
rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
>
<Input placeholder={t('SparkAPIPasswordMessage')} />
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (
<Form.Item<FieldType>
label={t('addSparkAPIPassword')}
name="spark_api_password"
rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
>
<Input placeholder={t('SparkAPIPasswordMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPPID')}
name="spark_app_id"
rules={[{ required: true, message: t('SparkAPPIDMessage') }]}
>
<Input placeholder={t('SparkAPPIDMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPISecret')}
name="spark_api_secret"
rules={[{ required: true, message: t('SparkAPISecretMessage') }]}
>
<Input placeholder={t('SparkAPISecretMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPIKey')}
name="spark_api_key"
rules={[{ required: true, message: t('SparkAPIKeyMessage') }]}
>
<Input placeholder={t('SparkAPIKeyMessage')} />
</Form.Item>
)
}
</Form.Item>
</Form>
</Modal>
Expand Down