-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(api): reranking backend integrated in with rag (#1090)
* Adds reranking to RAG pipeline * Adds RAG configuration endpoint when in dev mode * Adds additional logging * Refactors the pytest's test_routes api tests * Alters default RAG values into two steps, retrieval and ranking. With the retrieval results being set to 100 after ranking the results are filtered down to the user specified k value. If reranking is not enabled, the user specified k results is returned from the retrieval step. * Adds Zarf configs to enable dev mode.
- Loading branch information
1 parent
185dcbb
commit 2f80d87
Showing
18 changed files
with
408 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""LeapfrogAI endpoints for RAG.""" | ||
|
||
from fastapi import APIRouter | ||
from leapfrogai_api.typedef.rag.rag_types import ( | ||
ConfigurationSingleton, | ||
ConfigurationPayload, | ||
) | ||
from leapfrogai_api.routers.supabase_session import Session | ||
from leapfrogai_api.utils.logging_tools import logger | ||
|
||
router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) | ||
|
||
|
||
@router.patch("/configure") | ||
async def configure(session: Session, configuration: ConfigurationPayload) -> None: | ||
""" | ||
Configures the RAG settings at runtime. | ||
Args: | ||
session (Session): The database session. | ||
configuration (Configuration): The configuration to update. | ||
""" | ||
|
||
# We set the class variable to update the configuration globally | ||
ConfigurationSingleton._instance = ConfigurationSingleton.get_instance().copy( | ||
update=configuration.dict(exclude_none=True) | ||
) | ||
|
||
|
||
@router.get("/configure") | ||
async def get_configuration(session: Session) -> ConfigurationPayload: | ||
""" | ||
Retrieves the current RAG configuration. | ||
Args: | ||
session (Session): The database session. | ||
Returns: | ||
Configuration: The current RAG configuration. | ||
""" | ||
|
||
instance = ConfigurationSingleton.get_instance() | ||
|
||
# Create a new dictionary with only the relevant attributes | ||
config_dict = { | ||
key: value | ||
for key, value in instance.__dict__.items() | ||
if not key.startswith("_") # Exclude private attributes | ||
} | ||
|
||
# Create a new ConfigurationPayload instance with the filtered dictionary | ||
new_configuration = ConfigurationPayload(**config_dict) | ||
|
||
logger.info(f"The current configuration has been set to {new_configuration}") | ||
|
||
return new_configuration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .rag_types import ( | ||
ConfigurationSingleton as ConfigurationSingleton, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class ConfigurationSingleton: | ||
"""Singleton manager for ConfigurationPayload.""" | ||
|
||
_instance = None | ||
|
||
@classmethod | ||
def get_instance(cls): | ||
if cls._instance is None: | ||
cls._instance = ConfigurationPayload() | ||
cls._instance.enable_reranking = True | ||
cls._instance.rag_top_k_when_reranking = 100 | ||
cls._instance.ranking_model = "flashrank" | ||
return cls._instance | ||
|
||
|
||
class ConfigurationPayload(BaseModel): | ||
"""Response for RAG configuration.""" | ||
|
||
enable_reranking: Optional[bool] = Field( | ||
default=None, | ||
examples=[True, False], | ||
description="Enables reranking for RAG queries", | ||
) | ||
# More model info can be found here: | ||
# https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file | ||
# https://pypi.org/project/rerankers/ | ||
ranking_model: Optional[str] = Field( | ||
default=None, | ||
description="What model to use for reranking. Some options may require additional python dependencies.", | ||
examples=["flashrank", "rankllm", "cross-encoder", "colbert"], | ||
) | ||
rag_top_k_when_reranking: Optional[int] = Field( | ||
default=None, | ||
description="The top-k results returned from the RAG call before reranking", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
import logging | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
logging.basicConfig( | ||
level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), | ||
format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s", | ||
) | ||
|
||
logger = logging.getLogger(__name__) |
Oops, something went wrong.