From 84e5ab6301b13cb3a9eaafccc54c44af9649a4d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Fri, 22 Nov 2024 15:11:45 +0100 Subject: [PATCH] refactor(rag): simplify configuration and initialization (#266) * refactor(rag): simplify configuration and initialization - Reduce configuration complexity to just an enable flag - Use default paths and collection names - Add initialization state tracking - Improve code organization and documentation * fix: add proper logging cleanup to prevent errors during shutdown * feat: add custom index path and collection support to RAGManager Makes RAGManager more flexible by allowing custom index paths and collection names to be specified during initialization, while maintaining backward compatibility with default values. - Added index_path and collection parameters to RAGManager.__init__ - Maintains defaults (~/.cache/gptme/rag and 'default' respectively) - Allows for better testing and customization of RAG functionality * test: simplify RAG tests Major refactoring of RAG tests to be more focused and maintainable: - Added reset_rag fixture to ensure clean state between tests - Simplified test cases to focus on core functionality - Removed complex mocking in favor of simpler integration tests - Improved test organization and readability * build: bump gptme-rag to 0.3.0 Required for the new custom index path and collection features in RAGManager. --- gptme.toml | 3 + gptme/init.py | 12 +- gptme/tools/_rag_context.py | 10 +- gptme/tools/rag.py | 48 +++-- poetry.lock | 10 +- pyproject.toml | 2 +- tests/test_tools_rag.py | 347 ++++-------------------------------- 7 files changed, 84 insertions(+), 348 deletions(-) diff --git a/gptme.toml b/gptme.toml index b417d8ed..2edb4cfe 100644 --- a/gptme.toml +++ b/gptme.toml @@ -1,2 +1,5 @@ files = ["README.md", "Makefile"] #files = ["README.md", "Makefile", "gptme/cli.py", "docs/*.rst", "docs/*.md"] + +[rag] +enabled = true diff --git a/gptme/init.py b/gptme/init.py index d6b86a43..c4a0e6f0 100644 --- a/gptme/init.py +++ b/gptme/init.py @@ -71,15 +71,25 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None) def init_logging(verbose): # log init + handler = RichHandler() logging.basicConfig( level=logging.DEBUG if verbose else logging.INFO, format="%(message)s", datefmt="[%X]", - handlers=[RichHandler()], + handlers=[handler], ) # set httpx logging to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) + # Register cleanup handler + import atexit + + def cleanup_logging(): + logging.getLogger().removeHandler(handler) + logging.shutdown() + + atexit.register(cleanup_logging) + def _prompt_api_key() -> tuple[str, str, str]: # pragma: no cover api_key = input("Your OpenAI, Anthropic, or OpenRouter API key: ").strip() diff --git a/gptme/tools/_rag_context.py b/gptme/tools/_rag_context.py index 3ac94d5c..3c825ba9 100644 --- a/gptme/tools/_rag_context.py +++ b/gptme/tools/_rag_context.py @@ -4,9 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import ( - Any, -) +from typing import Any from ..config import get_project_config from ..message import Message @@ -73,10 +71,8 @@ def __init__(self, index_path: Path | None = None, collection: str | None = None self.config = config.rag if config and config.rag else {} # Use config values if not overridden by parameters - self.index_path = Path( - index_path or self.config.get("index_path", "~/.cache/gptme/rag") - ).expanduser() - self.collection = collection or self.config.get("collection", "default") + self.index_path = index_path or Path("~/.cache/gptme/rag").expanduser() + self.collection = collection or "default" # Initialize the indexer self.indexer = gptme_rag.Indexer( diff --git a/gptme/tools/rag.py b/gptme/tools/rag.py index d256c6b1..63388cef 100644 --- a/gptme/tools/rag.py +++ b/gptme/tools/rag.py @@ -14,40 +14,26 @@ Configure RAG in your ``gptme.toml``:: [rag] - # Storage configuration - index_path = "~/.cache/gptme/rag" # Where to store the index - collection = "gptme_docs" # Collection name for documents - - # Context enhancement settings - max_tokens = 2000 # Maximum tokens for context window - auto_context = true # Enable automatic context enhancement - min_relevance = 0.5 # Minimum relevance score for including context - max_results = 5 # Maximum number of results to consider - - # Cache configuration - [rag.cache] - max_embeddings = 10000 # Maximum number of cached embeddings - max_searches = 1000 # Maximum number of cached search results - embedding_ttl = 86400 # Embedding cache TTL in seconds (24h) - search_ttl = 3600 # Search cache TTL in seconds (1h) + enabled = true .. rubric:: Features 1. Manual Search and Indexing + - Index project documentation with ``rag_index`` - Search indexed documents with ``rag_search`` - Check index status with ``rag_status`` 2. Automatic Context Enhancement + - Automatically adds relevant context to user messages - Retrieves semantically similar documents - Manages token budget to avoid context overflow - Preserves conversation flow with hidden context messages 3. Performance Optimization + - Intelligent caching system for embeddings and search results - - Configurable cache sizes and TTLs - - Automatic cache invalidation - Memory-efficient storage .. rubric:: Benefits @@ -59,12 +45,13 @@ """ import logging +from dataclasses import replace from pathlib import Path from ..config import get_project_config from ..util import get_project_dir +from ._rag_context import _HAS_RAG, RAGManager from .base import ToolSpec, ToolUse -from ._rag_context import RAGManager, _HAS_RAG logger = logging.getLogger(__name__) @@ -119,20 +106,31 @@ def rag_status() -> str: return f"Index contains {rag_manager.get_document_count()} documents" +_init_run = False + + def init() -> ToolSpec: """Initialize the RAG tool.""" - if not _HAS_RAG: + global _init_run + if _init_run: + return tool + _init_run = True + + if not tool.available: return tool project_dir = get_project_dir() - index_path = Path("~/.cache/gptme/rag").expanduser() - collection = "default" if project_dir and (config := get_project_config(project_dir)): - index_path = Path(config.rag.get("index_path", index_path)).expanduser() - collection = config.rag.get("collection", project_dir.name) + enabled = config.rag.get("enabled", False) + if not enabled: + logger.debug("RAG not enabled in the project configuration") + return replace(tool, available=False) + else: + logger.debug("Project configuration not found, not enabling") + return replace(tool, available=False) global rag_manager - rag_manager = RAGManager(index_path=index_path, collection=collection) + rag_manager = RAGManager() return tool diff --git a/poetry.lock b/poetry.lock index 88383237..1d57ef33 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -1098,13 +1098,13 @@ files = [ [[package]] name = "gptme-rag" -version = "0.2.1" +version = "0.3.0" description = "RAG implementation for gptme context management" optional = true python-versions = "<4.0,>=3.10" files = [ - {file = "gptme_rag-0.2.1-py3-none-any.whl", hash = "sha256:ada534d91200bdaf7341e24de7ca82bf5f76071e9c155003b3dd0abe16290598"}, - {file = "gptme_rag-0.2.1.tar.gz", hash = "sha256:bae2b60e14e3a7a4c71dd2a9dac13abc7306a00902f8cd28d11395395fc82adc"}, + {file = "gptme_rag-0.3.0-py3-none-any.whl", hash = "sha256:76e0b8ffb0367971b33024815644c2d0c6db2922424d2e81093ef4db855e87e5"}, + {file = "gptme_rag-0.3.0.tar.gz", hash = "sha256:72318084c134236080f83e72d7eb62dff6274b1594c81c68587b4577b7a9cb40"}, ] [package.dependencies] @@ -5184,4 +5184,4 @@ server = ["flask", "flask-cors"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "fdb4c7c81ec59e0a80bb1e0f36a9d49934b460cc335dcbe613f37d2ed9e7e5e2" +content-hash = "2fa48920c632e5f420cfcc3f0d37b68ce630e9b6941a9b8f2acf091640cd568c" diff --git a/pyproject.toml b/pyproject.toml index 2a33f00c..439a2daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ youtube_transcript_api = {version = "^0.6.1", optional = true} python-xlib = {version = "^0.33", optional = true} # for X11 interaction # RAG -gptme-rag = {version = "^0.2.1", optional = true} +gptme-rag = {version = "^0.3.0", optional = true} #gptme-rag = {path = "../gptme-rag", optional = true, develop = true} # providers diff --git a/tests/test_tools_rag.py b/tests/test_tools_rag.py index 7196c49b..45a215b7 100644 --- a/tests/test_tools_rag.py +++ b/tests/test_tools_rag.py @@ -1,31 +1,23 @@ -"""Tests for the RAG tool and context enhancement functionality.""" +"""Tests for the RAG tool.""" -from dataclasses import replace -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest from gptme.message import Message -from gptme.tools._rag_context import ( - Context, - RAGManager, - _clear_cache, - _get_search_results, - enhance_messages, -) -from gptme.tools.base import ToolSpec -from gptme.tools.rag import _HAS_RAG -from gptme.tools.rag import init as init_rag -from gptme.tools.rag import rag_index, rag_search, rag_status +from gptme.tools._rag_context import enhance_messages +from gptme.tools.rag import _HAS_RAG, init as init_rag, rag_index, rag_search -pytest.importorskip("gptme_rag") - -# Fixtures +@pytest.fixture(autouse=True) +def reset_rag(): + """Reset the RAG manager and init state before and after each test.""" + import gptme.tools.rag -@pytest.fixture(scope="function") -def index_path(tmp_path): - """Create a temporary index path.""" - return tmp_path + gptme.tools.rag.rag_manager = None + gptme.tools.rag._init_run = False + yield + gptme.tools.rag.rag_manager = None + gptme.tools.rag._init_run = False @pytest.fixture(scope="function") @@ -40,322 +32,59 @@ def temp_docs(tmp_path): return tmp_path -@pytest.fixture -def mock_rag_manager(index_path): - """Create a mock RAG manager that returns test contexts.""" - with patch("gptme.tools._rag_context.gptme_rag"): - manager = RAGManager( - index_path=index_path, - collection="test", - ) - # Create mock documents - mock_docs = [ - Mock( - content="This is a test document about Python functions.", - metadata={"source": "doc1.md"}, - ), - Mock( - content="Documentation about testing practices.", - metadata={"source": "doc2.md"}, - ), - ] - mock_results = {"distances": [[0.2, 0.4]]} # 1 - distance = relevance - - # Mock the indexer's search method - manager.indexer.search = Mock(return_value=(mock_docs, mock_results)) - - # Mock get_context to return actual Context objects - test_contexts = [ - Context( - content="This is a test document about Python functions.", - source="doc1.md", - relevance=0.8, - ), - Context( - content="Documentation about testing practices.", - source="doc2.md", - relevance=0.6, - ), - ] - manager.get_context = Mock(return_value=test_contexts) # type: ignore - - return manager - - -@pytest.fixture -def mock_rag_manager_no_context(mock_rag_manager): - """Create a RAG manager that returns no context.""" - mock_rag_manager.get_context = Mock(return_value=[]) - return mock_rag_manager - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear the search cache before each test.""" - _clear_cache() - - -# RAG Tool Tests - - -@pytest.mark.timeout(func_only=True) -@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_tool_init(): - """Test RAG tool initialization.""" - tool = init_rag() - assert isinstance(tool, ToolSpec) - assert tool.name == "rag" - assert tool.available is True - - def test_rag_tool_init_without_gptme_rag(): """Test RAG tool initialization when gptme-rag is not available.""" - tool = init_rag() with ( patch("gptme.tools.rag._HAS_RAG", False), - patch("gptme.tools.rag.tool", replace(tool, available=False)), + patch("gptme.tools.rag.get_project_config") as mock_config, ): + # Mock config to disable RAG + mock_config.return_value.rag = {"enabled": False} + tool = init_rag() - assert isinstance(tool, ToolSpec) assert tool.name == "rag" assert tool.available is False -@pytest.mark.slow -@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_index_function(temp_docs, index_path, tmp_path): - """Test the index function.""" - with ( - patch("gptme.tools.rag.get_project_config") as mock_config, - patch("gptme.tools.rag.get_project_dir") as mock_project_dir, - ): - # Mock project dir to return the temp path - mock_project_dir.return_value = tmp_path - - # Mock config to return an object with a proper .get method - class MockConfig: - def __init__(self): - self.rag = {"index_path": str(index_path), "collection": tmp_path.name} - - def get(self, key, default=None): - return self.rag.get(key, default) - - mock_config.return_value = MockConfig() - - # Initialize RAG - init_rag() - - # Test indexing with specific path - result = rag_index(str(temp_docs)) - assert "Indexed 1 paths" in result - - # Test indexing with default path - # FIXME: this is really slow in the gptme directory, - # since it contains a lot of files (which are in gitignore, but not respected) - result = rag_index(glob="**/*.py") - assert "Indexed 1 paths" in result - - -@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_search_function(temp_docs, index_path, tmp_path): - """Test the search function.""" - with ( - patch("gptme.tools.rag.get_project_config") as mock_config, - patch("gptme.tools.rag.get_project_dir") as mock_project_dir, - ): - # Mock project dir to return the temp path - mock_project_dir.return_value = tmp_path - - # Mock config to return an object with a proper .get method - class MockConfig: - def __init__(self): - self.rag = {"index_path": str(index_path), "collection": tmp_path.name} - - def get(self, key, default=None): - return self.rag.get(key, default) - - mock_config.return_value = MockConfig() - - # Initialize RAG and index documents - init_rag() - rag_index(str(temp_docs)) - - # Search for Python - result = rag_search("Python") - assert "doc1.md" in result - assert "Python functions" in result - - # Search for testing - result = rag_search("testing") - assert "doc2.md" in result - assert "testing practices" in result - - @pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_status_function(temp_docs, index_path, tmp_path): - """Test the status function.""" +def test_rag_tool_functionality(temp_docs): + """Test basic RAG tool functionality.""" with ( patch("gptme.tools.rag.get_project_config") as mock_config, patch("gptme.tools.rag.get_project_dir") as mock_project_dir, ): - # Mock project dir to return the temp path - mock_project_dir.return_value = tmp_path + # Mock project dir to return a path + mock_project_dir.return_value = temp_docs - # Mock config to return an object with a proper .get method + # Mock config to enable RAG class MockConfig: def __init__(self): - self.rag = {"index_path": str(index_path), "collection": tmp_path.name} - - def get(self, key, default=None): - return self.rag.get(key, default) + self.rag = {"enabled": True} mock_config.return_value = MockConfig() # Initialize RAG - init_rag() - - # Check initial status - result = rag_status() - assert "Index contains" in result - assert "0" in result - - # Index documents and check status again - rag_index(str(temp_docs)) - result = rag_status() - assert "Index contains" in result - assert "2" in result # Should have indexed 2 documents - - -# Context Enhancement Tests - - -def test_search_caching(mock_rag_manager): - """Test that search results are properly cached.""" - query = "test query" - n_results = 5 - - # First search should use the manager - docs, results = mock_rag_manager.search(query, n_results) - assert mock_rag_manager.indexer.search.call_count == 1 - - # Cache should be populated - cached = _get_search_results(query, n_results) - assert cached is not None - assert cached == (docs, results) - - # Second search should use cache - docs2, results2 = mock_rag_manager.search(query, n_results) - assert mock_rag_manager.indexer.search.call_count == 1 # No additional calls - assert (docs2, results2) == (docs, results) - - -def test_enhance_messages_with_context(mock_rag_manager): - """Test that messages are enhanced with context.""" - with patch("gptme.tools._rag_context.RAGManager", return_value=mock_rag_manager): - messages = [ - Message("system", "Initial system message"), - Message("user", "Tell me about Python functions"), - Message("assistant", "Here's what I know about functions..."), - ] - - enhanced = enhance_messages(messages) - - # Should have one extra message for the context - assert len(enhanced) == 4 + tool = init_rag() + assert tool.available is True - # Check that context was added before the user message - assert enhanced[0].role == "system" # Original system message - assert enhanced[1].role == "system" # Added context - assert "Relevant context:" in enhanced[1].content - assert "doc1.md" in enhanced[1].content - assert "doc2.md" in enhanced[1].content - assert enhanced[1].hide is True # Context should be hidden + # Test indexing + index_result = rag_index(str(temp_docs)) + assert "Indexed" in index_result - # Original messages should remain unchanged - assert enhanced[2].role == "user" - assert enhanced[3].role == "assistant" + # Test searching + search_result = rag_search("test document") + assert "test document" in search_result.lower() def test_enhance_messages_no_rag(): """Test that enhancement works even without RAG available.""" - with patch("gptme.tools._rag_context._HAS_RAG", False): - messages = [ - Message("user", "Tell me about Python"), - Message("assistant", "Python is a programming language"), - ] - - enhanced = enhance_messages(messages) - - # Should be unchanged when RAG is not available - assert len(enhanced) == len(messages) - assert enhanced == messages - - -def test_enhance_messages_error_handling(mock_rag_manager): - """Test that errors in context enhancement are handled gracefully.""" - mock_rag_manager.get_context.side_effect = Exception("Test error") - - with patch("gptme.tools._rag_context.RAGManager", return_value=mock_rag_manager): - messages = [ - Message("user", "Tell me about Python"), - Message("assistant", "Python is great"), - ] - - # Should not raise an exception - enhanced = enhance_messages(messages) - - # Messages should be unchanged when enhancement fails - assert len(enhanced) == len(messages) - assert enhanced == messages - - -def test_rag_manager_initialization(): - """Test RAG manager initialization with and without gptme-rag.""" - # Test when gptme-rag is not available - with patch("gptme.tools._rag_context._HAS_RAG", False): - with pytest.raises(ImportError): - RAGManager() - - # Test when gptme-rag is available - with patch("gptme.tools._rag_context._HAS_RAG", True): - with patch("gptme.tools._rag_context.gptme_rag") as mock_rag: - manager = RAGManager() - assert isinstance(manager, RAGManager) - mock_rag.Indexer.assert_called_once() - - -def test_get_context_with_relevance_filter(mock_rag_manager): - """Test that get_context properly filters by relevance.""" - with patch("gptme.tools._rag_context.RAGManager", return_value=mock_rag_manager): - # Create test contexts with different relevance scores - contexts = [ - Context(content="High relevance", source="high.md", relevance=0.8), - Context(content="Low relevance", source="low.md", relevance=0.4), - ] - - # Mock get_context directly instead of search - mock_rag_manager.get_context = Mock( - return_value=[ctx for ctx in contexts if ctx.relevance >= 0.7] - ) - - messages = [Message("user", "test query")] - enhanced = enhance_messages(messages) - - # Should only include the high relevance context - assert len(enhanced) == 2 # Original message + context message - assert "high.md" in enhanced[0].content - assert "low.md" not in enhanced[0].content - - -def test_auto_context_disabled(mock_rag_manager): - """Test that context enhancement respects auto_context setting.""" - mock_rag_manager.auto_context = False - mock_rag_manager.get_context = Mock(return_value=[]) # Should not be called + messages = [ + Message("user", "Tell me about Python"), + Message("assistant", "Python is a programming language"), + ] - messages = [Message("user", "Tell me about Python")] enhanced = enhance_messages(messages) - # No context should be added when auto_context is False - assert len(enhanced) == 1 - assert enhanced[0].role == "user" - assert not mock_rag_manager.get_context.called + # Should be unchanged when RAG is not available + assert len(enhanced) == len(messages) + assert enhanced == messages