Skip to content

Commit

Permalink
refactor(rag): improve RAG manager implementation
Browse files Browse the repository at this point in the history
- Introduce shared RAG manager singleton via get_rag_manager()
- Add DEFAULT_COLLECTION constant for consistent collection naming
- Remove unused _clear_cache function
- Update gptme-rag dependency to 0.3.1
  • Loading branch information
ErikBjare committed Nov 22, 2024
1 parent 990965f commit 6d52f30
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 22 deletions.
21 changes: 15 additions & 6 deletions gptme/tools/_rag_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

logger = logging.getLogger(__name__)

# Constant collection name to ensure consistency
DEFAULT_COLLECTION = "gptme-default"

try:
import gptme_rag # type: ignore # fmt: skip

Expand All @@ -20,10 +23,21 @@
_HAS_RAG = False


# Shared RAG manager instance
_rag_manager: "RAGManager | None" = None

# Simple in-memory cache for search results
_search_cache: dict[str, tuple[list[Any], dict]] = {}


def get_rag_manager() -> "RAGManager":
"""Get or create the shared RAG manager instance."""
global _rag_manager
if _rag_manager is None:
_rag_manager = RAGManager()
return _rag_manager


def _get_search_results(query: str, n_results: int) -> tuple[list[Any], dict] | None:
"""Get cached search results."""
return _search_cache.get(f"{query}::{n_results}")
Expand All @@ -36,11 +50,6 @@ def _set_search_results(
_search_cache[f"{query}::{n_results}"] = results


def _clear_cache() -> None:
"""Clear the search cache."""
_search_cache.clear()


@dataclass
class Context:
"""Context information to be added to messages."""
Expand Down Expand Up @@ -72,7 +81,7 @@ def __init__(self, index_path: Path | None = None, collection: str | None = None

# Use config values if not overridden by parameters
self.index_path = index_path or Path("~/.cache/gptme/rag").expanduser()
self.collection = collection or "default"
self.collection = collection or DEFAULT_COLLECTION

# Initialize the indexer
self.indexer = gptme_rag.Indexer(
Expand Down
20 changes: 9 additions & 11 deletions gptme/tools/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@

from ..config import get_project_config
from ..util import get_project_dir
from ._rag_context import _HAS_RAG, RAGManager
from ._rag_context import _HAS_RAG, get_rag_manager
from .base import ToolSpec, ToolUse

logger = logging.getLogger(__name__)

rag_manager: RAGManager | None = None

instructions = """
Use RAG to index and search project documentation.
"""
Expand All @@ -82,28 +80,28 @@

def rag_index(*paths: str, glob: str | None = None) -> str:
"""Index documents in specified paths."""
assert rag_manager is not None, "RAG manager not initialized"
manager = get_rag_manager()
paths = paths or (".",)
kwargs = {"glob_pattern": glob} if glob else {}
total_docs = 0
for path in paths:
total_docs += rag_manager.index_directory(Path(path), **kwargs)
total_docs += manager.index_directory(Path(path), **kwargs)
return f"Indexed {len(paths)} paths ({total_docs} documents)"


def rag_search(query: str) -> str:
"""Search indexed documents."""
assert rag_manager is not None, "RAG manager not initialized"
docs, _ = rag_manager.search(query)
manager = get_rag_manager()
docs, _ = manager.search(query)
return "\n\n".join(
f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs
)


def rag_status() -> str:
"""Show index status."""
assert rag_manager is not None, "RAG manager not initialized"
return f"Index contains {rag_manager.get_document_count()} documents"
manager = get_rag_manager()
return f"Index contains {manager.get_document_count()} documents"


_init_run = False
Expand All @@ -129,8 +127,8 @@ def init() -> ToolSpec:
logger.debug("Project configuration not found, not enabling")
return replace(tool, available=False)

global rag_manager
rag_manager = RAGManager()
# Initialize the shared RAG manager
get_rag_manager()
return tool


Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ youtube_transcript_api = {version = "^0.6.1", optional = true}
python-xlib = {version = "^0.33", optional = true} # for X11 interaction

# RAG
gptme-rag = {version = "^0.3.0", optional = true}
gptme-rag = {version = "^0.3.1", optional = true}
#gptme-rag = {path = "../gptme-rag", optional = true, develop = true}

# providers
Expand Down

0 comments on commit 6d52f30

Please sign in to comment.