Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi authored and fedirz committed Jan 12, 2025
1 parent 22f774d commit 8980efe
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
with:
version: "0.4.11"
version: "latest"
enable-cache: true
- run: uv python install 3.12
- run: uv sync --extra dev
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
with:
version: "0.4.11"
version: "latest"
enable-cache: true
- run: uv python install 3.12
- run: uv sync --extra dev
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
with:
version: "0.4.11"
version: "latest"
enable-cache: true
- run: uv python install 3.12
- run: uv sync --all-extras
Expand Down
3 changes: 2 additions & 1 deletion src/speaches/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,14 @@ def get_kokoro_model_path() -> Path:
def download_kokoro_model() -> None:
model_id = "hexgrad/Kokoro-82M"
model_repo_path = Path(
huggingface_hub.snapshot_download(model_id, repo_type="model", allow_patterns="**/kokoro-v0_19.onnx")
huggingface_hub.snapshot_download(model_id, repo_type="model", allow_patterns=["kokoro-v0_19.onnx"])
)
# HACK
res = httpx.get(
"https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.json", follow_redirects=True
).raise_for_status()
voices_path = model_repo_path / "voices.json"
voices_path.touch(exist_ok=True)
voices_path.write_bytes(res.content)


Expand Down
17 changes: 12 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from typing import Protocol

from fastapi.testclient import TestClient
import httpx
from httpx import ASGITransport, AsyncClient
from huggingface_hub import snapshot_download
from openai import AsyncOpenAI
import pytest
import pytest_asyncio
from pytest_mock import MockerFixture

from speaches.config import Config, WhisperConfig
from speaches.dependencies import get_config
from speaches.hf_utils import download_kokoro_model
from speaches.main import create_app

DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"]
Expand All @@ -26,6 +27,7 @@
# disable the UI as it slightly increases the app startup time due to the imports it's doing
enable_ui=False,
)
TIMEOUT = httpx.Timeout(15.0)


def pytest_configure() -> None:
Expand Down Expand Up @@ -64,7 +66,7 @@ async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient,
app = create_app()
# https://fastapi.tiangolo.com/advanced/testing-dependencies/
app.dependency_overrides[get_config] = lambda: config
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test", timeout=TIMEOUT) as aclient:
yield aclient

return inner
Expand All @@ -91,7 +93,12 @@ def actual_openai_client() -> AsyncOpenAI:

# TODO: remove the download after running the tests
# TODO: do not download when not needed
# @pytest.fixture(scope="session", autouse=True)
# def download_piper_voices() -> None:
# # Only download `voices.json` and the default voice
# snapshot_download("rhasspy/piper-voices", allow_patterns=["voices.json", "en/en_US/amy/**"])


@pytest.fixture(scope="session", autouse=True)
def download_piper_voices() -> None:
# Only download `voices.json` and the default voice
snapshot_download("rhasspy/piper-voices", allow_patterns=["voices.json", "en/en_US/amy/**"])
def download_kokoro() -> None:
download_kokoro_model()
11 changes: 3 additions & 8 deletions tests/speech_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("response_format", SUPPORTED_RESPONSE_FORMATS)
async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format: ResponseFormat) -> None:
await openai_client.audio.speech.create(
Expand All @@ -42,7 +41,6 @@ async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize(("model", "voice"), GOOD_MODEL_VOICE_PAIRS)
async def test_create_speech_good_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
await openai_client.audio.speech.create(
Expand All @@ -63,7 +61,6 @@ async def test_create_speech_good_model_voice_pair(openai_client: AsyncOpenAI, m


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize(("model", "voice"), BAD_MODEL_VOICE_PAIRS)
async def test_create_speech_bad_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
# NOTE: not sure why `APIConnectionError` is sometimes raised
Expand All @@ -76,11 +73,10 @@ async def test_create_speech_bad_model_voice_pair(openai_client: AsyncOpenAI, mo
)


SUPPORTED_SPEEDS = [0.25, 0.5, 1.0, 2.0, 4.0]
SUPPORTED_SPEEDS = [0.5, 1.0, 2.0]


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
async def test_create_speech_with_varying_speed(openai_client: AsyncOpenAI) -> None:
previous_size: int | None = None
for speed in SUPPORTED_SPEEDS:
Expand All @@ -101,7 +97,6 @@ async def test_create_speech_with_varying_speed(openai_client: AsyncOpenAI) -> N


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("speed", UNSUPPORTED_SPEEDS)
async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI, speed: float) -> None:
with pytest.raises(UnprocessableEntityError):
Expand All @@ -118,7 +113,6 @@ async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI,


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("sample_rate", VALID_SAMPLE_RATES)
async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
res = await openai_client.audio.speech.create(
Expand All @@ -136,7 +130,6 @@ async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: in


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("sample_rate", INVALID_SAMPLE_RATES)
async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
with pytest.raises(UnprocessableEntityError):
Expand All @@ -149,6 +142,8 @@ async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate:
)


# TODO: add piper tests

# TODO: implement the following test

# NUMBER_OF_MODELS = 1
Expand Down

0 comments on commit 8980efe

Please sign in to comment.