Skip to content

Commit

Permalink
feat: simplify rag tool by simply calling gptme-rag via subprocess (#316
Browse files Browse the repository at this point in the history
)

* feat: simplify rag tool by simply calling gptme-rag via subprocess if installed

* build(deps): got rid of gptme-rag dependency

* test: fixed tests for simplified rag tool

* test: fixed docstring in test
  • Loading branch information
ErikBjare authored Dec 10, 2024
1 parent a2f06df commit 8757108
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 2,096 deletions.
195 changes: 0 additions & 195 deletions gptme/tools/_rag_context.py

This file was deleted.

128 changes: 82 additions & 46 deletions gptme/tools/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
.. rubric:: Installation
The RAG tool requires the ``gptme-rag`` package. Install it with::
The RAG tool requires the ``gptme-rag`` CLI to be installed::
pip install "gptme[rag]"
pipx install gptme-rag
.. rubric:: Configuration
Expand All @@ -19,38 +19,25 @@
.. rubric:: Features
1. Manual Search and Indexing
- Index project documentation with ``rag_index``
- Search indexed documents with ``rag_search``
- Check index status with ``rag_status``
2. Automatic Context Enhancement
- Automatically adds relevant context to user messages
- Retrieves semantically similar documents
- Manages token budget to avoid context overflow
- Preserves conversation flow with hidden context messages
3. Performance Optimization
- Intelligent caching system for embeddings and search results
- Memory-efficient storage
.. rubric:: Benefits
- Better informed responses through relevant documentation
- Reduced need for manual context inclusion
- Automatic token management
- Seamless integration with conversation flow
"""

import logging
import shutil
import subprocess
from dataclasses import replace
from functools import lru_cache
from pathlib import Path

from ..config import get_project_config
from ..message import Message
from ..util import get_project_dir
from ._rag_context import _HAS_RAG, get_rag_manager
from .base import ToolSpec, ToolUse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,45 +67,59 @@ def examples(tool_format):
"""


@lru_cache
def _has_gptme_rag() -> bool:
"""Check if gptme-rag is available in PATH."""
return shutil.which("gptme-rag") is not None


def _run_rag_cmd(cmd: list[str]) -> subprocess.CompletedProcess:
"""Run a gptme-rag command and handle errors."""
try:
return subprocess.run(cmd, capture_output=True, text=True, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"gptme-rag command failed: {e.stderr}")
raise RuntimeError(f"gptme-rag command failed: {e.stderr}") from e


def rag_index(*paths: str, glob: str | None = None) -> str:
"""Index documents in specified paths."""
manager = get_rag_manager()
paths = paths or (".",)
kwargs = {"glob_pattern": glob} if glob else {}
total_docs = 0
for path in paths:
total_docs += manager.index_directory(Path(path), **kwargs)
return f"Indexed {len(paths)} paths ({total_docs} documents)"
cmd = ["gptme-rag", "index"]
cmd.extend(paths)
if glob:
cmd.extend(["--glob", glob])

result = _run_rag_cmd(cmd)
return result.stdout.strip()


def rag_search(query: str) -> str:
def rag_search(query: str, return_full: bool = False) -> str:
"""Search indexed documents."""
manager = get_rag_manager()
docs, _ = manager.search(query)
return "\n\n".join(
f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs
)
cmd = ["gptme-rag", "search", query]
if return_full:
# shows full context of the search results
cmd.append("--show-context")

result = _run_rag_cmd(cmd)
return result.stdout.strip()


def rag_status() -> str:
"""Show index status."""
manager = get_rag_manager()
return f"Index contains {manager.get_document_count()} documents"


_init_run = False
cmd = ["gptme-rag", "status"]
result = _run_rag_cmd(cmd)
return result.stdout.strip()


def init() -> ToolSpec:
"""Initialize the RAG tool."""
global _init_run
if _init_run:
return tool
_init_run = True

if not tool.available:
return tool
# Check if gptme-rag CLI is available
if not _has_gptme_rag():
logger.debug("gptme-rag CLI not found in PATH")
return replace(tool, available=False)

# Check project configuration
project_dir = get_project_dir()
if project_dir and (config := get_project_config(project_dir)):
enabled = config.rag.get("enabled", False)
Expand All @@ -129,18 +130,53 @@ def init() -> ToolSpec:
logger.debug("Project configuration not found, not enabling")
return replace(tool, available=False)

# Initialize the shared RAG manager
get_rag_manager()
return tool


def rag_enhance_messages(messages: list[Message]) -> list[Message]:
"""Enhance messages with context from RAG."""
if not _has_gptme_rag():
return messages

# Load config
config = get_project_config(Path.cwd())
rag_config = config.rag if config and config.rag else {}

if not rag_config.get("enabled", False):
return messages

enhanced_messages = []
for msg in messages:
if msg.role == "user":
try:
# Get context using gptme-rag CLI
cmd = ["gptme-rag", "search", msg.content, "--show-context"]
if max_tokens := rag_config.get("max_tokens"):
cmd.extend(["--max-tokens", str(max_tokens)])
if min_relevance := rag_config.get("min_relevance"):
cmd.extend(["--min-relevance", str(min_relevance)])
enhanced_messages.append(
Message(
role="system",
content=f"Relevant context:\n\n{_run_rag_cmd(cmd).stdout}",
hide=True,
)
)
except Exception as e:
logger.warning(f"Error getting context: {e}")

enhanced_messages.append(msg)

return enhanced_messages


tool = ToolSpec(
name="rag",
desc="RAG (Retrieval-Augmented Generation) for context-aware assistance",
instructions=instructions,
examples=examples,
functions=[rag_index, rag_search, rag_status],
available=_HAS_RAG,
available=_has_gptme_rag(),
init=init,
)

Expand Down
Loading

0 comments on commit 8757108

Please sign in to comment.