-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
call storage.search in user context search instead of memory.search
- Loading branch information
Showing
2 changed files
with
145 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import asyncio | ||
import threading | ||
from mcp import ClientSession, StdioServerParameters | ||
from mcp.client.stdio import stdio_client | ||
from crewai.tools.base_tool import BaseTool | ||
from pydantic import BaseModel, Field, create_model, ConfigDict | ||
from typing import Type, Dict, Any, Union | ||
from contextlib import AsyncExitStack | ||
|
||
def create_pydantic_model_from_dict(model_name: str, schema_dict: Dict[str, Any]) -> Type[BaseModel]: | ||
fields = {} | ||
type_mapping = { | ||
'string': str, | ||
'number': float, | ||
'integer': int, | ||
'boolean': bool, | ||
'object': dict, | ||
'array': list, | ||
} | ||
properties = schema_dict.get('properties', {}) | ||
required_fields = schema_dict.get('required', []) | ||
for field_name, field_info in properties.items(): | ||
json_type = field_info.get('type', 'string') | ||
python_type = type_mapping.get(json_type, Any) | ||
description = field_info.get('description', '') | ||
default = field_info.get('default', ...) | ||
is_required = field_name in required_fields | ||
if not is_required: | ||
python_type = Union[python_type, None] | ||
default = None if default is ... else default | ||
field = (python_type, Field(default, description=description)) | ||
fields[field_name] = field | ||
model = create_model(model_name, **fields) | ||
return model | ||
|
||
class AsyncioEventLoopThread(threading.Thread): | ||
def __init__(self): | ||
super().__init__() | ||
self.loop = asyncio.new_event_loop() | ||
self._stop_event = threading.Event() | ||
def run(self): | ||
asyncio.set_event_loop(self.loop) | ||
self.loop.run_forever() | ||
def stop(self): | ||
self.loop.call_soon_threadsafe(self.loop.stop) | ||
self._stop_event.set() | ||
def schedule_coroutine(self, coro): | ||
return asyncio.run_coroutine_threadsafe(coro, self.loop) | ||
|
||
class MCPClient: | ||
def __init__(self, server_params: StdioServerParameters, loop_thread: AsyncioEventLoopThread): | ||
self.server_params = server_params | ||
self.loop_thread = loop_thread | ||
self.initialized = False | ||
self.client_session = None | ||
self.read = None | ||
self.write = None | ||
self._init_future = None | ||
self._exit_stack = AsyncExitStack() | ||
self._init_future = self.loop_thread.schedule_coroutine(self._async_init()) | ||
async def _async_init(self): | ||
await self._exit_stack.__aenter__() | ||
self.stdio_client = stdio_client(self.server_params) | ||
self.read, self.write = await self._exit_stack.enter_async_context(self.stdio_client) | ||
self.client_session = ClientSession(self.read, self.write) | ||
await self._exit_stack.enter_async_context(self.client_session) | ||
await self.client_session.initialize() | ||
self.initialized = True | ||
def call_tool(self, tool_name: str, tool_input: dict = None): | ||
future = self.loop_thread.schedule_coroutine(self._call_tool_async(tool_name, tool_input)) | ||
return future.result() | ||
async def _call_tool_async(self, tool_name: str, tool_input: dict = None): | ||
if not self.initialized: | ||
await asyncio.wrap_future(self._init_future) | ||
return await self.client_session.call_tool(tool_name, tool_input) | ||
def close(self): | ||
future = self.loop_thread.schedule_coroutine(self._async_close()) | ||
future.result() | ||
async def _async_close(self): | ||
await self._exit_stack.aclose() | ||
self.initialized = False | ||
|
||
class MCPTool(BaseTool): | ||
name: str | ||
description: str | ||
args_schema: Type[BaseModel] | ||
client: 'MCPClient' | ||
def __init__(self, name: str, description: str, args_schema: Type[BaseModel], client: 'MCPClient'): | ||
self.name = name | ||
self.description = description | ||
self.args_schema = args_schema | ||
self.client = client | ||
def _run(self, **kwargs): | ||
validated_inputs = self.args_schema(**kwargs) | ||
result = self.client.call_tool(self.name, validated_inputs.dict()) | ||
return result | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
def initialise_tools_sync(client: MCPClient): | ||
future = client.loop_thread.schedule_coroutine(initialise_tools(client)) | ||
return future.result() | ||
|
||
async def initialise_tools(client: MCPClient): | ||
if not client.initialized: | ||
await asyncio.wrap_future(client._init_future) | ||
tools_list = await client.client_session.list_tools() | ||
available_tools = [tool.model_dump() for tool in tools_list] | ||
mcp_tools = [] | ||
for tool in available_tools: | ||
mcp_tools.append(MCPTool( | ||
name=tool['name'], | ||
description=tool['description'], | ||
args_schema=create_pydantic_model_from_dict(f"{tool['name']}Input", tool['inputSchema']), | ||
client=client | ||
)) | ||
return mcp_tools | ||
|
||
class MCPStdioServerParams(StdioServerParameters): | ||
command: str | ||
args: list[str] = [] | ||
env: dict[str, str] = None | ||
|
||
def get_persistent_mcp_client(params: StdioServerParameters): | ||
loop_thread = AsyncioEventLoopThread() | ||
loop_thread.start() | ||
client = MCPClient(params, loop_thread) | ||
return client, loop_thread | ||
|
||
if __name__ == "__main__": | ||
params = MCPStdioServerParams( | ||
command="/opt/homebrew/bin/npx", | ||
args=["-y", "@modelcontextprotocol/server-filesystem", "/Users/burnerlee/Projects/dashwave/nucleon"] | ||
) | ||
client, loop_thread = get_persistent_mcp_client(params) | ||
try: | ||
tools = initialise_tools_sync(client) | ||
selected_tool = tools[0] | ||
result = selected_tool._run( | ||
path="/Users/burnerlee/Projects/dashwave/nucleon/README.md" | ||
) | ||
print(result) | ||
finally: | ||
client.close() | ||
loop_thread.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters