Skip to content

Commit

Permalink
refactor(rag): simplify configuration and initialization (#266)
Browse files Browse the repository at this point in the history
* refactor(rag): simplify configuration and initialization

- Reduce configuration complexity to just an enable flag
- Use default paths and collection names
- Add initialization state tracking
- Improve code organization and documentation

* fix: add proper logging cleanup to prevent errors during shutdown

* feat: add custom index path and collection support to RAGManager

Makes RAGManager more flexible by allowing custom index paths and collection names
to be specified during initialization, while maintaining backward compatibility
with default values.

- Added index_path and collection parameters to RAGManager.__init__
- Maintains defaults (~/.cache/gptme/rag and 'default' respectively)
- Allows for better testing and customization of RAG functionality

* test: simplify RAG tests

Major refactoring of RAG tests to be more focused and maintainable:
- Added reset_rag fixture to ensure clean state between tests
- Simplified test cases to focus on core functionality
- Removed complex mocking in favor of simpler integration tests
- Improved test organization and readability

* build: bump gptme-rag to 0.3.0

Required for the new custom index path and collection features in RAGManager.
  • Loading branch information
ErikBjare authored Nov 22, 2024
1 parent 2ade667 commit 84e5ab6
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 348 deletions.
3 changes: 3 additions & 0 deletions gptme.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
files = ["README.md", "Makefile"]
#files = ["README.md", "Makefile", "gptme/cli.py", "docs/*.rst", "docs/*.md"]

[rag]
enabled = true
12 changes: 11 additions & 1 deletion gptme/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,25 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None)

def init_logging(verbose):
# log init
handler = RichHandler()
logging.basicConfig(
level=logging.DEBUG if verbose else logging.INFO,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler()],
handlers=[handler],
)
# set httpx logging to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)

# Register cleanup handler
import atexit

def cleanup_logging():
logging.getLogger().removeHandler(handler)
logging.shutdown()

atexit.register(cleanup_logging)


def _prompt_api_key() -> tuple[str, str, str]: # pragma: no cover
api_key = input("Your OpenAI, Anthropic, or OpenRouter API key: ").strip()
Expand Down
10 changes: 3 additions & 7 deletions gptme/tools/_rag_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
)
from typing import Any

from ..config import get_project_config
from ..message import Message
Expand Down Expand Up @@ -73,10 +71,8 @@ def __init__(self, index_path: Path | None = None, collection: str | None = None
self.config = config.rag if config and config.rag else {}

# Use config values if not overridden by parameters
self.index_path = Path(
index_path or self.config.get("index_path", "~/.cache/gptme/rag")
).expanduser()
self.collection = collection or self.config.get("collection", "default")
self.index_path = index_path or Path("~/.cache/gptme/rag").expanduser()
self.collection = collection or "default"

# Initialize the indexer
self.indexer = gptme_rag.Indexer(
Expand Down
48 changes: 23 additions & 25 deletions gptme/tools/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,26 @@
Configure RAG in your ``gptme.toml``::
[rag]
# Storage configuration
index_path = "~/.cache/gptme/rag" # Where to store the index
collection = "gptme_docs" # Collection name for documents
# Context enhancement settings
max_tokens = 2000 # Maximum tokens for context window
auto_context = true # Enable automatic context enhancement
min_relevance = 0.5 # Minimum relevance score for including context
max_results = 5 # Maximum number of results to consider
# Cache configuration
[rag.cache]
max_embeddings = 10000 # Maximum number of cached embeddings
max_searches = 1000 # Maximum number of cached search results
embedding_ttl = 86400 # Embedding cache TTL in seconds (24h)
search_ttl = 3600 # Search cache TTL in seconds (1h)
enabled = true
.. rubric:: Features
1. Manual Search and Indexing
- Index project documentation with ``rag_index``
- Search indexed documents with ``rag_search``
- Check index status with ``rag_status``
2. Automatic Context Enhancement
- Automatically adds relevant context to user messages
- Retrieves semantically similar documents
- Manages token budget to avoid context overflow
- Preserves conversation flow with hidden context messages
3. Performance Optimization
- Intelligent caching system for embeddings and search results
- Configurable cache sizes and TTLs
- Automatic cache invalidation
- Memory-efficient storage
.. rubric:: Benefits
Expand All @@ -59,12 +45,13 @@
"""

import logging
from dataclasses import replace
from pathlib import Path

from ..config import get_project_config
from ..util import get_project_dir
from ._rag_context import _HAS_RAG, RAGManager
from .base import ToolSpec, ToolUse
from ._rag_context import RAGManager, _HAS_RAG

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -119,20 +106,31 @@ def rag_status() -> str:
return f"Index contains {rag_manager.get_document_count()} documents"


_init_run = False


def init() -> ToolSpec:
"""Initialize the RAG tool."""
if not _HAS_RAG:
global _init_run
if _init_run:
return tool
_init_run = True

if not tool.available:
return tool

project_dir = get_project_dir()
index_path = Path("~/.cache/gptme/rag").expanduser()
collection = "default"
if project_dir and (config := get_project_config(project_dir)):
index_path = Path(config.rag.get("index_path", index_path)).expanduser()
collection = config.rag.get("collection", project_dir.name)
enabled = config.rag.get("enabled", False)
if not enabled:
logger.debug("RAG not enabled in the project configuration")
return replace(tool, available=False)
else:
logger.debug("Project configuration not found, not enabling")
return replace(tool, available=False)

global rag_manager
rag_manager = RAGManager(index_path=index_path, collection=collection)
rag_manager = RAGManager()
return tool


Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ youtube_transcript_api = {version = "^0.6.1", optional = true}
python-xlib = {version = "^0.33", optional = true} # for X11 interaction

# RAG
gptme-rag = {version = "^0.2.1", optional = true}
gptme-rag = {version = "^0.3.0", optional = true}
#gptme-rag = {path = "../gptme-rag", optional = true, develop = true}

# providers
Expand Down
Loading

0 comments on commit 84e5ab6

Please sign in to comment.