From 2f80d879c218a72b9ce5bf370dc1f9124394d50f Mon Sep 17 00:00:00 2001 From: Gato <115658935+CollectiveUnicorn@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:34:46 -0700 Subject: [PATCH] feat(api): reranking backend integrated in with rag (#1090) * Adds reranking to RAG pipeline * Adds RAG configuration endpoint when in dev mode * Adds additional logging * Refactors the pytest's test_routes api tests * Alters default RAG values into two steps, retrieval and ranking. With the retrieval results being set to 100 after ranking the results are filtered down to the user specified k value. If reranking is not enabled, the user specified k results is returned from the retrieval step. * Adds Zarf configs to enable dev mode. --- .github/workflows/pytest.yaml | 1 + packages/api/chart/values.yaml | 2 + packages/api/values/registry1-values.yaml | 2 + packages/api/values/upstream-values.yaml | 2 + packages/api/zarf.yaml | 3 + src/leapfrogai_api/README.md | 69 +++++++++++++++++ src/leapfrogai_api/backend/rag/query.py | 74 ++++++++++++++++++- src/leapfrogai_api/main.py | 3 + src/leapfrogai_api/pyproject.toml | 1 + src/leapfrogai_api/routers/leapfrogai/rag.py | 56 ++++++++++++++ .../routers/leapfrogai/vector_stores.py | 4 +- src/leapfrogai_api/typedef/rag/__init__.py | 3 + src/leapfrogai_api/typedef/rag/rag_types.py | 40 ++++++++++ .../typedef/vectorstores/search_types.py | 10 +++ src/leapfrogai_api/utils/logging_tools.py | 12 +++ src/leapfrogai_evals/pyproject.toml | 5 +- tests/integration/api/test_rag_files.py | 69 ++++++++++++++++- tests/pytest/leapfrogai_api/test_api.py | 68 +++++++++++++++-- 18 files changed, 408 insertions(+), 16 deletions(-) create mode 100644 src/leapfrogai_api/routers/leapfrogai/rag.py create mode 100644 src/leapfrogai_api/typedef/rag/__init__.py create mode 100644 src/leapfrogai_api/typedef/rag/rag_types.py create mode 100644 src/leapfrogai_api/utils/logging_tools.py diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 93d0f0832..21d2e1985 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -64,6 +64,7 @@ jobs: run: make test-api-unit env: LFAI_RUN_REPEATER_TESTS: true + DEV: true integration: runs-on: ai-ubuntu-big-boy-8-core diff --git a/packages/api/chart/values.yaml b/packages/api/chart/values.yaml index 65b397e46..4c217ba8a 100644 --- a/packages/api/chart/values.yaml +++ b/packages/api/chart/values.yaml @@ -25,6 +25,8 @@ api: value: "*.toml" - name: DEFAULT_EMBEDDINGS_MODEL value: "text-embeddings" + - name: DEV + value: "false" - name: PORT value: "8080" - name: SUPABASE_URL diff --git a/packages/api/values/registry1-values.yaml b/packages/api/values/registry1-values.yaml index d269c6415..4bd35ee39 100644 --- a/packages/api/values/registry1-values.yaml +++ b/packages/api/values/registry1-values.yaml @@ -16,6 +16,8 @@ api: value: "*.toml" - name: DEFAULT_EMBEDDINGS_MODEL value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###" + - name: DEV + value: "###ZARF_VAR_DEV###" - name: PORT value: "8080" - name: SUPABASE_URL diff --git a/packages/api/values/upstream-values.yaml b/packages/api/values/upstream-values.yaml index 6d867260e..ef2dcdad9 100644 --- a/packages/api/values/upstream-values.yaml +++ b/packages/api/values/upstream-values.yaml @@ -14,6 +14,8 @@ api: value: "*.toml" - name: DEFAULT_EMBEDDINGS_MODEL value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###" + - name: DEV + value: "###ZARF_VAR_DEV###" - name: PORT value: "8080" - name: SUPABASE_URL diff --git a/packages/api/zarf.yaml b/packages/api/zarf.yaml index 4fa6c59f2..92b3c8123 100644 --- a/packages/api/zarf.yaml +++ b/packages/api/zarf.yaml @@ -16,6 +16,9 @@ variables: description: "Flag to expose the OpenAPI schema for debugging." - name: DEFAULT_EMBEDDINGS_MODEL default: "text-embeddings" + - name: DEV + default: "false" + description: "Flag to enable development endpoints." components: - name: leapfrogai-api diff --git a/src/leapfrogai_api/README.md b/src/leapfrogai_api/README.md index eec4dd0c6..214c986a9 100644 --- a/src/leapfrogai_api/README.md +++ b/src/leapfrogai_api/README.md @@ -56,3 +56,72 @@ See the ["Access" section of the DEVELOPMENT.md](../../docs/DEVELOPMENT.md#acces ### Tests See the [tests directory documentation](../../tests/README.md) for more details. + +### Reranking Configuration + +The LeapfrogAI API includes a Retrieval Augmented Generation (RAG) pipeline for enhanced question answering. This section details how to configure its reranking options. All RAG configurations are managed through the `/leapfrogai/v1/rag/configure` API endpoint. + +#### 1. Enabling/Disabling Reranking + +Reranking improves the accuracy and relevance of RAG responses. You can enable or disable it using the `enable_reranking` parameter: + +* **Enable Reranking:** Send a PATCH request to `/leapfrogai/v1/rag/configure` with the following JSON payload: + +```json +{ + "enable_reranking": true +} +``` + +* **Disable Reranking:** Send a PATCH request with: + +```json +{ + "enable_reranking": false +} +``` + +#### 2. Selecting a Reranking Model + +Multiple reranking models are supported, each offering different performance characteristics. Choose your preferred model using the `ranking_model` parameter. Ensure you've installed any necessary Python dependencies for your chosen model (see the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) on dependencies). + +* **Supported Models:** The system supports several models, including (but not limited to) `flashrank`, `rankllm`, `cross-encoder`, and `colbert`. Refer to the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for a complete list and details on their capabilities. + +* **Model Selection:** Use a PATCH request to `/leapfrogai/v1/rag/configure` with the desired model: + +```json +{ + "enable_reranking": true, // Reranking must be enabled + "ranking_model": "rankllm" // Or another supported model +} +``` + +#### 3. Adjusting the Number of Results Before Reranking (`rag_top_k_when_reranking`) + +This parameter sets the number of top results retrieved from the vector database *before* the reranking process begins. A higher value increases the diversity of candidates considered for reranking but also increases processing time. A lower value can lead to missing relevant results if not carefully chosen. This setting is only relevant when reranking is enabled. + +* **Configuration:** Use a PATCH request to `/leapfrogai/v1/rag/configure` to set this value: + +```json +{ + "enable_reranking": true, + "ranking_model": "flashrank", + "rag_top_k_when_reranking": 150 // Adjust this value as needed +} +``` + +#### 4. Retrieving the Current RAG Configuration + +To check the current RAG configuration (including reranking status, model, and `rag_top_k_when_reranking`), send a GET request to `/leapfrogai/v1/rag/configure`. The response will be a JSON object containing all the current settings. + +#### 5. Example Configuration Flow + +1. **Initial Setup:** Start with reranking enabled using the default `flashrank` model and a `rag_top_k_when_reranking` value of 100. + +2. **Experiment with Models:** Test different reranking models (`rankllm`, `colbert`, etc.) by changing the `ranking_model` parameter and observing the impact on response quality. Adjust `rag_top_k_when_reranking` as needed to find the optimal balance between diversity and performance. + +3. **Fine-tuning:** Once you identify a suitable model, fine-tune the `rag_top_k_when_reranking` parameter for optimal performance. Monitor response times and quality to determine the best setting. + +4. **Disabling Reranking:** If needed, disable reranking by setting `"enable_reranking": false`. + +Remember to always consult the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for information on supported models and their specific requirements. The API documentation provides further details on request formats and potential error responses. diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index e5e0decce..bd0ae9bf6 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -1,11 +1,15 @@ """Service for querying the RAG model.""" +from rerankers.results import RankedResults from supabase import AClient as AsyncClient from langchain_core.embeddings import Embeddings from leapfrogai_api.backend.rag.leapfrogai_embeddings import LeapfrogAIEmbeddings from leapfrogai_api.data.crud_vector_content import CRUDVectorContent -from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse +from leapfrogai_api.typedef.rag.rag_types import ConfigurationSingleton +from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse, SearchItem from leapfrogai_api.backend.constants import TOP_K +from leapfrogai_api.utils.logging_tools import logger +from rerankers import Reranker # Allows for overwriting type of embeddings that will be instantiated embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = ( @@ -22,7 +26,10 @@ def __init__(self, db: AsyncClient) -> None: self.embeddings = embeddings_type() async def query_rag( - self, query: str, vector_store_id: str, k: int = TOP_K + self, + query: str, + vector_store_id: str, + k: int = TOP_K, ) -> SearchResponse: """ Query the Vector Store. @@ -36,11 +43,70 @@ async def query_rag( SearchResponse: The search response from the vector store. """ + logger.debug("Beginning RAG query...") + # 1. Embed query vector = await self.embeddings.aembed_query(query) # 2. Perform similarity search + _k: int = k + if ConfigurationSingleton.get_instance().enable_reranking: + """Use the user specified top-k value unless reranking. + When reranking, use the reranking top-k value to get the initial results. + Then filter the list down later to just the k that the user has requested after reranking.""" + _k = ConfigurationSingleton.get_instance().rag_top_k_when_reranking + crud_vector_content = CRUDVectorContent(db=self.db) - return await crud_vector_content.similarity_search( - query=vector, vector_store_id=vector_store_id, k=k + results = await crud_vector_content.similarity_search( + query=vector, vector_store_id=vector_store_id, k=_k ) + + # 3. Rerank results + if ( + ConfigurationSingleton.get_instance().enable_reranking + and len(results.data) > 0 + ): + ranker = Reranker(ConfigurationSingleton.get_instance().ranking_model) + ranked_results: RankedResults = ranker.rank( + query=query, + docs=[result.content for result in results.data], + doc_ids=[result.id for result in results.data], + ) + results = rerank_search_response(results, ranked_results) + # Narrow down the results to the top-k value specified by the user + results.data = results.data[0:k] + + logger.debug("Ending RAG query...") + + return results + + +def rerank_search_response( + original_response: SearchResponse, ranked_results: RankedResults +) -> SearchResponse: + """ + Reorder the SearchResponse based on reranked results. + + Args: + original_response (SearchResponse): The original search response. + ranked_results (List[str]): List of ranked content strings. + + Returns: + SearchResponse: A new SearchResponse with reordered items. + """ + # Create a mapping of id to original SearchItem + content_to_item = {item.id: item for item in original_response.data} + + # Create new SearchItems based on reranked results + ranked_items = [] + for content in ranked_results.results: + if content.document.doc_id in content_to_item: + item: SearchItem = content_to_item[content.document.doc_id] + item.rank = content.rank + item.score = content.score + ranked_items.append(item) + + ranked_response = SearchResponse(data=ranked_items) + + # Create a new SearchResponse with reranked items + return ranked_response diff --git a/src/leapfrogai_api/main.py b/src/leapfrogai_api/main.py index 85822f7f3..f9b3682d4 100644 --- a/src/leapfrogai_api/main.py +++ b/src/leapfrogai_api/main.py @@ -14,6 +14,7 @@ from leapfrogai_api.routers.leapfrogai import models as lfai_models from leapfrogai_api.routers.leapfrogai import vector_stores as lfai_vector_stores from leapfrogai_api.routers.leapfrogai import count as lfai_token_count +from leapfrogai_api.routers.leapfrogai import rag as lfai_rag from leapfrogai_api.routers.openai import ( assistants, audio, @@ -81,6 +82,8 @@ async def validation_exception_handler(request, exc): app.include_router(messages.router) app.include_router(runs_steps.router) app.include_router(lfai_vector_stores.router) +if os.environ.get("DEV"): + app.include_router(lfai_rag.router) app.include_router(lfai_token_count.router) app.include_router(lfai_models.router) # This should be at the bottom to prevent it preempting more specific runs endpoints diff --git a/src/leapfrogai_api/pyproject.toml b/src/leapfrogai_api/pyproject.toml index a18f6422f..ea9b8f7e4 100644 --- a/src/leapfrogai_api/pyproject.toml +++ b/src/leapfrogai_api/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "postgrest==0.16.11", # required by supabase, bug when using previous versions "openpyxl == 3.1.5", "psutil == 6.0.0", + "rerankers[flashrank] == 0.5.3" ] requires-python = "~=3.11" diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py new file mode 100644 index 000000000..3b61b616e --- /dev/null +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -0,0 +1,56 @@ +"""LeapfrogAI endpoints for RAG.""" + +from fastapi import APIRouter +from leapfrogai_api.typedef.rag.rag_types import ( + ConfigurationSingleton, + ConfigurationPayload, +) +from leapfrogai_api.routers.supabase_session import Session +from leapfrogai_api.utils.logging_tools import logger + +router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) + + +@router.patch("/configure") +async def configure(session: Session, configuration: ConfigurationPayload) -> None: + """ + Configures the RAG settings at runtime. + + Args: + session (Session): The database session. + configuration (Configuration): The configuration to update. + """ + + # We set the class variable to update the configuration globally + ConfigurationSingleton._instance = ConfigurationSingleton.get_instance().copy( + update=configuration.dict(exclude_none=True) + ) + + +@router.get("/configure") +async def get_configuration(session: Session) -> ConfigurationPayload: + """ + Retrieves the current RAG configuration. + + Args: + session (Session): The database session. + + Returns: + Configuration: The current RAG configuration. + """ + + instance = ConfigurationSingleton.get_instance() + + # Create a new dictionary with only the relevant attributes + config_dict = { + key: value + for key, value in instance.__dict__.items() + if not key.startswith("_") # Exclude private attributes + } + + # Create a new ConfigurationPayload instance with the filtered dictionary + new_configuration = ConfigurationPayload(**config_dict) + + logger.info(f"The current configuration has been set to {new_configuration}") + + return new_configuration diff --git a/src/leapfrogai_api/routers/leapfrogai/vector_stores.py b/src/leapfrogai_api/routers/leapfrogai/vector_stores.py index 09f8f4a77..5251440c1 100644 --- a/src/leapfrogai_api/routers/leapfrogai/vector_stores.py +++ b/src/leapfrogai_api/routers/leapfrogai/vector_stores.py @@ -33,9 +33,7 @@ async def search( """ query_service = QueryService(db=session) return await query_service.query_rag( - query=query, - vector_store_id=vector_store_id, - k=k, + query=query, vector_store_id=vector_store_id, k=k ) diff --git a/src/leapfrogai_api/typedef/rag/__init__.py b/src/leapfrogai_api/typedef/rag/__init__.py new file mode 100644 index 000000000..65c2e26cd --- /dev/null +++ b/src/leapfrogai_api/typedef/rag/__init__.py @@ -0,0 +1,3 @@ +from .rag_types import ( + ConfigurationSingleton as ConfigurationSingleton, +) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py new file mode 100644 index 000000000..17fe6601c --- /dev/null +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -0,0 +1,40 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class ConfigurationSingleton: + """Singleton manager for ConfigurationPayload.""" + + _instance = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = ConfigurationPayload() + cls._instance.enable_reranking = True + cls._instance.rag_top_k_when_reranking = 100 + cls._instance.ranking_model = "flashrank" + return cls._instance + + +class ConfigurationPayload(BaseModel): + """Response for RAG configuration.""" + + enable_reranking: Optional[bool] = Field( + default=None, + examples=[True, False], + description="Enables reranking for RAG queries", + ) + # More model info can be found here: + # https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file + # https://pypi.org/project/rerankers/ + ranking_model: Optional[str] = Field( + default=None, + description="What model to use for reranking. Some options may require additional python dependencies.", + examples=["flashrank", "rankllm", "cross-encoder", "colbert"], + ) + rag_top_k_when_reranking: Optional[int] = Field( + default=None, + description="The top-k results returned from the RAG call before reranking", + ) diff --git a/src/leapfrogai_api/typedef/vectorstores/search_types.py b/src/leapfrogai_api/typedef/vectorstores/search_types.py index d8d2a2d13..ea69df1fe 100644 --- a/src/leapfrogai_api/typedef/vectorstores/search_types.py +++ b/src/leapfrogai_api/typedef/vectorstores/search_types.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel, Field @@ -25,6 +27,14 @@ class SearchItem(BaseModel): similarity: float = Field( ..., description="Similarity score of this item to the query." ) + rank: Optional[int] = Field( + default=None, + description="The rank of this search item after ranking has occurred.", + ) + score: Optional[float] = Field( + default=None, + description="The score of this search item after ranking has occurred.", + ) class SearchResponse(BaseModel): diff --git a/src/leapfrogai_api/utils/logging_tools.py b/src/leapfrogai_api/utils/logging_tools.py new file mode 100644 index 000000000..aa2448288 --- /dev/null +++ b/src/leapfrogai_api/utils/logging_tools.py @@ -0,0 +1,12 @@ +import os +import logging +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig( + level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), + format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s", +) + +logger = logging.getLogger(__name__) diff --git a/src/leapfrogai_evals/pyproject.toml b/src/leapfrogai_evals/pyproject.toml index 1974da81a..9726c51c0 100644 --- a/src/leapfrogai_evals/pyproject.toml +++ b/src/leapfrogai_evals/pyproject.toml @@ -8,7 +8,7 @@ version = "0.13.1" dependencies = [ "deepeval == 1.3.0", - "openai == 1.42.0", + "openai == 1.45.0", "tqdm == 4.66.5", "python-dotenv == 1.0.1", "seaborn == 0.13.2", @@ -16,7 +16,8 @@ dependencies = [ "huggingface-hub == 0.24.6", "anthropic ==0.34.2", "instructor ==1.4.3", - "pyPDF2 == 3.0.1" + "pyPDF2 == 3.0.1", + "python-dotenv == 1.0.1" ] requires-python = "~=3.11" readme = "README.md" diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py index 45f832418..7520ddbcc 100644 --- a/tests/integration/api/test_rag_files.py +++ b/tests/integration/api/test_rag_files.py @@ -1,9 +1,13 @@ import os +from typing import Optional + +import requests from openai.types.beta.threads.text import Text import pytest from tests.utils.data_path import data_path -from tests.utils.client import client_config_factory +from leapfrogai_api.typedef.rag.rag_types import ConfigurationPayload +from tests.utils.client import client_config_factory, get_leapfrogai_api_url_base def make_test_assistant(client, model, vector_store_id): @@ -77,3 +81,66 @@ def test_rag_needle_haystack(): for a in message_content.annotations: print(a.text) + + +def configure_rag( + enable_reranking: bool, + ranking_model: str, + rag_top_k_when_reranking: int, +): + """ + Configures the RAG settings. + + Args: + enable_reranking: Whether to enable reranking. + ranking_model: The ranking model to use. + rag_top_k_when_reranking: The top-k results to return before reranking. + """ + url = f"{get_leapfrogai_api_url_base()}/leapfrogai/v1/rag/configure" + configuration = ConfigurationPayload( + enable_reranking=enable_reranking, + ranking_model=ranking_model, + rag_top_k_when_reranking=rag_top_k_when_reranking, + ) + + try: + response = requests.patch(url, json=configuration.model_dump()) + response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) + print("RAG configuration updated successfully.") + except requests.exceptions.RequestException as e: + print(f"Error configuring RAG: {e}") + + +def get_rag_configuration() -> Optional[ConfigurationPayload]: + """ + Retrieves the current RAG configuration. + + Args: + base_url: The base URL of the API. + + Returns: + The RAG configuration, or None if there was an error. + """ + url = f"{get_leapfrogai_api_url_base()}/leapfrogai/v1/rag/configure" + + try: + response = requests.get(url) + response.raise_for_status() + config = ConfigurationPayload.model_validate_json(response.text) + print(f"Current RAG configuration: {config}") + return config + except requests.exceptions.RequestException as e: + print(f"Error getting RAG configuration: {e}") + return None + + +@pytest.mark.skipif( + os.environ.get("LFAI_RUN_NIAH_TESTS") != "true", + reason="LFAI_RUN_NIAH_TESTS envvar was not set to true", +) +def test_rag_needle_haystack_with_reranking(): + configure_rag(True, "flashrank", 100) + config_result = get_rag_configuration() + assert config_result is not None + assert config_result.enable_reranking is True + test_rag_needle_haystack() diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index 724b0dc58..ec6460fda 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -32,6 +32,7 @@ ) TEXT_INPUT_LEN = len(TEXT_INPUT) + ######################### ######################### @@ -147,6 +148,7 @@ def test_routes(): "/openai/v1/files": ["POST"], "/openai/v1/assistants": ["POST"], "/leapfrogai/v1/count/tokens": ["POST"], + "/leapfrogai/v1/rag/configure": ["GET", "PATCH"], } openai_routes = [ @@ -196,10 +198,14 @@ def test_routes(): ] actual_routes = app.routes - for route in actual_routes: - if hasattr(route, "path") and route.path in expected_routes: - assert route.methods == set(expected_routes[route.path]) - del expected_routes[route.path] + for expected_route in expected_routes: + matching_routes = {expected_route: []} + for actual_route in actual_routes: + if hasattr(actual_route, "path") and expected_route == actual_route.path: + matching_routes[actual_route.path].extend(actual_route.methods) + assert set(expected_routes[expected_route]) <= set( + matching_routes[expected_route] + ) for route, name, methods in openai_routes: found = False @@ -214,8 +220,6 @@ def test_routes(): break assert found, f"Missing route: {route}, {name}, {methods}" - assert len(expected_routes) == 0 - def test_healthz(): """Test the healthz endpoint.""" @@ -535,3 +539,55 @@ def test_token_count(dummy_auth_middleware): assert "token_count" in response_data assert isinstance(response_data["token_count"], int) assert response_data["token_count"] == len(input_text) + + +@pytest.mark.skipif( + os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true" + or os.environ.get("DEV") != "true", + reason="LFAI_RUN_REPEATER_TESTS envvar was not set to true", +) +def test_configure(dummy_auth_middleware): + """Test the RAG configuration endpoints.""" + with TestClient(app) as client: + rag_configuration_request = { + "enable_reranking": True, + "ranking_model": "rankllm", + "rag_top_k_when_reranking": 50, + } + response = client.patch( + "/leapfrogai/v1/rag/configure", json=rag_configuration_request + ) + assert response.status_code == 200 + + response = client.get("/leapfrogai/v1/rag/configure") + assert response.status_code == 200 + response_data = response.json() + assert "enable_reranking" in response_data + assert "ranking_model" in response_data + assert "rag_top_k_when_reranking" in response_data + assert isinstance(response_data["enable_reranking"], bool) + assert isinstance(response_data["ranking_model"], str) + assert isinstance(response_data["rag_top_k_when_reranking"], int) + assert response_data["enable_reranking"] is True + assert response_data["ranking_model"] == "rankllm" + assert response_data["rag_top_k_when_reranking"] == 50 + + # Update only some of the configs to see if the existing ones persist + rag_configuration_request = {"ranking_model": "flashrank"} + response = client.patch( + "/leapfrogai/v1/rag/configure", json=rag_configuration_request + ) + assert response.status_code == 200 + + response = client.get("/leapfrogai/v1/rag/configure") + assert response.status_code == 200 + response_data = response.json() + assert "enable_reranking" in response_data + assert "ranking_model" in response_data + assert "rag_top_k_when_reranking" in response_data + assert isinstance(response_data["enable_reranking"], bool) + assert isinstance(response_data["ranking_model"], str) + assert isinstance(response_data["rag_top_k_when_reranking"], int) + assert response_data["enable_reranking"] is True + assert response_data["ranking_model"] == "flashrank" + assert response_data["rag_top_k_when_reranking"] == 50