diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index a1da4156bb..fbff768508 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -15,8 +15,10 @@ # from copy import deepcopy from flask import request, Response -from flask_login import login_required +from flask_login import login_required,current_user from api.db.services.dialog_service import DialogService, ConversationService, chat +from api.db.services.llm_service import LLMBundle, TenantService +from api.db import LLMType from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result @@ -176,6 +178,38 @@ def stream(): return server_error_response(e) +@manager.route('/tts', methods=['POST']) +@login_required +def tts(): + req = request.json + text = req["text"] + + tenants = TenantService.get_by_user_id(current_user.id) + if not tenants: + return get_data_error_result(retmsg="Tenant not found!") + + tts_id = tenants[0]["tts_id"] + if not tts_id: + return get_data_error_result(retmsg="No default TTS model is set") + + tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) + def stream_audio(): + try: + for chunk in tts_mdl(text): + yield chunk + except Exception as e: + yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), + "data": {"answer": "**ERROR**: "+str(e)}}, + ensure_ascii=False).encode('utf-8') + + resp = Response(stream_audio(), mimetype="audio/mpeg") + resp.headers.add_header("Cache-Control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + + return resp + + @manager.route('/delete_msg', methods=['POST']) @login_required @validate_request("conversation_id", "message_id") @@ -221,4 +255,4 @@ def thumbup(): break ConversationService.update_by_id(conv["id"], conv) - return get_json_result(data=conv) \ No newline at end of file + return get_json_result(data=conv) diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 07468b814b..07e20d47a3 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -96,6 +96,7 @@ def get_by_user_id(cls, user_id): cls.model.rerank_id, cls.model.asr_id, cls.model.img2txt_id, + cls.model.tts_id, cls.model.parser_ids, UserTenant.role] return list(cls.model.select(*fields)