Skip to content

Commit

Permalink
call storage.search in user context search instead of memory.search
Browse files Browse the repository at this point in the history
  • Loading branch information
burnerlee committed Dec 3, 2024
1 parent f8a8e7b commit 5d1d8ee
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 1 deletion.
144 changes: 144 additions & 0 deletions src/crewai/mcp/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import asyncio
import threading
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from crewai.tools.base_tool import BaseTool
from pydantic import BaseModel, Field, create_model, ConfigDict
from typing import Type, Dict, Any, Union
from contextlib import AsyncExitStack

def create_pydantic_model_from_dict(model_name: str, schema_dict: Dict[str, Any]) -> Type[BaseModel]:
fields = {}
type_mapping = {
'string': str,
'number': float,
'integer': int,
'boolean': bool,
'object': dict,
'array': list,
}
properties = schema_dict.get('properties', {})
required_fields = schema_dict.get('required', [])
for field_name, field_info in properties.items():
json_type = field_info.get('type', 'string')
python_type = type_mapping.get(json_type, Any)
description = field_info.get('description', '')
default = field_info.get('default', ...)
is_required = field_name in required_fields
if not is_required:
python_type = Union[python_type, None]
default = None if default is ... else default
field = (python_type, Field(default, description=description))
fields[field_name] = field
model = create_model(model_name, **fields)
return model

class AsyncioEventLoopThread(threading.Thread):
def __init__(self):
super().__init__()
self.loop = asyncio.new_event_loop()
self._stop_event = threading.Event()
def run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
def stop(self):
self.loop.call_soon_threadsafe(self.loop.stop)
self._stop_event.set()
def schedule_coroutine(self, coro):
return asyncio.run_coroutine_threadsafe(coro, self.loop)

class MCPClient:
def __init__(self, server_params: StdioServerParameters, loop_thread: AsyncioEventLoopThread):
self.server_params = server_params
self.loop_thread = loop_thread
self.initialized = False
self.client_session = None
self.read = None
self.write = None
self._init_future = None
self._exit_stack = AsyncExitStack()
self._init_future = self.loop_thread.schedule_coroutine(self._async_init())
async def _async_init(self):
await self._exit_stack.__aenter__()
self.stdio_client = stdio_client(self.server_params)
self.read, self.write = await self._exit_stack.enter_async_context(self.stdio_client)
self.client_session = ClientSession(self.read, self.write)
await self._exit_stack.enter_async_context(self.client_session)
await self.client_session.initialize()
self.initialized = True
def call_tool(self, tool_name: str, tool_input: dict = None):
future = self.loop_thread.schedule_coroutine(self._call_tool_async(tool_name, tool_input))
return future.result()
async def _call_tool_async(self, tool_name: str, tool_input: dict = None):
if not self.initialized:
await asyncio.wrap_future(self._init_future)
return await self.client_session.call_tool(tool_name, tool_input)
def close(self):
future = self.loop_thread.schedule_coroutine(self._async_close())
future.result()
async def _async_close(self):
await self._exit_stack.aclose()
self.initialized = False

class MCPTool(BaseTool):
name: str
description: str
args_schema: Type[BaseModel]
client: 'MCPClient'
def __init__(self, name: str, description: str, args_schema: Type[BaseModel], client: 'MCPClient'):
self.name = name
self.description = description
self.args_schema = args_schema
self.client = client
def _run(self, **kwargs):
validated_inputs = self.args_schema(**kwargs)
result = self.client.call_tool(self.name, validated_inputs.dict())
return result
model_config = ConfigDict(arbitrary_types_allowed=True)

def initialise_tools_sync(client: MCPClient):
future = client.loop_thread.schedule_coroutine(initialise_tools(client))
return future.result()

async def initialise_tools(client: MCPClient):
if not client.initialized:
await asyncio.wrap_future(client._init_future)
tools_list = await client.client_session.list_tools()
available_tools = [tool.model_dump() for tool in tools_list]
mcp_tools = []
for tool in available_tools:
mcp_tools.append(MCPTool(
name=tool['name'],
description=tool['description'],
args_schema=create_pydantic_model_from_dict(f"{tool['name']}Input", tool['inputSchema']),
client=client
))
return mcp_tools

class MCPStdioServerParams(StdioServerParameters):
command: str
args: list[str] = []
env: dict[str, str] = None

def get_persistent_mcp_client(params: StdioServerParameters):
loop_thread = AsyncioEventLoopThread()
loop_thread.start()
client = MCPClient(params, loop_thread)
return client, loop_thread

if __name__ == "__main__":
params = MCPStdioServerParams(
command="/opt/homebrew/bin/npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/Users/burnerlee/Projects/dashwave/nucleon"]
)
client, loop_thread = get_persistent_mcp_client(params)
try:
tools = initialise_tools_sync(client)
selected_tool = tools[0]
result = selected_tool._run(
path="/Users/burnerlee/Projects/dashwave/nucleon/README.md"
)
print(result)
finally:
client.close()
loop_thread.stop()
2 changes: 1 addition & 1 deletion src/crewai/memory/user/user_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def search(
limit: int = 3,
score_threshold: float = 0.35,
):
results = super().search(
results = self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
Expand Down

0 comments on commit 5d1d8ee

Please sign in to comment.