diff --git a/gptme.toml b/gptme.toml index b417d8ed..2edb4cfe 100644 --- a/gptme.toml +++ b/gptme.toml @@ -1,2 +1,5 @@ files = ["README.md", "Makefile"] #files = ["README.md", "Makefile", "gptme/cli.py", "docs/*.rst", "docs/*.md"] + +[rag] +enabled = true diff --git a/gptme/init.py b/gptme/init.py index d6b86a43..c4a0e6f0 100644 --- a/gptme/init.py +++ b/gptme/init.py @@ -71,15 +71,25 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None) def init_logging(verbose): # log init + handler = RichHandler() logging.basicConfig( level=logging.DEBUG if verbose else logging.INFO, format="%(message)s", datefmt="[%X]", - handlers=[RichHandler()], + handlers=[handler], ) # set httpx logging to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) + # Register cleanup handler + import atexit + + def cleanup_logging(): + logging.getLogger().removeHandler(handler) + logging.shutdown() + + atexit.register(cleanup_logging) + def _prompt_api_key() -> tuple[str, str, str]: # pragma: no cover api_key = input("Your OpenAI, Anthropic, or OpenRouter API key: ").strip() diff --git a/gptme/tools/_rag_context.py b/gptme/tools/_rag_context.py index 3ac94d5c..3c825ba9 100644 --- a/gptme/tools/_rag_context.py +++ b/gptme/tools/_rag_context.py @@ -4,9 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import ( - Any, -) +from typing import Any from ..config import get_project_config from ..message import Message @@ -73,10 +71,8 @@ def __init__(self, index_path: Path | None = None, collection: str | None = None self.config = config.rag if config and config.rag else {} # Use config values if not overridden by parameters - self.index_path = Path( - index_path or self.config.get("index_path", "~/.cache/gptme/rag") - ).expanduser() - self.collection = collection or self.config.get("collection", "default") + self.index_path = index_path or Path("~/.cache/gptme/rag").expanduser() + self.collection = collection or "default" # Initialize the indexer self.indexer = gptme_rag.Indexer( diff --git a/gptme/tools/rag.py b/gptme/tools/rag.py index d256c6b1..63388cef 100644 --- a/gptme/tools/rag.py +++ b/gptme/tools/rag.py @@ -14,40 +14,26 @@ Configure RAG in your ``gptme.toml``:: [rag] - # Storage configuration - index_path = "~/.cache/gptme/rag" # Where to store the index - collection = "gptme_docs" # Collection name for documents - - # Context enhancement settings - max_tokens = 2000 # Maximum tokens for context window - auto_context = true # Enable automatic context enhancement - min_relevance = 0.5 # Minimum relevance score for including context - max_results = 5 # Maximum number of results to consider - - # Cache configuration - [rag.cache] - max_embeddings = 10000 # Maximum number of cached embeddings - max_searches = 1000 # Maximum number of cached search results - embedding_ttl = 86400 # Embedding cache TTL in seconds (24h) - search_ttl = 3600 # Search cache TTL in seconds (1h) + enabled = true .. rubric:: Features 1. Manual Search and Indexing + - Index project documentation with ``rag_index`` - Search indexed documents with ``rag_search`` - Check index status with ``rag_status`` 2. Automatic Context Enhancement + - Automatically adds relevant context to user messages - Retrieves semantically similar documents - Manages token budget to avoid context overflow - Preserves conversation flow with hidden context messages 3. Performance Optimization + - Intelligent caching system for embeddings and search results - - Configurable cache sizes and TTLs - - Automatic cache invalidation - Memory-efficient storage .. rubric:: Benefits @@ -59,12 +45,13 @@ """ import logging +from dataclasses import replace from pathlib import Path from ..config import get_project_config from ..util import get_project_dir +from ._rag_context import _HAS_RAG, RAGManager from .base import ToolSpec, ToolUse -from ._rag_context import RAGManager, _HAS_RAG logger = logging.getLogger(__name__) @@ -119,20 +106,31 @@ def rag_status() -> str: return f"Index contains {rag_manager.get_document_count()} documents" +_init_run = False + + def init() -> ToolSpec: """Initialize the RAG tool.""" - if not _HAS_RAG: + global _init_run + if _init_run: + return tool + _init_run = True + + if not tool.available: return tool project_dir = get_project_dir() - index_path = Path("~/.cache/gptme/rag").expanduser() - collection = "default" if project_dir and (config := get_project_config(project_dir)): - index_path = Path(config.rag.get("index_path", index_path)).expanduser() - collection = config.rag.get("collection", project_dir.name) + enabled = config.rag.get("enabled", False) + if not enabled: + logger.debug("RAG not enabled in the project configuration") + return replace(tool, available=False) + else: + logger.debug("Project configuration not found, not enabling") + return replace(tool, available=False) global rag_manager - rag_manager = RAGManager(index_path=index_path, collection=collection) + rag_manager = RAGManager() return tool diff --git a/poetry.lock b/poetry.lock index 88383237..1d57ef33 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -1098,13 +1098,13 @@ files = [ [[package]] name = "gptme-rag" -version = "0.2.1" +version = "0.3.0" description = "RAG implementation for gptme context management" optional = true python-versions = "<4.0,>=3.10" files = [ - {file = "gptme_rag-0.2.1-py3-none-any.whl", hash = "sha256:ada534d91200bdaf7341e24de7ca82bf5f76071e9c155003b3dd0abe16290598"}, - {file = "gptme_rag-0.2.1.tar.gz", hash = "sha256:bae2b60e14e3a7a4c71dd2a9dac13abc7306a00902f8cd28d11395395fc82adc"}, + {file = "gptme_rag-0.3.0-py3-none-any.whl", hash = "sha256:76e0b8ffb0367971b33024815644c2d0c6db2922424d2e81093ef4db855e87e5"}, + {file = "gptme_rag-0.3.0.tar.gz", hash = "sha256:72318084c134236080f83e72d7eb62dff6274b1594c81c68587b4577b7a9cb40"}, ] [package.dependencies] @@ -5184,4 +5184,4 @@ server = ["flask", "flask-cors"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "fdb4c7c81ec59e0a80bb1e0f36a9d49934b460cc335dcbe613f37d2ed9e7e5e2" +content-hash = "2fa48920c632e5f420cfcc3f0d37b68ce630e9b6941a9b8f2acf091640cd568c" diff --git a/pyproject.toml b/pyproject.toml index 2a33f00c..439a2daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ youtube_transcript_api = {version = "^0.6.1", optional = true} python-xlib = {version = "^0.33", optional = true} # for X11 interaction # RAG -gptme-rag = {version = "^0.2.1", optional = true} +gptme-rag = {version = "^0.3.0", optional = true} #gptme-rag = {path = "../gptme-rag", optional = true, develop = true} # providers diff --git a/tests/test_tools_rag.py b/tests/test_tools_rag.py index 7196c49b..45a215b7 100644 --- a/tests/test_tools_rag.py +++ b/tests/test_tools_rag.py @@ -1,31 +1,23 @@ -"""Tests for the RAG tool and context enhancement functionality.""" +"""Tests for the RAG tool.""" -from dataclasses import replace -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest from gptme.message import Message -from gptme.tools._rag_context import ( - Context, - RAGManager, - _clear_cache, - _get_search_results, - enhance_messages, -) -from gptme.tools.base import ToolSpec -from gptme.tools.rag import _HAS_RAG -from gptme.tools.rag import init as init_rag -from gptme.tools.rag import rag_index, rag_search, rag_status +from gptme.tools._rag_context import enhance_messages +from gptme.tools.rag import _HAS_RAG, init as init_rag, rag_index, rag_search -pytest.importorskip("gptme_rag") - -# Fixtures +@pytest.fixture(autouse=True) +def reset_rag(): + """Reset the RAG manager and init state before and after each test.""" + import gptme.tools.rag -@pytest.fixture(scope="function") -def index_path(tmp_path): - """Create a temporary index path.""" - return tmp_path + gptme.tools.rag.rag_manager = None + gptme.tools.rag._init_run = False + yield + gptme.tools.rag.rag_manager = None + gptme.tools.rag._init_run = False @pytest.fixture(scope="function") @@ -40,322 +32,59 @@ def temp_docs(tmp_path): return tmp_path -@pytest.fixture -def mock_rag_manager(index_path): - """Create a mock RAG manager that returns test contexts.""" - with patch("gptme.tools._rag_context.gptme_rag"): - manager = RAGManager( - index_path=index_path, - collection="test", - ) - # Create mock documents - mock_docs = [ - Mock( - content="This is a test document about Python functions.", - metadata={"source": "doc1.md"}, - ), - Mock( - content="Documentation about testing practices.", - metadata={"source": "doc2.md"}, - ), - ] - mock_results = {"distances": [[0.2, 0.4]]} # 1 - distance = relevance - - # Mock the indexer's search method - manager.indexer.search = Mock(return_value=(mock_docs, mock_results)) - - # Mock get_context to return actual Context objects - test_contexts = [ - Context( - content="This is a test document about Python functions.", - source="doc1.md", - relevance=0.8, - ), - Context( - content="Documentation about testing practices.", - source="doc2.md", - relevance=0.6, - ), - ] - manager.get_context = Mock(return_value=test_contexts) # type: ignore - - return manager - - -@pytest.fixture -def mock_rag_manager_no_context(mock_rag_manager): - """Create a RAG manager that returns no context.""" - mock_rag_manager.get_context = Mock(return_value=[]) - return mock_rag_manager - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear the search cache before each test.""" - _clear_cache() - - -# RAG Tool Tests - - -@pytest.mark.timeout(func_only=True) -@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_tool_init(): - """Test RAG tool initialization.""" - tool = init_rag() - assert isinstance(tool, ToolSpec) - assert tool.name == "rag" - assert tool.available is True - - def test_rag_tool_init_without_gptme_rag(): """Test RAG tool initialization when gptme-rag is not available.""" - tool = init_rag() with ( patch("gptme.tools.rag._HAS_RAG", False), - patch("gptme.tools.rag.tool", replace(tool, available=False)), + patch("gptme.tools.rag.get_project_config") as mock_config, ): + # Mock config to disable RAG + mock_config.return_value.rag = {"enabled": False} + tool = init_rag() - assert isinstance(tool, ToolSpec) assert tool.name == "rag" assert tool.available is False -@pytest.mark.slow -@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_index_function(temp_docs, index_path, tmp_path): - """Test the index function.""" - with ( - patch("gptme.tools.rag.get_project_config") as mock_config, - patch("gptme.tools.rag.get_project_dir") as mock_project_dir, - ): - # Mock project dir to return the temp path - mock_project_dir.return_value = tmp_path - - # Mock config to return an object with a proper .get method - class MockConfig: - def __init__(self): - self.rag = {"index_path": str(index_path), "collection": tmp_path.name} - - def get(self, key, default=None): - return self.rag.get(key, default) - - mock_config.return_value = MockConfig() - - # Initialize RAG - init_rag() - - # Test indexing with specific path - result = rag_index(str(temp_docs)) - assert "Indexed 1 paths" in result - - # Test indexing with default path - # FIXME: this is really slow in the gptme directory, - # since it contains a lot of files (which are in gitignore, but not respected) - result = rag_index(glob="**/*.py") - assert "Indexed 1 paths" in result - - -@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_search_function(temp_docs, index_path, tmp_path): - """Test the search function.""" - with ( - patch("gptme.tools.rag.get_project_config") as mock_config, - patch("gptme.tools.rag.get_project_dir") as mock_project_dir, - ): - # Mock project dir to return the temp path - mock_project_dir.return_value = tmp_path - - # Mock config to return an object with a proper .get method - class MockConfig: - def __init__(self): - self.rag = {"index_path": str(index_path), "collection": tmp_path.name} - - def get(self, key, default=None): - return self.rag.get(key, default) - - mock_config.return_value = MockConfig() - - # Initialize RAG and index documents - init_rag() - rag_index(str(temp_docs)) - - # Search for Python - result = rag_search("Python") - assert "doc1.md" in result - assert "Python functions" in result - - # Search for testing - result = rag_search("testing") - assert "doc2.md" in result - assert "testing practices" in result - - @pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_status_function(temp_docs, index_path, tmp_path): - """Test the status function.""" +def test_rag_tool_functionality(temp_docs): + """Test basic RAG tool functionality.""" with ( patch("gptme.tools.rag.get_project_config") as mock_config, patch("gptme.tools.rag.get_project_dir") as mock_project_dir, ): - # Mock project dir to return the temp path - mock_project_dir.return_value = tmp_path + # Mock project dir to return a path + mock_project_dir.return_value = temp_docs - # Mock config to return an object with a proper .get method + # Mock config to enable RAG class MockConfig: def __init__(self): - self.rag = {"index_path": str(index_path), "collection": tmp_path.name} - - def get(self, key, default=None): - return self.rag.get(key, default) + self.rag = {"enabled": True} mock_config.return_value = MockConfig() # Initialize RAG - init_rag() - - # Check initial status - result = rag_status() - assert "Index contains" in result - assert "0" in result - - # Index documents and check status again - rag_index(str(temp_docs)) - result = rag_status() - assert "Index contains" in result - assert "2" in result # Should have indexed 2 documents - - -# Context Enhancement Tests - - -def test_search_caching(mock_rag_manager): - """Test that search results are properly cached.""" - query = "test query" - n_results = 5 - - # First search should use the manager - docs, results = mock_rag_manager.search(query, n_results) - assert mock_rag_manager.indexer.search.call_count == 1 - - # Cache should be populated - cached = _get_search_results(query, n_results) - assert cached is not None - assert cached == (docs, results) - - # Second search should use cache - docs2, results2 = mock_rag_manager.search(query, n_results) - assert mock_rag_manager.indexer.search.call_count == 1 # No additional calls - assert (docs2, results2) == (docs, results) - - -def test_enhance_messages_with_context(mock_rag_manager): - """Test that messages are enhanced with context.""" - with patch("gptme.tools._rag_context.RAGManager", return_value=mock_rag_manager): - messages = [ - Message("system", "Initial system message"), - Message("user", "Tell me about Python functions"), - Message("assistant", "Here's what I know about functions..."), - ] - - enhanced = enhance_messages(messages) - - # Should have one extra message for the context - assert len(enhanced) == 4 + tool = init_rag() + assert tool.available is True - # Check that context was added before the user message - assert enhanced[0].role == "system" # Original system message - assert enhanced[1].role == "system" # Added context - assert "Relevant context:" in enhanced[1].content - assert "doc1.md" in enhanced[1].content - assert "doc2.md" in enhanced[1].content - assert enhanced[1].hide is True # Context should be hidden + # Test indexing + index_result = rag_index(str(temp_docs)) + assert "Indexed" in index_result - # Original messages should remain unchanged - assert enhanced[2].role == "user" - assert enhanced[3].role == "assistant" + # Test searching + search_result = rag_search("test document") + assert "test document" in search_result.lower() def test_enhance_messages_no_rag(): """Test that enhancement works even without RAG available.""" - with patch("gptme.tools._rag_context._HAS_RAG", False): - messages = [ - Message("user", "Tell me about Python"), - Message("assistant", "Python is a programming language"), - ] - - enhanced = enhance_messages(messages) - - # Should be unchanged when RAG is not available - assert len(enhanced) == len(messages) - assert enhanced == messages - - -def test_enhance_messages_error_handling(mock_rag_manager): - """Test that errors in context enhancement are handled gracefully.""" - mock_rag_manager.get_context.side_effect = Exception("Test error") - - with patch("gptme.tools._rag_context.RAGManager", return_value=mock_rag_manager): - messages = [ - Message("user", "Tell me about Python"), - Message("assistant", "Python is great"), - ] - - # Should not raise an exception - enhanced = enhance_messages(messages) - - # Messages should be unchanged when enhancement fails - assert len(enhanced) == len(messages) - assert enhanced == messages - - -def test_rag_manager_initialization(): - """Test RAG manager initialization with and without gptme-rag.""" - # Test when gptme-rag is not available - with patch("gptme.tools._rag_context._HAS_RAG", False): - with pytest.raises(ImportError): - RAGManager() - - # Test when gptme-rag is available - with patch("gptme.tools._rag_context._HAS_RAG", True): - with patch("gptme.tools._rag_context.gptme_rag") as mock_rag: - manager = RAGManager() - assert isinstance(manager, RAGManager) - mock_rag.Indexer.assert_called_once() - - -def test_get_context_with_relevance_filter(mock_rag_manager): - """Test that get_context properly filters by relevance.""" - with patch("gptme.tools._rag_context.RAGManager", return_value=mock_rag_manager): - # Create test contexts with different relevance scores - contexts = [ - Context(content="High relevance", source="high.md", relevance=0.8), - Context(content="Low relevance", source="low.md", relevance=0.4), - ] - - # Mock get_context directly instead of search - mock_rag_manager.get_context = Mock( - return_value=[ctx for ctx in contexts if ctx.relevance >= 0.7] - ) - - messages = [Message("user", "test query")] - enhanced = enhance_messages(messages) - - # Should only include the high relevance context - assert len(enhanced) == 2 # Original message + context message - assert "high.md" in enhanced[0].content - assert "low.md" not in enhanced[0].content - - -def test_auto_context_disabled(mock_rag_manager): - """Test that context enhancement respects auto_context setting.""" - mock_rag_manager.auto_context = False - mock_rag_manager.get_context = Mock(return_value=[]) # Should not be called + messages = [ + Message("user", "Tell me about Python"), + Message("assistant", "Python is a programming language"), + ] - messages = [Message("user", "Tell me about Python")] enhanced = enhance_messages(messages) - # No context should be added when auto_context is False - assert len(enhanced) == 1 - assert enhanced[0].role == "user" - assert not mock_rag_manager.get_context.called + # Should be unchanged when RAG is not available + assert len(enhanced) == len(messages) + assert enhanced == messages