Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rag): improve RAG manager implementation #268

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions gptme/tools/_rag_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

logger = logging.getLogger(__name__)

# Constant collection name to ensure consistency
DEFAULT_COLLECTION = "gptme-default"

try:
import gptme_rag # type: ignore # fmt: skip

Expand All @@ -20,10 +23,21 @@
_HAS_RAG = False


# Shared RAG manager instance
_rag_manager: "RAGManager | None" = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a global variable for _rag_manager can lead to issues in multi-threaded environments or when the state needs to be reset. Consider using a class or context manager to handle the RAG manager instance.


# Simple in-memory cache for search results
_search_cache: dict[str, tuple[list[Any], dict]] = {}


def get_rag_manager() -> "RAGManager":
"""Get or create the shared RAG manager instance."""
global _rag_manager
if _rag_manager is None:
_rag_manager = RAGManager()
return _rag_manager


def _get_search_results(query: str, n_results: int) -> tuple[list[Any], dict] | None:
"""Get cached search results."""
return _search_cache.get(f"{query}::{n_results}")
Expand All @@ -36,11 +50,6 @@ def _set_search_results(
_search_cache[f"{query}::{n_results}"] = results


def _clear_cache() -> None:
"""Clear the search cache."""
_search_cache.clear()


@dataclass
class Context:
"""Context information to be added to messages."""
Expand Down Expand Up @@ -72,7 +81,7 @@ def __init__(self, index_path: Path | None = None, collection: str | None = None

# Use config values if not overridden by parameters
self.index_path = index_path or Path("~/.cache/gptme/rag").expanduser()
self.collection = collection or "default"
self.collection = collection or DEFAULT_COLLECTION

# Initialize the indexer
self.indexer = gptme_rag.Indexer(
Expand Down
20 changes: 9 additions & 11 deletions gptme/tools/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@

from ..config import get_project_config
from ..util import get_project_dir
from ._rag_context import _HAS_RAG, RAGManager
from ._rag_context import _HAS_RAG, get_rag_manager
from .base import ToolSpec, ToolUse

logger = logging.getLogger(__name__)

rag_manager: RAGManager | None = None

instructions = """
Use RAG to index and search project documentation.
"""
Expand All @@ -82,28 +80,28 @@

def rag_index(*paths: str, glob: str | None = None) -> str:
"""Index documents in specified paths."""
assert rag_manager is not None, "RAG manager not initialized"
manager = get_rag_manager()
paths = paths or (".",)
kwargs = {"glob_pattern": glob} if glob else {}
total_docs = 0
for path in paths:
total_docs += rag_manager.index_directory(Path(path), **kwargs)
total_docs += manager.index_directory(Path(path), **kwargs)
return f"Indexed {len(paths)} paths ({total_docs} documents)"


def rag_search(query: str) -> str:
"""Search indexed documents."""
assert rag_manager is not None, "RAG manager not initialized"
docs, _ = rag_manager.search(query)
manager = get_rag_manager()
docs, _ = manager.search(query)
return "\n\n".join(
f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs
)


def rag_status() -> str:
"""Show index status."""
assert rag_manager is not None, "RAG manager not initialized"
return f"Index contains {rag_manager.get_document_count()} documents"
manager = get_rag_manager()
return f"Index contains {manager.get_document_count()} documents"


_init_run = False
Expand All @@ -129,8 +127,8 @@ def init() -> ToolSpec:
logger.debug("Project configuration not found, not enabling")
return replace(tool, available=False)

global rag_manager
rag_manager = RAGManager()
# Initialize the shared RAG manager
get_rag_manager()
return tool


Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ youtube_transcript_api = {version = "^0.6.1", optional = true}
python-xlib = {version = "^0.33", optional = true} # for X11 interaction

# RAG
gptme-rag = {version = "^0.3.0", optional = true}
gptme-rag = {version = "^0.3.1", optional = true}
#gptme-rag = {path = "../gptme-rag", optional = true, develop = true}

# providers
Expand Down
11 changes: 7 additions & 4 deletions tests/test_tools_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

from unittest.mock import patch

import gptme.tools._rag_context
import gptme.tools.rag
import pytest
from gptme.message import Message
from gptme.tools._rag_context import enhance_messages
from gptme.tools.rag import _HAS_RAG, init as init_rag, rag_index, rag_search
from gptme.tools.rag import _HAS_RAG
from gptme.tools.rag import init as init_rag
from gptme.tools.rag import rag_index, rag_search


@pytest.fixture(autouse=True)
def reset_rag():
"""Reset the RAG manager and init state before and after each test."""
import gptme.tools.rag

gptme.tools.rag.rag_manager = None
gptme.tools._rag_context._rag_manager = None
gptme.tools.rag._init_run = False
yield
gptme.tools.rag.rag_manager = None
gptme.tools._rag_context._rag_manager = None
gptme.tools.rag._init_run = False


Expand Down
Loading