From c3948ad531b9366fa0d24a99f87f4cd6fc0eb96a Mon Sep 17 00:00:00 2001 From: Sihan Chen <39623753+Spycsh@users.noreply.github.com> Date: Mon, 25 Nov 2024 09:55:48 +0800 Subject: [PATCH] openai compatible for asr/tts (#929) * openai compatible for asr/tts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add dep * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/asr/whisper/README.md | 9 +++ comps/asr/whisper/dependency/whisper_model.py | 8 +-- .../asr/whisper/dependency/whisper_server.py | 58 ++++++++++++++++++- comps/asr/whisper/requirements.txt | 1 + comps/cores/proto/api_protocol.py | 32 ++++++++++ comps/tts/gpt-sovits/Dockerfile | 2 +- comps/tts/gpt-sovits/README.md | 8 ++- comps/tts/speecht5/README.md | 5 +- .../tts/speecht5/dependency/speecht5_model.py | 31 +++++----- .../speecht5/dependency/speecht5_server.py | 32 +++++++++- 10 files changed, 160 insertions(+), 26 deletions(-) diff --git a/comps/asr/whisper/README.md b/comps/asr/whisper/README.md index d12a23518..71285dd8a 100644 --- a/comps/asr/whisper/README.md +++ b/comps/asr/whisper/README.md @@ -38,6 +38,14 @@ pip install optimum[habana] cd dependency/ nohup python whisper_server.py --device=hpu & python check_whisper_server.py + +# Or use openai protocol compatible curl command +# Please refer to https://platform.openai.com/docs/api-reference/audio/createTranscription +wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav +curl http://localhost:7066/v1/audio/transcriptions \ + -H "Content-Type: multipart/form-data" \ + -F file="@./sample.wav" \ + -F model="openai/whisper-small" ``` ### 1.3 Start ASR Service/Test @@ -114,6 +122,7 @@ docker run -d -p 9099:9099 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$ # curl http_proxy="" curl http://localhost:9099/v1/audio/transcriptions -XPOST -d '{"byte_str": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA"}' -H 'Content-Type: application/json' + # python python check_asr_server.py ``` diff --git a/comps/asr/whisper/dependency/whisper_model.py b/comps/asr/whisper/dependency/whisper_model.py index 94f1c7ce5..c3e810803 100644 --- a/comps/asr/whisper/dependency/whisper_model.py +++ b/comps/asr/whisper/dependency/whisper_model.py @@ -30,10 +30,10 @@ def __init__( from transformers import WhisperForConditionalGeneration, WhisperProcessor self.device = device - asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path) - print("Downloading model: {}".format(asr_model_name_or_path)) - self.model = WhisperForConditionalGeneration.from_pretrained(asr_model_name_or_path).to(self.device) - self.processor = WhisperProcessor.from_pretrained(asr_model_name_or_path) + self.asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path) + print("Downloading model: {}".format(self.asr_model_name_or_path)) + self.model = WhisperForConditionalGeneration.from_pretrained(self.asr_model_name_or_path).to(self.device) + self.processor = WhisperProcessor.from_pretrained(self.asr_model_name_or_path) self.model.eval() self.language = language diff --git a/comps/asr/whisper/dependency/whisper_server.py b/comps/asr/whisper/dependency/whisper_server.py index 481bf0da0..dcb3dd19c 100644 --- a/comps/asr/whisper/dependency/whisper_server.py +++ b/comps/asr/whisper/dependency/whisper_server.py @@ -5,14 +5,21 @@ import base64 import os import uuid +from typing import List, Optional, Union import uvicorn -from fastapi import FastAPI, Request +from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.responses import Response from pydub import AudioSegment from starlette.middleware.cors import CORSMiddleware from whisper_model import WhisperModel +from comps import CustomLogger +from comps.cores.proto.api_protocol import AudioTranscriptionResponse + +logger = CustomLogger("whisper") +logflag = os.getenv("LOGFLAG", False) + app = FastAPI() asr = None @@ -29,7 +36,7 @@ async def health() -> Response: @app.post("/v1/asr") async def audio_to_text(request: Request): - print("Whisper generation begin.") + logger.info("Whisper generation begin.") uid = str(uuid.uuid4()) file_name = uid + ".wav" request_dict = await request.json() @@ -44,13 +51,58 @@ async def audio_to_text(request: Request): try: asr_result = asr.audio2text(file_name) except Exception as e: - print(e) + logger.error(e) asr_result = e finally: os.remove(file_name) return {"asr_result": asr_result} +@app.post("/v1/audio/transcriptions") +async def audio_transcriptions( + file: UploadFile = File(...), # Handling the uploaded file directly + model: str = Form("openai/whisper-small"), + language: str = Form("english"), + prompt: str = Form(None), + response_format: str = Form("json"), + temperature: float = Form(0), + timestamp_granularities: List[str] = Form(None), +): + logger.info("Whisper generation begin.") + audio_content = await file.read() + # validate the request parameters + if model != asr.asr_model_name_or_path: + raise Exception( + f"ASR model mismatch! Please make sure you pass --model_name_or_path or set environment variable ASR_MODEL_PATH to {model}" + ) + asr.language = language + if prompt is not None or response_format != "json" or temperature != 0 or timestamp_granularities is not None: + logger.warning( + "Currently parameters 'language', 'response_format', 'temperature', 'timestamp_granularities' are not supported!" + ) + + uid = str(uuid.uuid4()) + file_name = uid + ".wav" + # Save the uploaded file + with open(file_name, "wb") as buffer: + buffer.write(audio_content) + + audio = AudioSegment.from_file(file_name) + audio = audio.set_frame_rate(16000) + + audio.export(f"{file_name}", format="wav") + + try: + asr_result = asr.audio2text(file_name) + except Exception as e: + logger.error(e) + asr_result = e + finally: + os.remove(file_name) + + return AudioTranscriptionResponse(text=asr_result) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") diff --git a/comps/asr/whisper/requirements.txt b/comps/asr/whisper/requirements.txt index def6a51b8..8b4644c6f 100644 --- a/comps/asr/whisper/requirements.txt +++ b/comps/asr/whisper/requirements.txt @@ -7,6 +7,7 @@ opentelemetry-sdk prometheus-fastapi-instrumentator pydantic==2.7.2 pydub +python-multipart shortuuid transformers uvicorn diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 75cab6df5..69156fbc2 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -300,6 +300,38 @@ class AudioChatCompletionRequest(BaseModel): user: Optional[str] = None +# Pydantic does not support UploadFile directly +# class AudioTranscriptionRequest(BaseModel): +# # Ordered by official OpenAI API documentation +# # default values are same with +# # https://platform.openai.com/docs/api-reference/audio/createTranscription +# file: UploadFile = File(...) +# model: Optional[str] = "openai/whisper-small" +# language: Optional[str] = "english" +# prompt: Optional[str] = None +# response_format: Optional[str] = "json" +# temperature: Optional[str] = 0 +# timestamp_granularities: Optional[List] = None + + +class AudioTranscriptionResponse(BaseModel): + # Ordered by official OpenAI API documentation + # default values are same with + # https://platform.openai.com/docs/api-reference/audio/json-object + text: str + + +class AudioSpeechRequest(BaseModel): + # Ordered by official OpenAI API documentation + # default values are same with + # https://platform.openai.com/docs/api-reference/audio/createSpeech + input: str + model: Optional[str] = "microsoft/speecht5_tts" + voice: Optional[str] = "default" + response_format: Optional[str] = "mp3" + speed: Optional[float] = 1.0 + + class ChatMessage(BaseModel): role: str content: str diff --git a/comps/tts/gpt-sovits/Dockerfile b/comps/tts/gpt-sovits/Dockerfile index e5f004819..5fce95770 100644 --- a/comps/tts/gpt-sovits/Dockerfile +++ b/comps/tts/gpt-sovits/Dockerfile @@ -16,7 +16,7 @@ ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libiomp5.so:/usr/lib/x86_64-linux-gnu/l ENV MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000" # Clone source repo -RUN git clone https://github.com/RVC-Boss/GPT-SoVITS.git +RUN git clone --branch openai_compat --single-branch https://github.com/Spycsh/GPT-SoVITS.git # Download pre-trained models, and prepare env RUN git clone https://huggingface.co/lj1995/GPT-SoVITS pretrained_models RUN mv pretrained_models/* GPT-SoVITS/GPT_SoVITS/pretrained_models/ && \ diff --git a/comps/tts/gpt-sovits/README.md b/comps/tts/gpt-sovits/README.md index d051065f6..4876764cb 100644 --- a/comps/tts/gpt-sovits/README.md +++ b/comps/tts/gpt-sovits/README.md @@ -54,9 +54,15 @@ wget https://github.com/OpenTalker/SadTalker/blob/main/examples/driven_audio/chi docker cp chinese_poem1.wav gpt-sovits-service:/home/user/chinese_poem1.wav -http_proxy="" curl localhost:9880/change_refer -d '{ +curl localhost:9880/change_refer -d '{ "refer_wav_path": "/home/user/chinese_poem1.wav", "prompt_text": "窗前明月光,疑是地上霜,举头望明月,低头思故乡。", "prompt_language": "zh" }' ``` + +- openai protocol compatible request + +```bash +curl localhost:9880/v1/audio/speech -XPOST -d '{"input":"你好呀,你是谁. Hello, who are you?"}' -H 'Content-Type: application/json' --output speech.mp3 +``` diff --git a/comps/tts/speecht5/README.md b/comps/tts/speecht5/README.md index 4539a24a8..fba5e87b8 100644 --- a/comps/tts/speecht5/README.md +++ b/comps/tts/speecht5/README.md @@ -85,8 +85,11 @@ docker run -p 9088:9088 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$htt #### 2.2.3 Test ```bash -# curl curl http://localhost:7055/v1/tts -XPOST -d '{"text": "Who are you?"}' -H 'Content-Type: application/json' +# openai protocol compatible +# voice can be 'male' or 'default' +curl http://localhost:7055/v1/audio/speech -XPOST -d '{"input":"Who are you?", "voice": "male"}' -H 'Content-Type: application/json' --output speech.wav + curl http://localhost:9088/v1/audio/speech -XPOST -d '{"text": "Who are you?"}' -H 'Content-Type: application/json' ``` diff --git a/comps/tts/speecht5/dependency/speecht5_model.py b/comps/tts/speecht5/dependency/speecht5_model.py index 8a800d68d..778323b56 100644 --- a/comps/tts/speecht5/dependency/speecht5_model.py +++ b/comps/tts/speecht5/dependency/speecht5_model.py @@ -17,33 +17,34 @@ def __init__(self, device="cpu"): adapt_transformers_to_gaudi() - model_name_or_path = "microsoft/speecht5_tts" + self.model_name_or_path = "microsoft/speecht5_tts" vocoder_model_name_or_path = "microsoft/speecht5_hifigan" - self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name_or_path).to(device) + self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name_or_path).to(device) self.model.eval() - self.processor = SpeechT5Processor.from_pretrained(model_name_or_path, normalize=True) + self.processor = SpeechT5Processor.from_pretrained(self.model_name_or_path, normalize=True) self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_model_name_or_path).to(device) self.vocoder.eval() # fetch default speaker embedding - if os.path.exists("spk_embed_default.pt"): - self.default_speaker_embedding = torch.load("spk_embed_default.pt") - else: - try: + try: + for spk_embed in ["spk_embed_default.pt", "spk_embed_male.pt"]: + if os.path.exists(spk_embed): + continue + p = subprocess.Popen( [ "curl", "-O", "https://raw.githubusercontent.com/intel/intel-extension-for-transformers/main/" - "intel_extension_for_transformers/neural_chat/assets/speaker_embeddings/" - "spk_embed_default.pt", + "intel_extension_for_transformers/neural_chat/assets/speaker_embeddings/" + spk_embed, ] ) + p.wait() - self.default_speaker_embedding = torch.load("spk_embed_default.pt") - except Exception as e: - print("Warning! Need to prepare speaker_embeddings, will use the backup embedding.") - self.default_speaker_embedding = torch.zeros((1, 512)) + self.default_speaker_embedding = torch.load("spk_embed_default.pt") + except Exception as e: + print("Warning! Need to prepare speaker_embeddings, will use the backup embedding.") + self.default_speaker_embedding = torch.zeros((1, 512)) if self.device == "hpu": # do hpu graph warmup with variable inputs @@ -87,7 +88,9 @@ def _warmup_speecht5_hpu_graph(self): "OPEA is an ecosystem orchestration framework to integrate performant GenAI technologies & workflows leading to quicker GenAI adoption and business value." ) - def t2s(self, text): + def t2s(self, text, voice="default"): + if voice == "male": + self.default_speaker_embedding = torch.load("spk_embed_male.pt") if self.device == "hpu": # See https://github.com/huggingface/optimum-habana/pull/824 from optimum.habana.utils import set_seed diff --git a/comps/tts/speecht5/dependency/speecht5_server.py b/comps/tts/speecht5/dependency/speecht5_server.py index ae4d5d0dd..5435f91b9 100644 --- a/comps/tts/speecht5/dependency/speecht5_server.py +++ b/comps/tts/speecht5/dependency/speecht5_server.py @@ -3,14 +3,21 @@ import argparse import base64 +import os import soundfile as sf import uvicorn from fastapi import FastAPI, Request -from fastapi.responses import Response +from fastapi.responses import Response, StreamingResponse from speecht5_model import SpeechT5Model from starlette.middleware.cors import CORSMiddleware +from comps import CustomLogger +from comps.cores.proto.api_protocol import AudioSpeechRequest + +logger = CustomLogger("speecht5") +logflag = os.getenv("LOGFLAG", False) + app = FastAPI() tts = None @@ -27,7 +34,7 @@ async def health() -> Response: @app.post("/v1/tts") async def text_to_speech(request: Request): - print("SpeechT5 generation begin.") + logger.info("SpeechT5 generation begin.") request_dict = await request.json() text = request_dict.pop("text") @@ -40,6 +47,27 @@ async def text_to_speech(request: Request): return {"tts_result": b64_str} +@app.post("/v1/audio/speech") +async def audio_speech(request: AudioSpeechRequest): + logger.info("SpeechT5 generation begin.") + # validate the request parameters + if request.model != tts.model_name_or_path: + raise Exception("TTS model mismatch! Currently only support model: microsoft/speecht5_tts") + if request.voice not in ["default", "male"] or request.speed != 1.0: + logger.warning("Currently parameter 'speed' can only be 1.0 and 'voice' can only be default or male!") + + speech = tts.t2s(request.input, voice=request.voice) + + tmp_path = "tmp.wav" + sf.write(tmp_path, speech, samplerate=16000) + + def audio_gen(): + with open(tmp_path, "rb") as f: + yield from f + + return StreamingResponse(audio_gen(), media_type=f"audio/{request.response_format}") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0")