diff --git a/tests/test_tools_rag.py b/tests/test_tools_rag.py index 45a215b7..55edc740 100644 --- a/tests/test_tools_rag.py +++ b/tests/test_tools_rag.py @@ -2,21 +2,24 @@ from unittest.mock import patch +import gptme.tools._rag_context +import gptme.tools.rag import pytest from gptme.message import Message from gptme.tools._rag_context import enhance_messages -from gptme.tools.rag import _HAS_RAG, init as init_rag, rag_index, rag_search +from gptme.tools.rag import _HAS_RAG +from gptme.tools.rag import init as init_rag +from gptme.tools.rag import rag_index, rag_search @pytest.fixture(autouse=True) def reset_rag(): """Reset the RAG manager and init state before and after each test.""" - import gptme.tools.rag - gptme.tools.rag.rag_manager = None + gptme.tools._rag_context._rag_manager = None gptme.tools.rag._init_run = False yield - gptme.tools.rag.rag_manager = None + gptme.tools._rag_context._rag_manager = None gptme.tools.rag._init_run = False