Skip to content

Commit

Permalink
openai compatible for asr/tts (opea-project#929)
Browse files Browse the repository at this point in the history
* openai compatible for asr/tts

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add dep

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Spycsh and pre-commit-ci[bot] authored Nov 25, 2024
1 parent bbca7fd commit c3948ad
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 26 deletions.
9 changes: 9 additions & 0 deletions comps/asr/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ pip install optimum[habana]
cd dependency/
nohup python whisper_server.py --device=hpu &
python check_whisper_server.py

# Or use openai protocol compatible curl command
# Please refer to https://platform.openai.com/docs/api-reference/audio/createTranscription
wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav
curl http://localhost:7066/v1/audio/transcriptions \
-H "Content-Type: multipart/form-data" \
-F file="@./sample.wav" \
-F model="openai/whisper-small"
```

### 1.3 Start ASR Service/Test
Expand Down Expand Up @@ -114,6 +122,7 @@ docker run -d -p 9099:9099 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$
# curl
http_proxy="" curl http://localhost:9099/v1/audio/transcriptions -XPOST -d '{"byte_str": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA"}' -H 'Content-Type: application/json'


# python
python check_asr_server.py
```
8 changes: 4 additions & 4 deletions comps/asr/whisper/dependency/whisper_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def __init__(
from transformers import WhisperForConditionalGeneration, WhisperProcessor

self.device = device
asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
print("Downloading model: {}".format(asr_model_name_or_path))
self.model = WhisperForConditionalGeneration.from_pretrained(asr_model_name_or_path).to(self.device)
self.processor = WhisperProcessor.from_pretrained(asr_model_name_or_path)
self.asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
print("Downloading model: {}".format(self.asr_model_name_or_path))
self.model = WhisperForConditionalGeneration.from_pretrained(self.asr_model_name_or_path).to(self.device)
self.processor = WhisperProcessor.from_pretrained(self.asr_model_name_or_path)
self.model.eval()

self.language = language
Expand Down
58 changes: 55 additions & 3 deletions comps/asr/whisper/dependency/whisper_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
import base64
import os
import uuid
from typing import List, Optional, Union

import uvicorn
from fastapi import FastAPI, Request
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.responses import Response
from pydub import AudioSegment
from starlette.middleware.cors import CORSMiddleware
from whisper_model import WhisperModel

from comps import CustomLogger
from comps.cores.proto.api_protocol import AudioTranscriptionResponse

logger = CustomLogger("whisper")
logflag = os.getenv("LOGFLAG", False)

app = FastAPI()
asr = None

Expand All @@ -29,7 +36,7 @@ async def health() -> Response:

@app.post("/v1/asr")
async def audio_to_text(request: Request):
print("Whisper generation begin.")
logger.info("Whisper generation begin.")
uid = str(uuid.uuid4())
file_name = uid + ".wav"
request_dict = await request.json()
Expand All @@ -44,13 +51,58 @@ async def audio_to_text(request: Request):
try:
asr_result = asr.audio2text(file_name)
except Exception as e:
print(e)
logger.error(e)
asr_result = e
finally:
os.remove(file_name)
return {"asr_result": asr_result}


@app.post("/v1/audio/transcriptions")
async def audio_transcriptions(
file: UploadFile = File(...), # Handling the uploaded file directly
model: str = Form("openai/whisper-small"),
language: str = Form("english"),
prompt: str = Form(None),
response_format: str = Form("json"),
temperature: float = Form(0),
timestamp_granularities: List[str] = Form(None),
):
logger.info("Whisper generation begin.")
audio_content = await file.read()
# validate the request parameters
if model != asr.asr_model_name_or_path:
raise Exception(
f"ASR model mismatch! Please make sure you pass --model_name_or_path or set environment variable ASR_MODEL_PATH to {model}"
)
asr.language = language
if prompt is not None or response_format != "json" or temperature != 0 or timestamp_granularities is not None:
logger.warning(
"Currently parameters 'language', 'response_format', 'temperature', 'timestamp_granularities' are not supported!"
)

uid = str(uuid.uuid4())
file_name = uid + ".wav"
# Save the uploaded file
with open(file_name, "wb") as buffer:
buffer.write(audio_content)

audio = AudioSegment.from_file(file_name)
audio = audio.set_frame_rate(16000)

audio.export(f"{file_name}", format="wav")

try:
asr_result = asr.audio2text(file_name)
except Exception as e:
logger.error(e)
asr_result = e
finally:
os.remove(file_name)

return AudioTranscriptionResponse(text=asr_result)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
Expand Down
1 change: 1 addition & 0 deletions comps/asr/whisper/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ opentelemetry-sdk
prometheus-fastapi-instrumentator
pydantic==2.7.2
pydub
python-multipart
shortuuid
transformers
uvicorn
Expand Down
32 changes: 32 additions & 0 deletions comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,38 @@ class AudioChatCompletionRequest(BaseModel):
user: Optional[str] = None


# Pydantic does not support UploadFile directly
# class AudioTranscriptionRequest(BaseModel):
# # Ordered by official OpenAI API documentation
# # default values are same with
# # https://platform.openai.com/docs/api-reference/audio/createTranscription
# file: UploadFile = File(...)
# model: Optional[str] = "openai/whisper-small"
# language: Optional[str] = "english"
# prompt: Optional[str] = None
# response_format: Optional[str] = "json"
# temperature: Optional[str] = 0
# timestamp_granularities: Optional[List] = None


class AudioTranscriptionResponse(BaseModel):
# Ordered by official OpenAI API documentation
# default values are same with
# https://platform.openai.com/docs/api-reference/audio/json-object
text: str


class AudioSpeechRequest(BaseModel):
# Ordered by official OpenAI API documentation
# default values are same with
# https://platform.openai.com/docs/api-reference/audio/createSpeech
input: str
model: Optional[str] = "microsoft/speecht5_tts"
voice: Optional[str] = "default"
response_format: Optional[str] = "mp3"
speed: Optional[float] = 1.0


class ChatMessage(BaseModel):
role: str
content: str
Expand Down
2 changes: 1 addition & 1 deletion comps/tts/gpt-sovits/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libiomp5.so:/usr/lib/x86_64-linux-gnu/l
ENV MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000"

# Clone source repo
RUN git clone https://github.com/RVC-Boss/GPT-SoVITS.git
RUN git clone --branch openai_compat --single-branch https://github.com/Spycsh/GPT-SoVITS.git
# Download pre-trained models, and prepare env
RUN git clone https://huggingface.co/lj1995/GPT-SoVITS pretrained_models
RUN mv pretrained_models/* GPT-SoVITS/GPT_SoVITS/pretrained_models/ && \
Expand Down
8 changes: 7 additions & 1 deletion comps/tts/gpt-sovits/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@ wget https://github.com/OpenTalker/SadTalker/blob/main/examples/driven_audio/chi

docker cp chinese_poem1.wav gpt-sovits-service:/home/user/chinese_poem1.wav

http_proxy="" curl localhost:9880/change_refer -d '{
curl localhost:9880/change_refer -d '{
"refer_wav_path": "/home/user/chinese_poem1.wav",
"prompt_text": "窗前明月光,疑是地上霜,举头望明月,低头思故乡。",
"prompt_language": "zh"
}'
```

- openai protocol compatible request

```bash
curl localhost:9880/v1/audio/speech -XPOST -d '{"input":"你好呀,你是谁. Hello, who are you?"}' -H 'Content-Type: application/json' --output speech.mp3
```
5 changes: 4 additions & 1 deletion comps/tts/speecht5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ docker run -p 9088:9088 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$htt
#### 2.2.3 Test

```bash
# curl
curl http://localhost:7055/v1/tts -XPOST -d '{"text": "Who are you?"}' -H 'Content-Type: application/json'

# openai protocol compatible
# voice can be 'male' or 'default'
curl http://localhost:7055/v1/audio/speech -XPOST -d '{"input":"Who are you?", "voice": "male"}' -H 'Content-Type: application/json' --output speech.wav

curl http://localhost:9088/v1/audio/speech -XPOST -d '{"text": "Who are you?"}' -H 'Content-Type: application/json'
```
31 changes: 17 additions & 14 deletions comps/tts/speecht5/dependency/speecht5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,34 @@ def __init__(self, device="cpu"):

adapt_transformers_to_gaudi()

model_name_or_path = "microsoft/speecht5_tts"
self.model_name_or_path = "microsoft/speecht5_tts"
vocoder_model_name_or_path = "microsoft/speecht5_hifigan"
self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name_or_path).to(device)
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name_or_path).to(device)
self.model.eval()
self.processor = SpeechT5Processor.from_pretrained(model_name_or_path, normalize=True)
self.processor = SpeechT5Processor.from_pretrained(self.model_name_or_path, normalize=True)
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_model_name_or_path).to(device)
self.vocoder.eval()

# fetch default speaker embedding
if os.path.exists("spk_embed_default.pt"):
self.default_speaker_embedding = torch.load("spk_embed_default.pt")
else:
try:
try:
for spk_embed in ["spk_embed_default.pt", "spk_embed_male.pt"]:
if os.path.exists(spk_embed):
continue

p = subprocess.Popen(
[
"curl",
"-O",
"https://raw.githubusercontent.com/intel/intel-extension-for-transformers/main/"
"intel_extension_for_transformers/neural_chat/assets/speaker_embeddings/"
"spk_embed_default.pt",
"intel_extension_for_transformers/neural_chat/assets/speaker_embeddings/" + spk_embed,
]
)

p.wait()
self.default_speaker_embedding = torch.load("spk_embed_default.pt")
except Exception as e:
print("Warning! Need to prepare speaker_embeddings, will use the backup embedding.")
self.default_speaker_embedding = torch.zeros((1, 512))
self.default_speaker_embedding = torch.load("spk_embed_default.pt")
except Exception as e:
print("Warning! Need to prepare speaker_embeddings, will use the backup embedding.")
self.default_speaker_embedding = torch.zeros((1, 512))

if self.device == "hpu":
# do hpu graph warmup with variable inputs
Expand Down Expand Up @@ -87,7 +88,9 @@ def _warmup_speecht5_hpu_graph(self):
"OPEA is an ecosystem orchestration framework to integrate performant GenAI technologies & workflows leading to quicker GenAI adoption and business value."
)

def t2s(self, text):
def t2s(self, text, voice="default"):
if voice == "male":
self.default_speaker_embedding = torch.load("spk_embed_male.pt")
if self.device == "hpu":
# See https://github.com/huggingface/optimum-habana/pull/824
from optimum.habana.utils import set_seed
Expand Down
32 changes: 30 additions & 2 deletions comps/tts/speecht5/dependency/speecht5_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@

import argparse
import base64
import os

import soundfile as sf
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import Response
from fastapi.responses import Response, StreamingResponse
from speecht5_model import SpeechT5Model
from starlette.middleware.cors import CORSMiddleware

from comps import CustomLogger
from comps.cores.proto.api_protocol import AudioSpeechRequest

logger = CustomLogger("speecht5")
logflag = os.getenv("LOGFLAG", False)

app = FastAPI()
tts = None

Expand All @@ -27,7 +34,7 @@ async def health() -> Response:

@app.post("/v1/tts")
async def text_to_speech(request: Request):
print("SpeechT5 generation begin.")
logger.info("SpeechT5 generation begin.")
request_dict = await request.json()
text = request_dict.pop("text")

Expand All @@ -40,6 +47,27 @@ async def text_to_speech(request: Request):
return {"tts_result": b64_str}


@app.post("/v1/audio/speech")
async def audio_speech(request: AudioSpeechRequest):
logger.info("SpeechT5 generation begin.")
# validate the request parameters
if request.model != tts.model_name_or_path:
raise Exception("TTS model mismatch! Currently only support model: microsoft/speecht5_tts")
if request.voice not in ["default", "male"] or request.speed != 1.0:
logger.warning("Currently parameter 'speed' can only be 1.0 and 'voice' can only be default or male!")

speech = tts.t2s(request.input, voice=request.voice)

tmp_path = "tmp.wav"
sf.write(tmp_path, speech, samplerate=16000)

def audio_gen():
with open(tmp_path, "rb") as f:
yield from f

return StreamingResponse(audio_gen(), media_type=f"audio/{request.response_format}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
Expand Down

0 comments on commit c3948ad

Please sign in to comment.