From 63e8d04f69a182bc95db8057689e9beb2054d7b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Fri, 22 Nov 2024 21:20:58 +0100 Subject: [PATCH] refactor(rag): improve RAG manager implementation (#268) * refactor(rag): improve RAG manager implementation - Introduce shared RAG manager singleton via get_rag_manager() - Add DEFAULT_COLLECTION constant for consistent collection naming - Remove unused _clear_cache function - Update gptme-rag dependency to 0.3.1 * fix: use correct path to reset RAG manager state between tests - Fix type error by using correct internal path to _rag_manager - Split imports for better readability - Organize imports more clearly --- gptme/tools/_rag_context.py | 21 +++++++++++++++------ gptme/tools/rag.py | 20 +++++++++----------- poetry.lock | 8 ++++---- pyproject.toml | 2 +- tests/test_tools_rag.py | 11 +++++++---- 5 files changed, 36 insertions(+), 26 deletions(-) diff --git a/gptme/tools/_rag_context.py b/gptme/tools/_rag_context.py index 3c825ba9..01ca66e5 100644 --- a/gptme/tools/_rag_context.py +++ b/gptme/tools/_rag_context.py @@ -11,6 +11,9 @@ logger = logging.getLogger(__name__) +# Constant collection name to ensure consistency +DEFAULT_COLLECTION = "gptme-default" + try: import gptme_rag # type: ignore # fmt: skip @@ -20,10 +23,21 @@ _HAS_RAG = False +# Shared RAG manager instance +_rag_manager: "RAGManager | None" = None + # Simple in-memory cache for search results _search_cache: dict[str, tuple[list[Any], dict]] = {} +def get_rag_manager() -> "RAGManager": + """Get or create the shared RAG manager instance.""" + global _rag_manager + if _rag_manager is None: + _rag_manager = RAGManager() + return _rag_manager + + def _get_search_results(query: str, n_results: int) -> tuple[list[Any], dict] | None: """Get cached search results.""" return _search_cache.get(f"{query}::{n_results}") @@ -36,11 +50,6 @@ def _set_search_results( _search_cache[f"{query}::{n_results}"] = results -def _clear_cache() -> None: - """Clear the search cache.""" - _search_cache.clear() - - @dataclass class Context: """Context information to be added to messages.""" @@ -72,7 +81,7 @@ def __init__(self, index_path: Path | None = None, collection: str | None = None # Use config values if not overridden by parameters self.index_path = index_path or Path("~/.cache/gptme/rag").expanduser() - self.collection = collection or "default" + self.collection = collection or DEFAULT_COLLECTION # Initialize the indexer self.indexer = gptme_rag.Indexer( diff --git a/gptme/tools/rag.py b/gptme/tools/rag.py index 63388cef..b04f262c 100644 --- a/gptme/tools/rag.py +++ b/gptme/tools/rag.py @@ -50,13 +50,11 @@ from ..config import get_project_config from ..util import get_project_dir -from ._rag_context import _HAS_RAG, RAGManager +from ._rag_context import _HAS_RAG, get_rag_manager from .base import ToolSpec, ToolUse logger = logging.getLogger(__name__) -rag_manager: RAGManager | None = None - instructions = """ Use RAG to index and search project documentation. """ @@ -82,19 +80,19 @@ def rag_index(*paths: str, glob: str | None = None) -> str: """Index documents in specified paths.""" - assert rag_manager is not None, "RAG manager not initialized" + manager = get_rag_manager() paths = paths or (".",) kwargs = {"glob_pattern": glob} if glob else {} total_docs = 0 for path in paths: - total_docs += rag_manager.index_directory(Path(path), **kwargs) + total_docs += manager.index_directory(Path(path), **kwargs) return f"Indexed {len(paths)} paths ({total_docs} documents)" def rag_search(query: str) -> str: """Search indexed documents.""" - assert rag_manager is not None, "RAG manager not initialized" - docs, _ = rag_manager.search(query) + manager = get_rag_manager() + docs, _ = manager.search(query) return "\n\n".join( f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs ) @@ -102,8 +100,8 @@ def rag_search(query: str) -> str: def rag_status() -> str: """Show index status.""" - assert rag_manager is not None, "RAG manager not initialized" - return f"Index contains {rag_manager.get_document_count()} documents" + manager = get_rag_manager() + return f"Index contains {manager.get_document_count()} documents" _init_run = False @@ -129,8 +127,8 @@ def init() -> ToolSpec: logger.debug("Project configuration not found, not enabling") return replace(tool, available=False) - global rag_manager - rag_manager = RAGManager() + # Initialize the shared RAG manager + get_rag_manager() return tool diff --git a/poetry.lock b/poetry.lock index 1d57ef33..67ef9c6b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1098,13 +1098,13 @@ files = [ [[package]] name = "gptme-rag" -version = "0.3.0" +version = "0.3.1" description = "RAG implementation for gptme context management" optional = true python-versions = "<4.0,>=3.10" files = [ - {file = "gptme_rag-0.3.0-py3-none-any.whl", hash = "sha256:76e0b8ffb0367971b33024815644c2d0c6db2922424d2e81093ef4db855e87e5"}, - {file = "gptme_rag-0.3.0.tar.gz", hash = "sha256:72318084c134236080f83e72d7eb62dff6274b1594c81c68587b4577b7a9cb40"}, + {file = "gptme_rag-0.3.1-py3-none-any.whl", hash = "sha256:57084ddcfc70959d9efbd91c10c71b7b7702387e72a3343a6de413518e242cb5"}, + {file = "gptme_rag-0.3.1.tar.gz", hash = "sha256:5dea6cb66aa44271f89f465b0e09306a93164bcf259646e709539f1ee60d4169"}, ] [package.dependencies] @@ -5184,4 +5184,4 @@ server = ["flask", "flask-cors"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "2fa48920c632e5f420cfcc3f0d37b68ce630e9b6941a9b8f2acf091640cd568c" +content-hash = "8ba9182b2c6c57a16bde1a5d0f37003f07eb72225f6690af7acd9df5612cfae6" diff --git a/pyproject.toml b/pyproject.toml index 4cd2e10e..b7eb0e88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ youtube_transcript_api = {version = "^0.6.1", optional = true} python-xlib = {version = "^0.33", optional = true} # for X11 interaction # RAG -gptme-rag = {version = "^0.3.0", optional = true} +gptme-rag = {version = "^0.3.1", optional = true} #gptme-rag = {path = "../gptme-rag", optional = true, develop = true} # providers diff --git a/tests/test_tools_rag.py b/tests/test_tools_rag.py index 45a215b7..55edc740 100644 --- a/tests/test_tools_rag.py +++ b/tests/test_tools_rag.py @@ -2,21 +2,24 @@ from unittest.mock import patch +import gptme.tools._rag_context +import gptme.tools.rag import pytest from gptme.message import Message from gptme.tools._rag_context import enhance_messages -from gptme.tools.rag import _HAS_RAG, init as init_rag, rag_index, rag_search +from gptme.tools.rag import _HAS_RAG +from gptme.tools.rag import init as init_rag +from gptme.tools.rag import rag_index, rag_search @pytest.fixture(autouse=True) def reset_rag(): """Reset the RAG manager and init state before and after each test.""" - import gptme.tools.rag - gptme.tools.rag.rag_manager = None + gptme.tools._rag_context._rag_manager = None gptme.tools.rag._init_run = False yield - gptme.tools.rag.rag_manager = None + gptme.tools._rag_context._rag_manager = None gptme.tools.rag._init_run = False