Skip to content

Commit

Permalink
feat(api): transcription/translation endpoints (#726)
Browse files Browse the repository at this point in the history
Updates transcription endpoint to the latest OpenAI spec.
Adds translation endpoint that takes audio in any language and converts it to English along with API endpoint that follows the OpenAI spec.
  • Loading branch information
CollectiveUnicorn authored Jul 11, 2024
1 parent f51078b commit a62b07e
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 4 deletions.
8 changes: 6 additions & 2 deletions packages/whisper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def make_transcribe_request(filename, task, language, temperature, prompt):
device = "cuda" if GPU_ENABLED else "cpu"
model = WhisperModel(model_path, device=device, compute_type="float32")

segments, info = model.transcribe(filename, beam_size=5)
segments, info = model.transcribe(filename, task=task, beam_size=5)

output = ""

Expand Down Expand Up @@ -57,7 +57,11 @@ def call_whisper(
f.name, task, inputLanguage, temperature, prompt
)
text = str(result["text"])
logger.info("Transcription complete!")

if task == "transcribe":
logger.info("Transcription complete!")
elif task == "translate":
logger.info("Translation complete!")
return lfai.AudioResponse(text=text)


Expand Down
10 changes: 10 additions & 0 deletions src/leapfrogai_api/backend/grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CreateTranscriptionResponse,
EmbeddingResponseData,
Usage,
CreateTranslationResponse,
)
from leapfrogai_sdk.chat.chat_pb2 import (
ChatCompletionResponse as ProtobufChatCompletionResponse,
Expand Down Expand Up @@ -139,3 +140,12 @@ async def create_transcription(model: Model, request: Iterator[lfai.AudioRequest
response: lfai.AudioResponse = await stub.Transcribe(request)

return CreateTranscriptionResponse(text=response.text)


async def create_translation(model: Model, request: Iterator[lfai.AudioRequest]):
"""Translate audio using the specified model."""
async with grpc.aio.insecure_channel(model.backend) as channel:
stub = lfai.AudioStub(channel)
response: lfai.AudioResponse = await stub.Translate(request)

return CreateTranslationResponse(text=response.text)
57 changes: 57 additions & 0 deletions src/leapfrogai_api/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ class CreateTranscriptionRequest(BaseModel):
le=1,
description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.",
)
timestamp_granularities: list[Literal["word", "segment"]] | None = Field(
default=None,
description="The timestamp granularities to populate for this transcription. response_format must be set to verbose_json to use timestamp granularities. Either or both of these options are supported: word, or segment. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.",
)

@classmethod
def as_form(
Expand All @@ -398,6 +402,7 @@ def as_form(
prompt: str | None = Form(""),
response_format: str | None = Form(""),
temperature: float | None = Form(1.0),
timestamp_granularities: list[Literal["word", "segment"]] | None = Form(None),
) -> CreateTranscriptionRequest:
return cls(
file=file,
Expand All @@ -406,6 +411,7 @@ def as_form(
prompt=prompt,
response_format=response_format,
temperature=temperature,
timestamp_granularities=timestamp_granularities,
)


Expand All @@ -419,6 +425,57 @@ class CreateTranscriptionResponse(BaseModel):
)


class CreateTranslationRequest(BaseModel):
"""Request object for creating a translation."""

file: UploadFile = Field(
...,
description="The audio file to translate. Supports any audio format that ffmpeg can handle. For a complete list of supported formats, see: https://ffmpeg.org/ffmpeg-formats.html",
)
model: str = Field(..., description="ID of the model to use.")
prompt: str = Field(
default="",
description="An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.",
)
response_format: str = Field(
default="json",
description="The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.",
)
temperature: float = Field(
default=1.0,
ge=0,
le=1,
description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.",
)

@classmethod
def as_form(
cls,
file: UploadFile = File(...),
model: str = Form(...),
prompt: str | None = Form(""),
response_format: str | None = Form(""),
temperature: float | None = Form(1.0),
) -> CreateTranslationRequest:
return cls(
file=file,
model=model,
prompt=prompt,
response_format=response_format,
temperature=temperature,
)


class CreateTranslationResponse(BaseModel):
"""Response object for translation."""

text: str = Field(
...,
description="The translated text.",
examples=["Hello, this is a translation of the audio file."],
)


#############
# FILES
#############
Expand Down
30 changes: 29 additions & 1 deletion src/leapfrogai_api/routers/openai/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from typing import Annotated
from fastapi import HTTPException, APIRouter, Depends
from fastapi.security import HTTPBearer
from leapfrogai_api.backend.grpc_client import create_transcription
from leapfrogai_api.backend.grpc_client import create_transcription, create_translation
from leapfrogai_api.backend.helpers import read_chunks
from leapfrogai_api.backend.types import (
CreateTranscriptionRequest,
CreateTranscriptionResponse,
CreateTranslationRequest,
)
from leapfrogai_api.routers.supabase_session import Session
from leapfrogai_api.utils import get_model_config
Expand Down Expand Up @@ -46,3 +47,30 @@ async def transcribe(
request_iterator = chain((audio_metadata_request,), chunk_iterator)

return await create_transcription(model, request_iterator)


@router.post("/translations")
async def translate(
session: Session,
model_config: Annotated[Config, Depends(get_model_config)],
req: CreateTranslationRequest = Depends(CreateTranslationRequest.as_form),
) -> CreateTranscriptionResponse:
"""Create a translation to english from the given audio file."""
model = model_config.get_model_backend(req.model)
if model is None:
raise HTTPException(
status_code=405,
detail=f"Model {req.model} not found. Currently supported models are {list(model_config.models.keys())}",
)

# Create a request that contains the metadata for the AudioRequest
audio_metadata = lfai.AudioMetadata(prompt=req.prompt, temperature=req.temperature)
audio_metadata_request = lfai.AudioRequest(metadata=audio_metadata)

# Read the file and get an iterator of all the data chunks
chunk_iterator = read_chunks(req.file.file, 1024)

# combine our metadata and chunk_data iterators
request_iterator = chain((audio_metadata_request,), chunk_iterator)

return await create_translation(model, request_iterator)
Binary file added tests/data/arabic-audio.wav
Binary file not shown.
35 changes: 34 additions & 1 deletion tests/e2e/test_whisper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import string
from pathlib import Path

import pytest
from openai import InternalServerError, OpenAI
import unicodedata

from .utils import create_test_user

Expand Down Expand Up @@ -41,8 +43,39 @@ def test_embeddings():

def test_transcriptions():
transcription = client.audio.transcriptions.create(
model="whisper", file=Path("tests/data/0min12sec.wav")
model="whisper",
file=Path("tests/data/0min12sec.wav"),
language="en",
prompt="This is a test transcription.",
response_format="json",
temperature=0.5,
timestamp_granularities=["word", "segment"],
)

assert len(transcription.text) > 0 # The transcription should not be empty
assert len(transcription.text) < 500 # The transcription should not be too long


def test_translations():
translation = client.audio.translations.create(
model="whisper",
file=Path("tests/data/arabic-audio.wav"),
prompt="This is a test translation.",
response_format="json",
temperature=0.3,
)

assert len(translation.text) > 0 # The translation should not be empty
assert len(translation.text) < 500 # The translation should not be too long

def is_english_or_punctuation(c):
if c in string.punctuation or c.isspace():
return True
if c.isalpha():
# Allow uppercase letters (for proper nouns) and common Latin characters
return c.isupper() or unicodedata.name(c).startswith(("LATIN", "COMMON"))
return False

english_chars = [is_english_or_punctuation(c) for c in translation.text]

assert all(english_chars) # Check that only English characters are returned

0 comments on commit a62b07e

Please sign in to comment.