diff --git a/src/leapfrogai_ui/src/app.css b/src/leapfrogai_ui/src/app.css
index b1f6ef61b..1afa8ffdd 100644
--- a/src/leapfrogai_ui/src/app.css
+++ b/src/leapfrogai_ui/src/app.css
@@ -9,6 +9,17 @@
scrollbar-color: #4b5563 #1f2937;
}
+/* Override TailwindCSS default Preflight styles for lists in messages */
+#message-content-container {
+ ul {
+ margin: revert;
+ padding: revert;
+ li {
+ list-style: square;
+ }
+ }
+}
+
/*TODO - can we get rid of some of these?*/
@layer utilities {
.content {
diff --git a/src/leapfrogai_ui/src/lib/components/Message.svelte b/src/leapfrogai_ui/src/lib/components/Message.svelte
index 0af165b1f..d59a5e448 100644
--- a/src/leapfrogai_ui/src/lib/components/Message.svelte
+++ b/src/leapfrogai_ui/src/lib/components/Message.svelte
@@ -178,14 +178,16 @@
{#if message.role !== 'user' && !messageText}
{:else}
-
- {@html DOMPurify.sanitize(md.render(messageText), {
- CUSTOM_ELEMENT_HANDLING: {
- tagNameCheck: /^code-block$/,
- attributeNameCheck: /^(code|language)$/,
- allowCustomizedBuiltInElements: false
- }
- })}
+
+
+ {@html DOMPurify.sanitize(md.render(messageText), {
+ CUSTOM_ELEMENT_HANDLING: {
+ tagNameCheck: /^code-block$/,
+ attributeNameCheck: /^(code|language)$/,
+ allowCustomizedBuiltInElements: false
+ }
+ })}
+
{#each getCitations(message, $page.data.files) as { component: Component, props }}
diff --git a/tests/Makefile b/tests/Makefile
index 1b22ff443..b62ca37b0 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -74,3 +74,6 @@ test-api-unit:
test-load:
python -m locust -f $$(pwd)/tests/load/loadtest.py --web-port 8089
+
+test-conformance:
+ PYTHONPATH=$$(pwd) pytest -vv -s tests/conformance
diff --git a/tests/README.md b/tests/README.md
index 765695d8d..b9d41eda9 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -107,3 +107,23 @@ python -m pytest tests/e2e/test_llama.py -v
# Cleanup after yourself
k3d cluster delete uds
```
+
+## Conformance Testing
+
+We include a set of conformance tests to verify our spec against OpenAI to guarantee interoperability with tools that support OpenAI's API (MatterMost, Continue.dev, etc.) and SDKs (Vercel, Azure, etc.). To run these tests the environment variables need to be set:
+
+```bash
+LEAPFROGAI_API_KEY="
" # this can be created via the LeapfrogAI UI or Supabase
+LEAPFROGAI_API_URL="https://leapfrogai-api.uds.dev/openai/v1" # This is the default when using a UDS-bundle locally
+LEAPFROGAI_MODEL="vllm" # or whatever model you have installed
+OPENAI_API_KEY="" # you need a funded OpenAI account for this
+OPENAI_MODEL="gpt-4o-mini" # or whatever model you prefer
+```
+
+To run the tests, from the root directory of the LeapfrogAI project:
+
+```bash
+make install # to ensure all python dependencies are installed
+
+make test-conformance # runs the entire suite
+```
diff --git a/tests/conformance/test_conformance_assistants.py b/tests/conformance/test_assistants.py
similarity index 100%
rename from tests/conformance/test_conformance_assistants.py
rename to tests/conformance/test_assistants.py
diff --git a/tests/conformance/test_completions.py b/tests/conformance/test_completions.py
index 6e53dfdcc..d400d985a 100644
--- a/tests/conformance/test_completions.py
+++ b/tests/conformance/test_completions.py
@@ -1,7 +1,7 @@
import pytest
from openai.types.beta.threads import Run, Message, TextContentBlock, Text
-from .utils import client_config_factory
+from tests.utils.client import client_config_factory
def make_mock_message_object(role, message_text):
@@ -37,12 +37,12 @@ def make_mock_message_simple(role, message_text):
def test_run_completion(client_name, test_messages):
# Setup
config = client_config_factory(client_name)
- client = config["client"]
+ client = config.client
assistant = client.beta.assistants.create(
name="Test Assistant",
instructions="You must provide a response based on the attached files.",
- model=config["model"],
+ model=config.model,
)
thread = client.beta.threads.create()
diff --git a/tests/conformance/test_files.py b/tests/conformance/test_files.py
index 02b67530a..18074e259 100644
--- a/tests/conformance/test_files.py
+++ b/tests/conformance/test_files.py
@@ -6,16 +6,17 @@
)
from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile
-from ..utils.client import client_config_factory, text_file_path
+from tests.utils.client import client_config_factory
+from tests.utils.data_path import data_path, TXT_DATA_FILE
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_upload(client_name):
config = client_config_factory(client_name)
- client = config.client # shorthand
+ client = config.client
vector_store = client.beta.vector_stores.create(name="Test data")
- with open(text_file_path(), "rb") as file:
+ with open(data_path(TXT_DATA_FILE), "rb") as file:
vector_store_file = client.beta.vector_stores.files.upload(
vector_store_id=vector_store.id, file=file
)
@@ -24,13 +25,14 @@ def test_file_upload(client_name):
assert isinstance(vector_store_file, VectorStoreFile)
+@pytest.mark.xfail(reason="File Batch Upload is not yet implemented in LeapfrogAI")
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_delete(client_name):
config = client_config_factory(client_name)
client = config.client
vector_store = client.beta.vector_stores.create(name="Test data")
- with open(text_file_path(), "rb") as file:
+ with open(data_path(TXT_DATA_FILE), "rb") as file:
vector_store_file = client.beta.vector_stores.files.upload(
vector_store_id=vector_store.id, file=file
)
diff --git a/tests/conformance/test_messages.py b/tests/conformance/test_messages.py
index f58f22b9c..24e1f312f 100644
--- a/tests/conformance/test_messages.py
+++ b/tests/conformance/test_messages.py
@@ -2,7 +2,7 @@
from openai.types.beta.threads.message import Message
-from ..utils.client import client_config_factory
+from tests.utils.client import client_config_factory
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
diff --git a/tests/conformance/test_conformance_runs.py b/tests/conformance/test_runs.py
similarity index 93%
rename from tests/conformance/test_conformance_runs.py
rename to tests/conformance/test_runs.py
index 7a4447bfc..d8039864e 100644
--- a/tests/conformance/test_conformance_runs.py
+++ b/tests/conformance/test_runs.py
@@ -1,7 +1,7 @@
import pytest
from openai.types.beta.threads import Run, Message, TextContentBlock, Text
-from .utils import client_config_factory
+from tests.utils.client import client_config_factory
def make_mock_message_object(role, message_text):
@@ -37,12 +37,12 @@ def make_mock_message_simple(role, message_text):
def test_run_create(client_name, test_messages):
# Setup
config = client_config_factory(client_name)
- client = config["client"]
+ client = config.client
assistant = client.beta.assistants.create(
name="Test Assistant",
instructions="You must provide a response based on the attached files.",
- model=config["model"],
+ model=config.model,
)
thread = client.beta.threads.create()
diff --git a/tests/conformance/test_conformance_threads.py b/tests/conformance/test_threads.py
similarity index 95%
rename from tests/conformance/test_conformance_threads.py
rename to tests/conformance/test_threads.py
index 91d17c940..2a56528c7 100644
--- a/tests/conformance/test_conformance_threads.py
+++ b/tests/conformance/test_threads.py
@@ -2,7 +2,7 @@
from openai.types.beta.thread import Thread
from openai.types.beta.threads import Message, TextContentBlock, Text
-from ..utils.client import client_config_factory
+from tests.utils.client import client_config_factory
def make_mock_message_object(role, message_text):
diff --git a/tests/conformance/test_conformance_tools.py b/tests/conformance/test_tools.py
similarity index 92%
rename from tests/conformance/test_conformance_tools.py
rename to tests/conformance/test_tools.py
index 9b69193d5..cff821545 100644
--- a/tests/conformance/test_conformance_tools.py
+++ b/tests/conformance/test_tools.py
@@ -7,12 +7,13 @@
from openai.types.beta.threads.message import Message
import re
-from ..utils.client import client_config_factory, text_file_path
+from tests.utils.client import client_config_factory
+from tests.utils.data_path import data_path, TXT_DATA_FILE
def make_vector_store_with_file(client):
vector_store = client.beta.vector_stores.create(name="Test data")
- with open(text_file_path(), "rb") as file:
+ with open(data_path(TXT_DATA_FILE), "rb") as file:
client.beta.vector_stores.files.upload(
vector_store_id=vector_store.id, file=file
)
@@ -46,7 +47,7 @@ def validate_annotation_format(annotation):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_thread_file_annotations(client_name):
config = client_config_factory(client_name)
- client = config.client # shorthand
+ client = config.client
vector_store = make_vector_store_with_file(client)
assistant = make_test_assistant(client, config.model, vector_store.id)
diff --git a/tests/conformance/test_conformance_vectorstore.py b/tests/conformance/test_vectorstore.py
similarity index 90%
rename from tests/conformance/test_conformance_vectorstore.py
rename to tests/conformance/test_vectorstore.py
index 25ad52f9d..b9e65b4d0 100644
--- a/tests/conformance/test_conformance_vectorstore.py
+++ b/tests/conformance/test_vectorstore.py
@@ -3,13 +3,13 @@
from openai.types.beta.vector_store import VectorStore
from openai.types.beta.vector_store_deleted import VectorStoreDeleted
-from ..utils.client import client_config_factory
+from tests.utils.client import client_config_factory
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_create(client_name):
config = client_config_factory(client_name)
- client = config.client # shorthand
+ client = config.client
vector_store = client.beta.vector_stores.create(name="Test data")
@@ -19,7 +19,7 @@ def test_vector_store_create(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_list(client_name):
config = client_config_factory(client_name)
- client = config.client # shorthand
+ client = config.client
client.beta.vector_stores.create(name="Test data")
diff --git a/tests/e2e/test_llm_generation.py b/tests/e2e/test_llm_generation.py
index 4f54e7f0f..badb0dd3e 100644
--- a/tests/e2e/test_llm_generation.py
+++ b/tests/e2e/test_llm_generation.py
@@ -1,11 +1,11 @@
import os
-from pathlib import Path
from typing import Iterable
import warnings
import pytest
from openai import InternalServerError, OpenAI
from openai.types.chat import ChatCompletionMessageParam
+from tests.utils.data_path import data_path, WAV_FILE
DEFAULT_LEAPFROGAI_MODEL = "llama-cpp-python"
@@ -72,7 +72,8 @@ def test_embeddings(client: OpenAI, model_name: str):
def test_transcriptions(client: OpenAI, model_name: str):
with pytest.raises(InternalServerError) as excinfo:
client.audio.transcriptions.create(
- model=model_name, file=Path("tests/data/0min12sec.wav")
+ model=model_name,
+ file=data_path(WAV_FILE),
)
assert str(excinfo.value) == "Internal Server Error"
diff --git a/tests/e2e/test_text_embeddings.py b/tests/e2e/test_text_embeddings.py
index 1912228e1..23fdb4571 100644
--- a/tests/e2e/test_text_embeddings.py
+++ b/tests/e2e/test_text_embeddings.py
@@ -1,7 +1,6 @@
-from pathlib import Path
-
import pytest
from openai import InternalServerError, OpenAI
+from tests.utils.data_path import data_path, WAV_FILE
model_name = "text-embeddings"
@@ -41,6 +40,7 @@ def test_embeddings(client: OpenAI):
def test_transcriptions(client: OpenAI):
with pytest.raises(InternalServerError) as excinfo:
client.audio.transcriptions.create(
- model=model_name, file=Path("tests/data/0min12sec.wav")
+ model=model_name,
+ file=data_path(WAV_FILE),
)
assert str(excinfo.value) == "Internal Server Error"
diff --git a/tests/e2e/test_whisper.py b/tests/e2e/test_whisper.py
index bae880039..3b5256f91 100644
--- a/tests/e2e/test_whisper.py
+++ b/tests/e2e/test_whisper.py
@@ -4,6 +4,7 @@
import pytest
from openai import InternalServerError, OpenAI
import unicodedata
+from tests.utils.data_path import data_path, WAV_FILE, WAV_FILE_ARABIC
def test_completions(client: OpenAI):
@@ -38,7 +39,7 @@ def test_embeddings(client: OpenAI):
def test_transcriptions(client: OpenAI):
transcription = client.audio.transcriptions.create(
model="whisper",
- file=Path("tests/data/0min12sec.wav"),
+ file=data_path(WAV_FILE),
language="en",
prompt="This is a test transcription.",
response_format="json",
@@ -53,7 +54,7 @@ def test_transcriptions(client: OpenAI):
def test_translations(client: OpenAI):
translation = client.audio.translations.create(
model="whisper",
- file=Path("tests/data/arabic-audio.wav"),
+ file=data_path(WAV_FILE_ARABIC),
prompt="This is a test translation.",
response_format="json",
temperature=0.0,
@@ -79,7 +80,7 @@ def test_non_english_transcription(client: OpenAI):
# Arabic transcription
arabic_transcription = client.audio.transcriptions.create(
model="whisper",
- file=Path("tests/data/arabic-audio.wav"),
+ file=data_path(WAV_FILE_ARABIC),
response_format="json",
temperature=0.5,
timestamp_granularities=["word", "segment"],
diff --git a/tests/integration/api/test_assistants.py b/tests/integration/api/test_assistants.py
index deb341904..06c0444af 100644
--- a/tests/integration/api/test_assistants.py
+++ b/tests/integration/api/test_assistants.py
@@ -21,6 +21,7 @@
CreateAssistantRequest,
ModifyAssistantRequest,
)
+from tests.utils.data_path import data_path, TXT_FILE
INSTRUCTOR_XL_EMBEDDING_SIZE: int = 768
@@ -92,9 +93,7 @@ class MissingEnvironmentVariable(Exception):
def read_testfile():
"""Read the test file content."""
- with open(
- os.path.dirname(__file__) + "/../../../tests/data/test.txt", "rb"
- ) as testfile:
+ with open(data_path(TXT_FILE), "rb") as testfile:
testfile_content = testfile.read()
return testfile_content
@@ -109,7 +108,7 @@ def create_file(read_testfile): # pylint: disable=redefined-outer-name, unused-
file_response = files_client.post(
"/openai/v1/files",
- files={"file": ("test.txt", read_testfile, "text/plain")},
+ files={"file": (TXT_FILE, read_testfile, "text/plain")},
data={"purpose": "assistants"},
)
diff --git a/tests/integration/api/test_files.py b/tests/integration/api/test_files.py
index 1e7bb51e4..1a0711184 100644
--- a/tests/integration/api/test_files.py
+++ b/tests/integration/api/test_files.py
@@ -8,6 +8,7 @@
from leapfrogai_api.backend.rag.document_loader import load_file, split
from leapfrogai_api.routers.openai.files import router
+from tests.utils.data_path import data_path, WAV_FILE, TXT_FILE, PPTX_FILE, XLSX_FILE
file_response: Response
testfile_content: bytes
@@ -34,7 +35,7 @@ class MissingEnvironmentVariable(Exception):
def read_testfile():
"""Read the test file content."""
global testfile_content # pylint: disable=global-statement
- with open(os.path.dirname(__file__) + "/../../data/test.txt", "rb") as testfile:
+ with open(data_path(TXT_FILE), "rb") as testfile:
testfile_content = testfile.read()
@@ -46,7 +47,7 @@ def create_file(read_testfile): # pylint: disable=redefined-outer-name, unused-
file_response = client.post(
"/openai/v1/files",
- files={"file": ("test.txt", testfile_content, "text/plain")},
+ files={"file": (TXT_FILE, testfile_content, "text/plain")},
data={"purpose": "assistants"},
)
@@ -132,15 +133,11 @@ def test_get_nonexistent():
def test_invalid_file_type():
"""Test creating uploading an invalid file type."""
- file_path = "../../../tests/data/0min12sec.wav"
- dir_path = os.path.dirname(os.path.realpath(__file__))
- relative_file_path = os.path.join(dir_path, file_path)
-
with pytest.raises(HTTPException) as exception:
- with open(relative_file_path, "rb") as testfile:
+ with open(data_path(WAV_FILE), "rb") as testfile:
_ = client.post(
"/openai/v1/files",
- files={"file": ("0min12sec.wav", testfile, "audio/wav")},
+ files={"file": (WAV_FILE, testfile, "audio/wav")},
data={"purpose": "assistants"},
)
assert exception.status_code == status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
@@ -149,16 +146,8 @@ def test_invalid_file_type():
@pytest.mark.asyncio
async def test_excel_file_handling():
"""Test handling of an Excel file including upload, retrieval, and deletion."""
- # Path to the test Excel file
- excel_file_path = os.path.join(os.path.dirname(__file__), "../../data/test.xlsx")
-
- # Ensure the file exists
- assert os.path.exists(
- excel_file_path
- ), f"Test Excel file not found at {excel_file_path}"
-
# Test file loading and splitting
- documents = await load_file(excel_file_path)
+ documents = await load_file(data_path(XLSX_FILE))
assert len(documents) > 0, "No documents were loaded from the Excel file"
assert documents[0].page_content, "The first document has no content"
@@ -167,12 +156,12 @@ async def test_excel_file_handling():
assert split_documents[0].page_content, "The first split document has no content"
# Test file upload via API
- with open(excel_file_path, "rb") as excel_file:
+ with open(data_path(XLSX_FILE), "rb") as excel_file:
response = client.post(
"/openai/v1/files",
files={
"file": (
- "test.xlsx",
+ XLSX_FILE,
excel_file,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
@@ -228,16 +217,9 @@ async def test_excel_file_handling():
@pytest.mark.asyncio
async def test_powerpoint_file_handling():
"""Test handling of a PowerPoint file including upload, retrieval, and deletion."""
- # Path to the test PowerPoint file
- pptx_file_path = os.path.join(os.path.dirname(__file__), "../../data/test.pptx")
-
- # Ensure the file exists
- assert os.path.exists(
- pptx_file_path
- ), f"Test PowerPoint file not found at {pptx_file_path}"
# Test file loading and splitting
- documents = await load_file(pptx_file_path)
+ documents = await load_file(data_path(PPTX_FILE).__str__())
assert len(documents) > 0, "No documents were loaded from the PowerPoint file"
assert documents[0].page_content, "The first document has no content"
@@ -246,13 +228,13 @@ async def test_powerpoint_file_handling():
assert split_documents[0].page_content, "The first split document has no content"
# Test file upload via API
- with open(pptx_file_path, "rb") as pptx_file:
+ with open(data_path(PPTX_FILE), "rb") as file:
response = client.post(
"/openai/v1/files",
files={
"file": (
- "test.pptx",
- pptx_file,
+ PPTX_FILE,
+ file,
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
)
},
diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py
index 9ed2ad28c..45f832418 100644
--- a/tests/integration/api/test_rag_files.py
+++ b/tests/integration/api/test_rag_files.py
@@ -1,9 +1,9 @@
import os
-from pathlib import Path
from openai.types.beta.threads.text import Text
import pytest
+from tests.utils.data_path import data_path
-from ...utils.client import client_config_factory
+from tests.utils.client import client_config_factory
def make_test_assistant(client, model, vector_store_id):
@@ -33,7 +33,6 @@ def test_rag_needle_haystack():
client = config.client
vector_store = client.beta.vector_stores.create(name="Test data")
- file_path = "../../data"
file_names = [
"test_rag_1.1.txt",
"test_rag_1.2.txt",
@@ -44,9 +43,7 @@ def test_rag_needle_haystack():
]
vector_store_files = []
for file_name in file_names:
- with open(
- f"{Path(os.path.dirname(__file__))}/{file_path}/{file_name}", "rb"
- ) as file:
+ with open(data_path(file_name), "rb") as file:
vector_store_files.append(
client.beta.vector_stores.files.upload(
vector_store_id=vector_store.id, file=file
diff --git a/tests/integration/api/test_vector_stores.py b/tests/integration/api/test_vector_stores.py
index f9c69f8e8..5427a0943 100644
--- a/tests/integration/api/test_vector_stores.py
+++ b/tests/integration/api/test_vector_stores.py
@@ -19,6 +19,7 @@
)
from leapfrogai_api.routers.openai.vector_stores import router as vector_store_router
from leapfrogai_api.routers.openai.files import router as files_router
+from tests.utils.data_path import data_path, TXT_FILE
INSTRUCTOR_XL_EMBEDDING_SIZE: int = 768
@@ -52,9 +53,7 @@ class MissingEnvironmentVariable(Exception):
def read_testfile():
"""Read the test file content."""
- with open(
- os.path.dirname(__file__) + "/../../../tests/data/test.txt", "rb"
- ) as testfile:
+ with open(data_path(TXT_FILE), "rb") as testfile:
testfile_content = testfile.read()
return testfile_content
@@ -67,7 +66,7 @@ def create_file(read_testfile): # pylint: disable=redefined-outer-name, unused-
file_response = files_client.post(
"/openai/v1/files",
- files={"file": ("test.txt", read_testfile, "text/plain")},
+ files={"file": (TXT_FILE, read_testfile, "text/plain")},
data={"purpose": "assistants"},
)
diff --git a/tests/load/loadtest.py b/tests/load/loadtest.py
index 1a3bd8faa..745379e4e 100644
--- a/tests/load/loadtest.py
+++ b/tests/load/loadtest.py
@@ -8,6 +8,7 @@
import warnings
import tempfile
import uuid
+from tests.utils.data_path import data_path, MP3_FILE_RUSSIAN
# Suppress SSL-related warnings
warnings.filterwarnings("ignore", category=Warning)
@@ -59,9 +60,7 @@ def download_arxiv_pdf():
def load_audio_file():
- script_dir = os.path.dirname(os.path.abspath(__file__))
- file_path = os.path.join(script_dir, "..", "data", "russian.mp3")
- with open(file_path, "rb") as file:
+ with open(data_path(MP3_FILE_RUSSIAN), "rb") as file:
return file.read()
@@ -211,14 +210,14 @@ def test_embeddings(self):
@task
def test_transcribe(self):
audio_content = load_audio_file()
- files = {"file": ("russian.mp3", audio_content, "audio/mpeg")}
+ files = {"file": (MP3_FILE_RUSSIAN, audio_content, "audio/mpeg")}
data = {"model": "whisper", "language": "ru"}
self.client.post("/openai/v1/audio/transcriptions", files=files, data=data)
@task
def test_translate(self):
audio_content = load_audio_file()
- files = {"file": ("russian.mp3", audio_content, "audio/mpeg")}
+ files = {"file": (MP3_FILE_RUSSIAN, audio_content, "audio/mpeg")}
data = {"model": "whisper"}
self.client.post("/openai/v1/audio/translations", files=files, data=data)
diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py
index 10fbe698b..724b0dc58 100644
--- a/tests/pytest/leapfrogai_api/test_api.py
+++ b/tests/pytest/leapfrogai_api/test_api.py
@@ -15,6 +15,7 @@
from leapfrogai_api.typedef.embeddings import CreateEmbeddingRequest
from leapfrogai_api.main import app
from leapfrogai_api.routers.supabase_session import init_supabase_client
+from tests.utils.data_path import data_path, WAV_FILE, WAV_FILE_ARABIC
security = HTTPBearer()
@@ -65,12 +66,6 @@ async def pack_dummy_bearer_token(request: _CachedRequest, call_next):
return await call_next(request)
-def load_audio_file(path: str):
- file_path = os.path.join("tests", "data", path)
- with open(file_path, "rb") as file:
- return file.read()
-
-
@pytest.fixture
def dummy_auth_middleware():
app.dependency_overrides[init_supabase_client] = mock_init_supabase_client
@@ -269,13 +264,12 @@ def test_transcription(dummy_auth_middleware):
expected_transcription = "The repeater model received a transcribe request"
with TestClient(app) as client:
- audio_filename = "0min12sec.wav"
- audio_content = load_audio_file(audio_filename)
- files = {"file": (audio_filename, audio_content, "audio/mpeg")}
- data = {"model": MODEL}
- response = client.post(
- "/openai/v1/audio/transcriptions", files=files, data=data
- )
+ with open(data_path(WAV_FILE), "rb") as audio_content:
+ files = {"file": (WAV_FILE, audio_content, "audio/mpeg")}
+ data = {"model": MODEL}
+ response = client.post(
+ "/openai/v1/audio/transcriptions", files=files, data=data
+ )
assert response.status_code == 200
@@ -292,11 +286,12 @@ def test_translation(dummy_auth_middleware):
expected_translation = "The repeater model received a translation request"
with TestClient(app) as client:
- audio_filename = "arabic-audio.wav"
- audio_content = load_audio_file(audio_filename)
- files = {"file": (audio_filename, audio_content, "audio/mpeg")}
- data = {"model": MODEL}
- response = client.post("/openai/v1/audio/translations", files=files, data=data)
+ with open(data_path(WAV_FILE_ARABIC), "rb") as audio_content:
+ files = {"file": (WAV_FILE_ARABIC, audio_content, "audio/mpeg")}
+ data = {"model": MODEL}
+ response = client.post(
+ "/openai/v1/audio/translations", files=files, data=data
+ )
assert response.status_code == 200
diff --git a/tests/utils/client.py b/tests/utils/client.py
index 08855e6c0..8411d5077 100644
--- a/tests/utils/client.py
+++ b/tests/utils/client.py
@@ -1,16 +1,10 @@
from openai import OpenAI
import os
-from pathlib import Path
-
LEAPFROGAI_MODEL = os.getenv("LEAPFROGAI_MODEL", "llama-cpp-python")
OPENAI_MODEL = "gpt-4o-mini"
-def text_file_path():
- return Path(os.path.dirname(__file__) + "/../data/test_with_data.txt")
-
-
def openai_client():
return OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
diff --git a/tests/utils/data_path.py b/tests/utils/data_path.py
new file mode 100644
index 000000000..88f6a1db4
--- /dev/null
+++ b/tests/utils/data_path.py
@@ -0,0 +1,35 @@
+import os
+from pathlib import Path
+
+TXT_FILE = "test.txt"
+TXT_DATA_FILE = "test_with_data.txt"
+PPTX_FILE = "test.pptx"
+WAV_FILE = "0min12sec.wav"
+WAV_FILE_ARABIC = "arabic-audio.wav"
+MP3_FILE_RUSSIAN = "russian.mp3"
+XLSX_FILE = "test.xlsx"
+
+
+def data_path(filename: str) -> Path:
+ """Return the path to a test file in the data directory. (See constants for specific files.)
+
+ Args:
+ filename (str): The name of the file to return the path.
+
+ Returns:
+ Path: The path to the file in the data directory.
+
+ Raises:
+ FileNotFoundError: If the file does not exist in the data directory.
+ """
+
+ data_path = Path(
+ os.path.realpath(os.path.dirname(__file__) + f"/../data/{filename}")
+ )
+
+ try:
+ # Check if the file exists
+ with open(data_path, "r"):
+ return data_path
+ except FileNotFoundError:
+ raise FileNotFoundError(f"File not found in data directory: {data_path}")