Skip to content

Commit

Permalink
fix(whisper): fix whisper transcription of non-english audio (#1066)
Browse files Browse the repository at this point in the history
* This fix ensures that the original language of the audio is what is outputted into the transcribed text.
* Adds more logging to the whisper backend
* Removes english as the default language and instead use the automatic language detection
  • Loading branch information
CollectiveUnicorn authored Sep 25, 2024
1 parent c4c7e9d commit 8dd467a
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 19 deletions.
10 changes: 10 additions & 0 deletions packages/whisper/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
MODEL_NAME ?= openai/whisper-base

install:
python -m pip install ../../src/leapfrogai_sdk
python -m pip install -e ".[dev]"

download-model:
mkdir -p .model
ct2-transformers-converter --model $(MODEL_NAME) \
--output_dir .model \
--copy_files tokenizer.json special_tokens_map.json preprocessor_config.json normalizer.json tokenizer_config.json vocab.json \
--quantization float32 \
--force

dev:
make install
python -m leapfrogai_sdk.cli --app-dir=src/ main:Model
5 changes: 3 additions & 2 deletions packages/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ To run the vllm backend locally without K8s (starting from the root directory of
make install

# Download and convert model
# Change the value for --model to change the whisper base
ct2-transformers-converter --model openai/whisper-base --output_dir .model --copy_files tokenizer.json --quantization float32
# Change the MODEL_NAME to change the whisper base
export MODEL_NAME=openai/whisper-base
make download-model

# Start the model backend
make dev
Expand Down
45 changes: 29 additions & 16 deletions packages/whisper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
GPU_ENABLED = True if int(os.environ.get("GPU_REQUEST", 0)) > 0 else False


def make_transcribe_request(filename, task, language, temperature, prompt):
def make_whisper_request(filename, task, language, temperature, prompt):
device = "cuda" if GPU_ENABLED else "cpu"
model = WhisperModel(model_path, device=device, compute_type="float32")

Expand All @@ -27,18 +27,22 @@ def make_transcribe_request(filename, task, language, temperature, prompt):
if task:
if task in ["transcribe", "translate"]:
kwargs["task"] = task
logger.info(f"Task {task} is starting")
else:
logger.error(f"Task {task} is not supported")
return {"text": ""}
if language:
if language in model.supported_languages:
kwargs["language"] = language
logger.info(f"Language {language} is supported")
else:
logger.error(f"Language {language} is not supported")
if temperature:
kwargs["temperature"] = temperature
logger.info(f"Temperature {temperature} is set")
if prompt:
kwargs["initial_prompt"] = prompt
logger.info(f"Prompt {prompt} is set")

try:
# Call transcribe with only non-None parameters
Expand All @@ -62,26 +66,35 @@ def call_whisper(
data = bytearray()
prompt = ""
temperature = 0.0
inputLanguage = "en"
# By default, automatically detect the language
input_language = None

for request in request_iterator:
if (
request.metadata.prompt
and request.metadata.temperature
and request.metadata.inputlanguage
):
prompt = request.metadata.prompt
temperature = request.metadata.temperature
inputLanguage = request.metadata.inputlanguage
continue

data.extend(request.chunk_data)
metadata = request.metadata
updated = False

if metadata.prompt:
logger.info(f"Updated metadata: Prompt='{prompt}'")
prompt = metadata.prompt
updated = True

if metadata.temperature:
logger.info(f"Updated metadata: Temperature={temperature}")
temperature = metadata.temperature
updated = True

if metadata.inputlanguage:
logger.info(f"Updated metadata: Input Language='{input_language}'")
input_language = metadata.inputlanguage
updated = True

# Metadata updates are done separate from data updates
if not updated:
data.extend(request.chunk_data)

with tempfile.NamedTemporaryFile("wb") as f:
f.write(data)
result = make_transcribe_request(
f.name, task, inputLanguage, temperature, prompt
)
result = make_whisper_request(f.name, task, input_language, temperature, prompt)
text = str(result["text"])

if task == "transcribe":
Expand Down
Binary file modified tests/data/russian.mp3
Binary file not shown.
56 changes: 55 additions & 1 deletion tests/e2e/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_translations(client: OpenAI):
file=Path("tests/data/arabic-audio.wav"),
prompt="This is a test translation.",
response_format="json",
temperature=0.3,
temperature=0.0,
)

assert len(translation.text) > 0, "The translation should not be empty"
Expand All @@ -73,3 +73,57 @@ def is_english_or_punctuation(c):
english_chars = [is_english_or_punctuation(c) for c in translation.text]

assert all(english_chars), "Non-English characters have been returned"


def test_non_english_transcription(client: OpenAI):
# Arabic transcription
arabic_transcription = client.audio.transcriptions.create(
model="whisper",
file=Path("tests/data/arabic-audio.wav"),
response_format="json",
temperature=0.5,
timestamp_granularities=["word", "segment"],
)

assert (
len(arabic_transcription.text) > 0
), "The Arabic transcription should not be empty"
assert (
len(arabic_transcription.text) < 500
), "The Arabic transcription should not be too long"

def is_arabic_or_punctuation(c):
if c in string.punctuation or c.isspace():
return True
return unicodedata.name(c).startswith("ARABIC")

arabic_chars = [is_arabic_or_punctuation(c) for c in arabic_transcription.text]
assert all(
arabic_chars
), "Non-Arabic characters have been returned in Arabic transcription"

# Russian transcription
russian_transcription = client.audio.transcriptions.create(
model="whisper",
file=Path("tests/data/russian.mp3"),
response_format="json",
temperature=0.5,
timestamp_granularities=["word", "segment"],
)

assert (
len(russian_transcription.text) > 0
), "The Russian transcription should not be empty"
assert (
len(russian_transcription.text) < 500
), "The Russian transcription should not be too long"

def is_russian_or_punctuation(c):
if c in string.punctuation or c.isspace():
return True
return unicodedata.name(c).startswith("CYRILLIC")

russian_chars = [is_russian_or_punctuation(c) for c in russian_transcription.text]
assert all(
russian_chars
), "Non-Russian characters have been returned in Russian transcription"

0 comments on commit 8dd467a

Please sign in to comment.