Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support importing tools from LangChain #1745

Merged
merged 9 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion memgpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.24"
__version__ = "0.4.0"

from memgpt.client.admin import Admin
from memgpt.client.client import create_client
76 changes: 75 additions & 1 deletion memgpt/client/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import time
from typing import Dict, Generator, List, Optional, Union

import requests

import memgpt.utils
from memgpt.config import MemGPTConfig
from memgpt.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.data_sources.connectors import DataConnector
Expand Down Expand Up @@ -1292,6 +1294,10 @@ def __init__(
self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface_factory=lambda: self.interface)

# set logging levels
memgpt.utils.DEBUG = debug
logging.getLogger().setLevel(logging.CRITICAL)

# create user if does not exist
existing_user = self.server.get_user(self.user_id)
if not existing_user:
Expand Down Expand Up @@ -1617,7 +1623,10 @@ def send_message(
messages = self.interface.to_list()
for m in messages:
assert isinstance(m, Message), f"Expected Message object, got {type(m)}"
return MemGPTResponse(messages=messages, usage=usage)
memgpt_messages = []
for m in messages:
memgpt_messages += m.to_memgpt_message()
return MemGPTResponse(messages=memgpt_messages, usage=usage)

def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
"""
Expand Down Expand Up @@ -2177,3 +2186,68 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
models (List[EmbeddingConfig]): List of embedding models
"""
return [self.server.server_embedding_config]

def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
"""
List available blocks

Args:
label (str): Label of the block
templates_only (bool): List only templates

Returns:
blocks (List[Block]): List of blocks
"""
return self.server.get_blocks(label=label, template=templates_only)

def create_block(self, name: str, text: str, label: Optional[str] = None) -> Block: #
"""
Create a block

Args:
label (str): Label of the block
name (str): Name of the block
text (str): Text of the block

Returns:
block (Block): Created block
"""
return self.server.create_block(CreateBlock(label=label, name=name, value=text, user_id=self.user_id), user_id=self.user_id)

def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
"""
Update a block

Args:
block_id (str): ID of the block
name (str): Name of the block
text (str): Text of the block

Returns:
block (Block): Updated block
"""
return self.server.update_block(UpdateBlock(id=block_id, name=name, value=text))

def get_block(self, block_id: str) -> Block:
"""
Get a block

Args:
block_id (str): ID of the block

Returns:
block (Block): Block
"""
return self.server.get_block(block_id)

def delete_block(self, id: str) -> Block:
"""
Delete a block

Args:
id (str): ID of the block

Returns:
block (Block): Deleted block
"""
return self.server.delete_block(id)
30 changes: 28 additions & 2 deletions memgpt/functions/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,15 @@ def generate_schema_from_args_schema(
properties = {}
required = []
for field_name, field in args_schema.__fields__.items():
properties[field_name] = {"type": field.type_.__name__, "description": field.field_info.description}
if field.type_.__name__ == "str":
field_type = "string"
elif field.type_.__name__ == "int":
field_type = "integer"
elif field.type_.__name__ == "bool":
field_type = "boolean"
else:
field_type = field.type_.__name__
properties[field_name] = {"type": field_type, "description": field.field_info.description}
if field.required:
required.append(field_name)

Expand All @@ -158,7 +166,24 @@ def generate_schema_from_args_schema(
return function_call_json


def generate_tool_wrapper(tool_name: str) -> str:
def generate_langchain_tool_wrapper(tool_name: str) -> str:
import_statement = f"from langchain_community.tools import {tool_name}"
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"

# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
del kwargs['self']
{import_statement}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str


def generate_crewai_tool_wrapper(tool_name: str) -> str:
import_statement = f"from crewai_tools import {tool_name}"
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
Expand All @@ -167,6 +192,7 @@ def generate_tool_wrapper(tool_name: str) -> str:
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
del kwargs['self']
{import_statement}
{tool_instantiation}
{run_call}
Expand Down
1 change: 1 addition & 0 deletions memgpt/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
from memgpt.utils import json_dumps

LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"]

Expand Down
19 changes: 16 additions & 3 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,17 @@ def to_record(self) -> Block:
user_id=self.user_id,
)
else:
raise ValueError(f"Block with label {self.label} is not supported")
return Block(
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
description=self.description,
user_id=self.user_id,
)


class ToolModel(Base):
Expand Down Expand Up @@ -725,13 +735,13 @@ def get_blocks(
self,
user_id: Optional[str],
label: Optional[str] = None,
template: bool = True,
template: Optional[bool] = None,
name: Optional[str] = None,
id: Optional[str] = None,
) -> List[Block]:
"""List available blocks"""
with self.session_maker() as session:
query = session.query(BlockModel).filter(BlockModel.template == template)
query = session.query(BlockModel)

if user_id:
query = query.filter(BlockModel.user_id == user_id)
Expand All @@ -745,6 +755,9 @@ def get_blocks(
if id:
query = query.filter(BlockModel.id == id)

if template:
query = query.filter(BlockModel.template == template)

results = query.all()

if len(results) == 0:
Expand Down
46 changes: 44 additions & 2 deletions memgpt/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from pydantic import Field

from memgpt.functions.schema_generator import (
generate_crewai_tool_wrapper,
generate_langchain_tool_wrapper,
generate_schema_from_args_schema,
generate_tool_wrapper,
)
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.openai.chat_completions import ToolCall
Expand Down Expand Up @@ -56,6 +57,40 @@ def to_dict(self):
)
)

@classmethod
def from_langchain(cls, langchain_tool) -> "Tool":
"""
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).

Args:
langchain_tool (LangchainTool): An instance of a crewAI BaseTool (BaseTool from crewai)

Returns:
Tool: A MemGPT Tool initialized with attributes derived from the provided crewAI BaseTool object.
"""
description = langchain_tool.description
source_type = "python"
tags = ["langchain"]
# NOTE: langchain tools may come from different packages
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool.__class__.__name__)
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)

# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")

return cls(
name=wrapper_func_name,
description=description,
source_type=source_type,
tags=tags,
source_code=wrapper_function_str,
json_schema=json_schema,
)

@classmethod
def from_crewai(cls, crewai_tool) -> "Tool":
"""
Expand All @@ -71,9 +106,16 @@ def from_crewai(cls, crewai_tool) -> "Tool":
description = crewai_tool.description
source_type = "python"
tags = ["crew-ai"]
wrapper_func_name, wrapper_function_str = generate_tool_wrapper(crewai_tool.__class__.__name__)
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool.__class__.__name__)
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)

# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")

return cls(
name=wrapper_func_name,
description=description,
Expand Down
4 changes: 2 additions & 2 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def get_blocks(
self,
user_id: Optional[str] = None,
label: Optional[str] = None,
template: Optional[bool] = True,
template: Optional[bool] = None,
name: Optional[str] = None,
id: Optional[str] = None,
):
Expand All @@ -900,7 +900,7 @@ def get_block(self, block_id: str):

blocks = self.get_blocks(id=block_id)
if blocks is None or len(blocks) == 0:
return None
raise ValueError("Block does not exist")
if len(blocks) > 1:
raise ValueError("Multiple blocks with the same id")
return blocks[0]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pymemgpt"
version = "0.3.24"
version = "0.4.0"
packages = [
{include = "memgpt"}
]
Expand Down