Skip to content

Commit

Permalink
feat(api): reranking backend integrated in with rag (#1090)
Browse files Browse the repository at this point in the history
* Adds reranking to RAG pipeline
* Adds RAG configuration endpoint when in dev mode
* Adds additional logging
* Refactors the pytest's test_routes api tests
* Alters default RAG values into two steps, retrieval and ranking. With the retrieval results being set to 100 after ranking the results are filtered down to the user specified k value. If reranking is not enabled, the user specified k results is returned from the retrieval step.
* Adds Zarf configs to enable dev mode.
  • Loading branch information
CollectiveUnicorn authored Oct 1, 2024
1 parent 185dcbb commit 2f80d87
Show file tree
Hide file tree
Showing 18 changed files with 408 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
run: make test-api-unit
env:
LFAI_RUN_REPEATER_TESTS: true
DEV: true

integration:
runs-on: ai-ubuntu-big-boy-8-core
Expand Down
2 changes: 2 additions & 0 deletions packages/api/chart/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ api:
value: "*.toml"
- name: DEFAULT_EMBEDDINGS_MODEL
value: "text-embeddings"
- name: DEV
value: "false"
- name: PORT
value: "8080"
- name: SUPABASE_URL
Expand Down
2 changes: 2 additions & 0 deletions packages/api/values/registry1-values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ api:
value: "*.toml"
- name: DEFAULT_EMBEDDINGS_MODEL
value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###"
- name: DEV
value: "###ZARF_VAR_DEV###"
- name: PORT
value: "8080"
- name: SUPABASE_URL
Expand Down
2 changes: 2 additions & 0 deletions packages/api/values/upstream-values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ api:
value: "*.toml"
- name: DEFAULT_EMBEDDINGS_MODEL
value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###"
- name: DEV
value: "###ZARF_VAR_DEV###"
- name: PORT
value: "8080"
- name: SUPABASE_URL
Expand Down
3 changes: 3 additions & 0 deletions packages/api/zarf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ variables:
description: "Flag to expose the OpenAPI schema for debugging."
- name: DEFAULT_EMBEDDINGS_MODEL
default: "text-embeddings"
- name: DEV
default: "false"
description: "Flag to enable development endpoints."

components:
- name: leapfrogai-api
Expand Down
69 changes: 69 additions & 0 deletions src/leapfrogai_api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,72 @@ See the ["Access" section of the DEVELOPMENT.md](../../docs/DEVELOPMENT.md#acces
### Tests

See the [tests directory documentation](../../tests/README.md) for more details.

### Reranking Configuration

The LeapfrogAI API includes a Retrieval Augmented Generation (RAG) pipeline for enhanced question answering. This section details how to configure its reranking options. All RAG configurations are managed through the `/leapfrogai/v1/rag/configure` API endpoint.

#### 1. Enabling/Disabling Reranking

Reranking improves the accuracy and relevance of RAG responses. You can enable or disable it using the `enable_reranking` parameter:

* **Enable Reranking:** Send a PATCH request to `/leapfrogai/v1/rag/configure` with the following JSON payload:

```json
{
"enable_reranking": true
}
```

* **Disable Reranking:** Send a PATCH request with:

```json
{
"enable_reranking": false
}
```

#### 2. Selecting a Reranking Model

Multiple reranking models are supported, each offering different performance characteristics. Choose your preferred model using the `ranking_model` parameter. Ensure you've installed any necessary Python dependencies for your chosen model (see the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) on dependencies).
* **Supported Models:** The system supports several models, including (but not limited to) `flashrank`, `rankllm`, `cross-encoder`, and `colbert`. Refer to the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for a complete list and details on their capabilities.
* **Model Selection:** Use a PATCH request to `/leapfrogai/v1/rag/configure` with the desired model:
```json
{
"enable_reranking": true, // Reranking must be enabled
"ranking_model": "rankllm" // Or another supported model
}
```
#### 3. Adjusting the Number of Results Before Reranking (`rag_top_k_when_reranking`)
This parameter sets the number of top results retrieved from the vector database *before* the reranking process begins. A higher value increases the diversity of candidates considered for reranking but also increases processing time. A lower value can lead to missing relevant results if not carefully chosen. This setting is only relevant when reranking is enabled.
* **Configuration:** Use a PATCH request to `/leapfrogai/v1/rag/configure` to set this value:
```json
{
"enable_reranking": true,
"ranking_model": "flashrank",
"rag_top_k_when_reranking": 150 // Adjust this value as needed
}
```
#### 4. Retrieving the Current RAG Configuration
To check the current RAG configuration (including reranking status, model, and `rag_top_k_when_reranking`), send a GET request to `/leapfrogai/v1/rag/configure`. The response will be a JSON object containing all the current settings.
#### 5. Example Configuration Flow
1. **Initial Setup:** Start with reranking enabled using the default `flashrank` model and a `rag_top_k_when_reranking` value of 100.
2. **Experiment with Models:** Test different reranking models (`rankllm`, `colbert`, etc.) by changing the `ranking_model` parameter and observing the impact on response quality. Adjust `rag_top_k_when_reranking` as needed to find the optimal balance between diversity and performance.
3. **Fine-tuning:** Once you identify a suitable model, fine-tune the `rag_top_k_when_reranking` parameter for optimal performance. Monitor response times and quality to determine the best setting.
4. **Disabling Reranking:** If needed, disable reranking by setting `"enable_reranking": false`.
Remember to always consult the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for information on supported models and their specific requirements. The API documentation provides further details on request formats and potential error responses.
74 changes: 70 additions & 4 deletions src/leapfrogai_api/backend/rag/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Service for querying the RAG model."""

from rerankers.results import RankedResults
from supabase import AClient as AsyncClient
from langchain_core.embeddings import Embeddings
from leapfrogai_api.backend.rag.leapfrogai_embeddings import LeapfrogAIEmbeddings
from leapfrogai_api.data.crud_vector_content import CRUDVectorContent
from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse
from leapfrogai_api.typedef.rag.rag_types import ConfigurationSingleton
from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse, SearchItem
from leapfrogai_api.backend.constants import TOP_K
from leapfrogai_api.utils.logging_tools import logger
from rerankers import Reranker

# Allows for overwriting type of embeddings that will be instantiated
embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = (
Expand All @@ -22,7 +26,10 @@ def __init__(self, db: AsyncClient) -> None:
self.embeddings = embeddings_type()

async def query_rag(
self, query: str, vector_store_id: str, k: int = TOP_K
self,
query: str,
vector_store_id: str,
k: int = TOP_K,
) -> SearchResponse:
"""
Query the Vector Store.
Expand All @@ -36,11 +43,70 @@ async def query_rag(
SearchResponse: The search response from the vector store.
"""

logger.debug("Beginning RAG query...")

# 1. Embed query
vector = await self.embeddings.aembed_query(query)

# 2. Perform similarity search
_k: int = k
if ConfigurationSingleton.get_instance().enable_reranking:
"""Use the user specified top-k value unless reranking.
When reranking, use the reranking top-k value to get the initial results.
Then filter the list down later to just the k that the user has requested after reranking."""
_k = ConfigurationSingleton.get_instance().rag_top_k_when_reranking

crud_vector_content = CRUDVectorContent(db=self.db)
return await crud_vector_content.similarity_search(
query=vector, vector_store_id=vector_store_id, k=k
results = await crud_vector_content.similarity_search(
query=vector, vector_store_id=vector_store_id, k=_k
)

# 3. Rerank results
if (
ConfigurationSingleton.get_instance().enable_reranking
and len(results.data) > 0
):
ranker = Reranker(ConfigurationSingleton.get_instance().ranking_model)
ranked_results: RankedResults = ranker.rank(
query=query,
docs=[result.content for result in results.data],
doc_ids=[result.id for result in results.data],
)
results = rerank_search_response(results, ranked_results)
# Narrow down the results to the top-k value specified by the user
results.data = results.data[0:k]

logger.debug("Ending RAG query...")

return results


def rerank_search_response(
original_response: SearchResponse, ranked_results: RankedResults
) -> SearchResponse:
"""
Reorder the SearchResponse based on reranked results.
Args:
original_response (SearchResponse): The original search response.
ranked_results (List[str]): List of ranked content strings.
Returns:
SearchResponse: A new SearchResponse with reordered items.
"""
# Create a mapping of id to original SearchItem
content_to_item = {item.id: item for item in original_response.data}

# Create new SearchItems based on reranked results
ranked_items = []
for content in ranked_results.results:
if content.document.doc_id in content_to_item:
item: SearchItem = content_to_item[content.document.doc_id]
item.rank = content.rank
item.score = content.score
ranked_items.append(item)

ranked_response = SearchResponse(data=ranked_items)

# Create a new SearchResponse with reranked items
return ranked_response
3 changes: 3 additions & 0 deletions src/leapfrogai_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from leapfrogai_api.routers.leapfrogai import models as lfai_models
from leapfrogai_api.routers.leapfrogai import vector_stores as lfai_vector_stores
from leapfrogai_api.routers.leapfrogai import count as lfai_token_count
from leapfrogai_api.routers.leapfrogai import rag as lfai_rag
from leapfrogai_api.routers.openai import (
assistants,
audio,
Expand Down Expand Up @@ -81,6 +82,8 @@ async def validation_exception_handler(request, exc):
app.include_router(messages.router)
app.include_router(runs_steps.router)
app.include_router(lfai_vector_stores.router)
if os.environ.get("DEV"):
app.include_router(lfai_rag.router)
app.include_router(lfai_token_count.router)
app.include_router(lfai_models.router)
# This should be at the bottom to prevent it preempting more specific runs endpoints
Expand Down
1 change: 1 addition & 0 deletions src/leapfrogai_api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"postgrest==0.16.11", # required by supabase, bug when using previous versions
"openpyxl == 3.1.5",
"psutil == 6.0.0",
"rerankers[flashrank] == 0.5.3"
]
requires-python = "~=3.11"

Expand Down
56 changes: 56 additions & 0 deletions src/leapfrogai_api/routers/leapfrogai/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""LeapfrogAI endpoints for RAG."""

from fastapi import APIRouter
from leapfrogai_api.typedef.rag.rag_types import (
ConfigurationSingleton,
ConfigurationPayload,
)
from leapfrogai_api.routers.supabase_session import Session
from leapfrogai_api.utils.logging_tools import logger

router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"])


@router.patch("/configure")
async def configure(session: Session, configuration: ConfigurationPayload) -> None:
"""
Configures the RAG settings at runtime.
Args:
session (Session): The database session.
configuration (Configuration): The configuration to update.
"""

# We set the class variable to update the configuration globally
ConfigurationSingleton._instance = ConfigurationSingleton.get_instance().copy(
update=configuration.dict(exclude_none=True)
)


@router.get("/configure")
async def get_configuration(session: Session) -> ConfigurationPayload:
"""
Retrieves the current RAG configuration.
Args:
session (Session): The database session.
Returns:
Configuration: The current RAG configuration.
"""

instance = ConfigurationSingleton.get_instance()

# Create a new dictionary with only the relevant attributes
config_dict = {
key: value
for key, value in instance.__dict__.items()
if not key.startswith("_") # Exclude private attributes
}

# Create a new ConfigurationPayload instance with the filtered dictionary
new_configuration = ConfigurationPayload(**config_dict)

logger.info(f"The current configuration has been set to {new_configuration}")

return new_configuration
4 changes: 1 addition & 3 deletions src/leapfrogai_api/routers/leapfrogai/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ async def search(
"""
query_service = QueryService(db=session)
return await query_service.query_rag(
query=query,
vector_store_id=vector_store_id,
k=k,
query=query, vector_store_id=vector_store_id, k=k
)


Expand Down
3 changes: 3 additions & 0 deletions src/leapfrogai_api/typedef/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .rag_types import (
ConfigurationSingleton as ConfigurationSingleton,
)
40 changes: 40 additions & 0 deletions src/leapfrogai_api/typedef/rag/rag_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional

from pydantic import BaseModel, Field


class ConfigurationSingleton:
"""Singleton manager for ConfigurationPayload."""

_instance = None

@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = ConfigurationPayload()
cls._instance.enable_reranking = True
cls._instance.rag_top_k_when_reranking = 100
cls._instance.ranking_model = "flashrank"
return cls._instance


class ConfigurationPayload(BaseModel):
"""Response for RAG configuration."""

enable_reranking: Optional[bool] = Field(
default=None,
examples=[True, False],
description="Enables reranking for RAG queries",
)
# More model info can be found here:
# https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file
# https://pypi.org/project/rerankers/
ranking_model: Optional[str] = Field(
default=None,
description="What model to use for reranking. Some options may require additional python dependencies.",
examples=["flashrank", "rankllm", "cross-encoder", "colbert"],
)
rag_top_k_when_reranking: Optional[int] = Field(
default=None,
description="The top-k results returned from the RAG call before reranking",
)
10 changes: 10 additions & 0 deletions src/leapfrogai_api/typedef/vectorstores/search_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pydantic import BaseModel, Field


Expand Down Expand Up @@ -25,6 +27,14 @@ class SearchItem(BaseModel):
similarity: float = Field(
..., description="Similarity score of this item to the query."
)
rank: Optional[int] = Field(
default=None,
description="The rank of this search item after ranking has occurred.",
)
score: Optional[float] = Field(
default=None,
description="The score of this search item after ranking has occurred.",
)


class SearchResponse(BaseModel):
Expand Down
12 changes: 12 additions & 0 deletions src/leapfrogai_api/utils/logging_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import logging
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(
level=os.getenv("LFAI_LOG_LEVEL", logging.INFO),
format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s",
)

logger = logging.getLogger(__name__)
Loading

0 comments on commit 2f80d87

Please sign in to comment.