From 7078371405a12bbb07d17f7e0a36bb59561e195b Mon Sep 17 00:00:00 2001 From: Jeremie Pardou <571533+jrmi@users.noreply.github.com> Date: Sun, 15 Dec 2024 16:23:22 +0100 Subject: [PATCH] feat: allow to load tools from external modules --- docs/config.rst | 5 ++ docs/custom_tool.rst | 83 +++++++++++++++++ docs/index.rst | 1 + gptme/chat.py | 32 +++++-- gptme/cli.py | 12 +-- gptme/commands.py | 7 +- gptme/config.py | 4 +- gptme/llm/llm_openai.py | 2 +- gptme/prompts.py | 6 +- gptme/server/api.py | 3 +- gptme/tools/__init__.py | 182 ++++++++++++++++++++++---------------- gptme/tools/base.py | 37 ++++---- gptme/tools/computer.py | 1 + gptme/tools/patch.py | 5 +- gptme/tools/python.py | 8 ++ gptme/tools/subagent.py | 1 + gptme/util/ask_execute.py | 2 +- gptme/util/cli.py | 12 +-- tests/conftest.py | 8 ++ tests/test_prompts.py | 2 +- tests/test_tool_use.py | 3 +- tests/test_tools.py | 143 ++++++++++++++++++++++++++++++ 22 files changed, 423 insertions(+), 136 deletions(-) create mode 100644 docs/custom_tool.rst create mode 100644 tests/test_tools.py diff --git a/docs/config.rst b/docs/config.rst index 836380b0..c4685010 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -42,6 +42,11 @@ Here is an example: #MODEL = "local/" #OPENAI_BASE_URL = "http://localhost:11434/v1" + # Uncomment to change tool configuration + #TOOL_FORMAT = "markdown" # Select the tool formal. One of `markdown`, `xml`, `tool` + #TOOL_ALLOW_LIST = "save,append,patch,python" # Comma separated list of allowed tools + #TOOL_MODULES = "gptme.tools,custom.tools" # List of python comma separated python module path + The ``prompt`` section contains options for the prompt. The ``env`` section contains environment variables that gptme will fall back to if they are not set in the shell environment. This is useful for setting the default model and API keys for :doc:`providers`. diff --git a/docs/custom_tool.rst b/docs/custom_tool.rst new file mode 100644 index 00000000..6f2810b9 --- /dev/null +++ b/docs/custom_tool.rst @@ -0,0 +1,83 @@ +Creating a Custom Tool for gptme +================================= + +Introduction +------------ +In gptme, a custom tool allows you to extend the functionality of the assistant by +defining new tools that can be executed. +This guide will walk you through the process of creating and registering a custom tool. + +Creating a Custom Tool +----------------------- +To create a custom tool, you need to define a new instance of the `ToolSpec` class. +This class requires several parameters: + +- **name**: The name of the tool. +- **desc**: A description of what the tool does. +- **instructions**: Instructions on how to use the tool. +- **examples**: Example usage of the tool. +- **execute**: A function that defines the tool's behavior when executed. +- **block_types**: The block types to detects. +- **parameters**: A list of parameters that the tool accepts. + +Here is a basic example of defining a custom tool: + +.. code-block:: python + + import random + from gptme.tools import ToolSpec, Parameter, ToolUse + from gptme.message import Message + + def execute(code, args, kwargs, confirm): + + if code is None and kwargs is not None: + code = kwargs.get('side_count') + + yield Message('system', f"Result: {random.randint(1,code)}") + + def examples(tool_format): + return f""" + > User: Throw a dice and give me the result. + > Assistant: + {ToolUse("dice", [], "6").to_output(tool_format)} + > System: 3 + > assistant: The result is 3 + """.strip() + + tool = ToolSpec( + name="dice", + desc="A dice simulator.", + instructions="This tool generate a random integer value like a dice.", + examples=examples, + execute=execute, + block_types=["dice"], + parameters=[ + Parameter( + name="side_count", + type="integer", + description="The number of faces of the dice to throw.", + required=True, + ), + ], + ) + +Registering the Tool +--------------------- +To ensure your tool is available for use, you can specify the module in the `TOOL_MODULES` env variable or +setting in your :doc:`project configuration file `, which will automatically load your custom tools. + +.. code-block:: toml + + TOOL_MODULES = "gptme.tools,path.to.your.custom_tool_module" + +Don't remove the `gptme.tools` package unless you know exactly what you are doing. + +Ensure your module is in the Python path by either installing it (e.g., with `pip install .`) or +by temporarily modifying the `PYTHONPATH` environment variable. For example: + +.. code-block:: bash + + export PYTHONPATH=$PYTHONPATH:/path/to/your/module + + +This lets Python locate your module during development and testing without requiring installation. diff --git a/docs/index.rst b/docs/index.rst index 34842377..bb037ca4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -37,6 +37,7 @@ See the `README `_ fil evals bot finetuning + custom_tool arewetiny timeline alternatives diff --git a/gptme/chat.py b/gptme/chat.py index 5f3ff408..32f9506d 100644 --- a/gptme/chat.py +++ b/gptme/chat.py @@ -4,11 +4,13 @@ import re import sys import termios +from typing import cast import urllib.parse from collections.abc import Generator from pathlib import Path from .commands import action_descriptions, execute_cmd +from .config import get_config from .constants import PROMPT_USER from .init import init from .llm import reply @@ -19,11 +21,12 @@ from .tools import ( ToolFormat, ToolUse, - execute_msg, has_tool, - loaded_tools, + get_tools, + execute_msg, + ConfirmFunc, + set_tool_format, ) -from .tools.base import ConfirmFunc from .tools.browser import read_url from .util import console, path_with_tilde, print_bell from .util.ask_execute import ask_execute @@ -46,7 +49,7 @@ def chat( show_hidden: bool = False, workspace: Path | None = None, tool_allowlist: list[str] | None = None, - tool_format: ToolFormat = "markdown", + tool_format: ToolFormat | None = None, ) -> None: """ Run the chat loop. @@ -71,6 +74,15 @@ def chat( console.log(f"Using logdir {path_with_tilde(logdir)}") manager = LogManager.load(logdir, initial_msgs=initial_msgs, create=True) + config = get_config() + tool_format_with_default: ToolFormat = tool_format or cast( + ToolFormat, config.get_env("TOOL_FORMAT", "markdown") + ) + + # By defining the tool_format at the last moment we ensure we can use the + # configuration for subagent + set_tool_format(tool_format_with_default) + # change to workspace directory # use if exists, create if @log, or use given path # TODO: move this into LogManager? then just os.chdir(manager.workspace) @@ -130,8 +142,8 @@ def confirm_func(msg) -> bool: manager.log, stream, confirm_func, - tool_format, - workspace, + tool_format=tool_format_with_default, + workspace=workspace, ) ) except KeyboardInterrupt: @@ -179,7 +191,11 @@ def confirm_func(msg) -> bool: # ask for input if no prompt, generate reply, and run tools clear_interruptible() # Ensure we're not interruptible during user input for msg in step( - manager.log, stream, confirm_func, tool_format, workspace + manager.log, + stream, + confirm_func, + tool_format=tool_format_with_default, + workspace=workspace, ): # pragma: no cover manager.append(msg) # run any user-commands, if msg is from user @@ -224,7 +240,7 @@ def step( tools = None if tool_format == "tool": - tools = [t for t in loaded_tools if t.is_runnable()] + tools = [t for t in get_tools() if t.is_runnable()] # generate response msg_response = reply(msgs, get_model().model, stream, tools) diff --git a/gptme/cli.py b/gptme/cli.py index 8714255d..86843ba7 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -11,9 +11,9 @@ import click from pick import pick -from gptme.config import get_config from .chat import chat +from .config import get_config from .commands import _gen_help from .constants import MULTIPROMPT_SEPARATOR from .dirs import get_logs_dir @@ -22,12 +22,7 @@ from .logmanager import ConversationMeta, get_user_conversations from .message import Message from .prompts import get_prompt -from .tools import ( - ToolFormat, - ToolSpec, - init_tools, - set_tool_format, -) +from .tools import ToolFormat, init_tools, get_available_tools from .util import epoch_to_age from .util.generate_name import generate_name from .util.interrupt import handle_keyboard_interrupt, set_interruptible @@ -39,7 +34,7 @@ script_path = Path(os.path.realpath(__file__)) commands_help = "\n".join(_gen_help(incl_langtags=False)) available_tool_names = ", ".join( - sorted([tool.name for tool in ToolSpec.get_tools().values() if tool.available]) + sorted([tool.name for tool in get_available_tools() if tool.available]) ) @@ -189,7 +184,6 @@ def main( selected_tool_format: ToolFormat = ( tool_format or config.get_env("TOOL_FORMAT") or "markdown" # type: ignore ) - set_tool_format(selected_tool_format) # early init tools to generate system prompt init_tools(frozenset(tool_allowlist) if tool_allowlist else None) diff --git a/gptme/commands.py b/gptme/commands.py index e3ba2afb..0c33cbdd 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -15,8 +15,7 @@ print_msg, toml_to_msgs, ) -from .tools import ToolUse, execute_msg, loaded_tools -from .tools.base import ConfirmFunc, get_tool_format +from .tools import ToolUse, execute_msg, get_tools, ConfirmFunc, get_tool_format from .util.cost import log_costs from .util.export import export_chat_to_html from .util.useredit import edit_text_with_editor @@ -138,7 +137,7 @@ def handle_cmd( case "tools": manager.undo(1, quiet=True) print("Available tools:") - for tool in loaded_tools: + for tool in get_tools(): print( f""" # {tool.name} @@ -220,7 +219,7 @@ def _gen_help(incl_langtags: bool = True) -> Generator[str, None, None]: yield " /python print('hello')" yield "" yield "Supported langtags:" - for tool in loaded_tools: + for tool in get_tools(): if tool.block_types: yield f" - {tool.block_types[0]}" + ( f" (alias: {', '.join(tool.block_types[1:])})" diff --git a/gptme/config.py b/gptme/config.py index 31ffe17f..e2fde247 100644 --- a/gptme/config.py +++ b/gptme/config.py @@ -19,11 +19,11 @@ class Config: env: dict def get_env(self, key: str, default: str | None = None) -> str | None: - """Gets an enviromnent variable, checks the config file if it's not set in the environment.""" + """Gets an environment variable, checks the config file if it's not set in the environment.""" return os.environ.get(key) or self.env.get(key) or default def get_env_required(self, key: str) -> str: - """Gets an enviromnent variable, checks the config file if it's not set in the environment.""" + """Gets an environment variable, checks the config file if it's not set in the environment.""" if val := os.environ.get(key) or self.env.get(key): return val raise KeyError( # pragma: no cover diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py index 129c62d6..38228a22 100644 --- a/gptme/llm/llm_openai.py +++ b/gptme/llm/llm_openai.py @@ -9,7 +9,7 @@ from ..config import Config from ..constants import TEMPERATURE, TOP_P from ..message import Message, msgs2dicts -from ..tools.base import Parameter, ToolSpec, ToolUse +from ..tools import Parameter, ToolSpec, ToolUse from .models import ModelMeta, Provider, get_model if TYPE_CHECKING: diff --git a/gptme/prompts.py b/gptme/prompts.py index c766453b..33c1eb19 100644 --- a/gptme/prompts.py +++ b/gptme/prompts.py @@ -199,14 +199,14 @@ def prompt_tools( examples: bool = True, tool_format: ToolFormat = "markdown" ) -> Generator[Message, None, None]: """Generate the tools overview prompt.""" - from .tools import loaded_tools # fmt: skip + from .tools import get_tools # fmt: skip - assert loaded_tools, "No tools loaded" + assert get_tools(), "No tools loaded" use_tool = tool_format == "tool" prompt = "# Tools aliases" if use_tool else "# Tools Overview" - for tool in loaded_tools: + for tool in get_tools(): if not use_tool or not tool.is_runnable(): prompt += tool.get_tool_prompt(examples, tool_format) diff --git a/gptme/server/api.py b/gptme/server/api.py index 0da6b1a9..ff734f89 100644 --- a/gptme/server/api.py +++ b/gptme/server/api.py @@ -24,8 +24,7 @@ from ..llm.models import get_model from ..logmanager import LogManager, get_user_conversations, prepare_messages from ..message import Message -from ..tools import execute_msg -from ..tools.base import ToolUse +from ..tools import ToolUse, execute_msg logger = logging.getLogger(__name__) diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py index ea99e3a6..750cdd51 100644 --- a/gptme/tools/__init__.py +++ b/gptme/tools/__init__.py @@ -1,32 +1,23 @@ import logging -from collections.abc import Generator from functools import lru_cache +from collections.abc import Generator + +from gptme.config import get_config from ..message import Message from .base import ( - ConfirmFunc, ToolFormat, ToolSpec, ToolUse, + Parameter, + ConfirmFunc, get_tool_format, set_tool_format, ) -from .browser import tool as browser_tool -from .chats import tool as chats_tool -from .computer import tool as computer_tool -from .gh import tool as gh_tool -from .patch import tool as patch_tool -from .python import register_function -from .python import tool as python_tool -from .rag import tool as rag_tool -from .read import tool as read_tool -from .save import tool_append, tool_save -from .screenshot import tool as screenshot_tool -from .shell import tool as shell_tool -from .subagent import tool as subagent_tool -from .tmux import tool as tmux_tool -from .vision import tool as vision_tool -from .youtube import tool as youtube_tool + +import importlib +import pkgutil +import inspect logger = logging.getLogger(__name__) @@ -36,84 +27,90 @@ "ToolSpec", "ToolUse", "ToolFormat", + "Parameter", + "ConfirmFunc", # functions - "execute_msg", "get_tool_format", "set_tool_format", - # files - "read_tool", - "tool_append", - "tool_save", - "patch_tool", - # code - "shell_tool", - "python_tool", - "gh_tool", - # vision and computer use - "vision_tool", - "screenshot_tool", - "computer_tool", - # misc - "chats_tool", - "rag_tool", - "subagent_tool", - "tmux_tool", - "browser_tool", - "youtube_tool", ] -loaded_tools: list[ToolSpec] = [] +from .save import tool_save -# Tools that are disabled by default, unless explicitly enabled -# TODO: find a better way to handle this -tools_default_disabled = [ - "computer", - "subagent", -] +_loaded_tools: list[ToolSpec] = [] +_available_tools: list[ToolSpec] | None = None + + +def _discover_tools(package_names): + """Discover tools in a package or module, given the package name as a string.""" + tools = [] + for package_name in package_names: + try: + # Dynamically import the package or module + package = importlib.import_module(package_name) + except ModuleNotFoundError: + logger.warning("Module or package %s not found", package_name) + continue + + # Check if it's a package or a module + if hasattr(package, "__path__"): # It's a package + # Iterate over modules in the package + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + full_module_name = f"{package_name}.{module_name}" + try: + module = importlib.import_module(full_module_name) + except ModuleNotFoundError: + logger.warning("Missing dependency for module %s", full_module_name) + continue + + # Find instances of ToolSpec in the module + for _, obj in inspect.getmembers( + module, lambda c: isinstance(c, ToolSpec) + ): + tools.append(obj) + else: # It's a single module + # Find instances of ToolSpec in the module + for _, obj in inspect.getmembers( + package, lambda c: isinstance(c, ToolSpec) + ): + tools.append(obj) + + return tools @lru_cache -def init_tools(allowlist: frozenset[str] | None = None) -> None: +def init_tools( + allowlist: frozenset[str] | None = None, +) -> None: """Runs initialization logic for tools.""" - # init python tool last - tools = list( - sorted(ToolSpec.get_tools().values(), key=lambda tool: tool.name != "python") - ) - loaded_tool_names = [tool.name for tool in loaded_tools] - for tool in tools: - if tool.name in loaded_tool_names: + + config = get_config() + + if allowlist is None: + env_allowlist = config.get_env("TOOL_ALLOW_LIST") + if env_allowlist: + allowlist = frozenset(env_allowlist.split(",")) + + for tool in get_available_tools(): + if tool in _loaded_tools: + logger.warning("Tool '%s' already loaded", tool.name) continue if allowlist and tool.name not in allowlist: continue - if tool.init: - tool = tool.init() if not tool.available: continue - if tool in loaded_tools: - continue - if tool.name in tools_default_disabled: + if tool.disabled_by_default: if not allowlist or tool.name not in allowlist: continue - _load_tool(tool) + if tool.init: + tool = tool.init() + + _loaded_tools.append(tool) for tool_name in allowlist or []: if not has_tool(tool_name): raise ValueError(f"Tool '{tool_name}' not found") -def _load_tool(tool: ToolSpec) -> None: - """Loads a tool.""" - if tool in loaded_tools: - logger.warning(f"Tool '{tool.name}' already loaded") - return - - # tool init happens in init_tools to check that spec is available - if tool.functions: - for func in tool.functions: - register_function(func) - loaded_tools.append(tool) - - def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None, None]: """Uses any tools called in a message and returns the response.""" assert msg.role == "assistant", "Only assistant messages can be executed" @@ -128,7 +125,7 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None, @lru_cache def get_tool_for_langtag(lang: str) -> ToolSpec | None: block_type = lang.split(" ")[0] - for tool in loaded_tools: + for tool in _loaded_tools: if block_type in tool.block_types: return tool is_filename = "." in lang or "/" in lang @@ -142,14 +139,45 @@ def is_supported_langtag(lang: str) -> bool: return bool(get_tool_for_langtag(lang)) +def get_available_tools() -> list[ToolSpec]: + global _available_tools + + if _available_tools is None: + # We need to load tools first + config = get_config() + + tool_modules: frozenset[str] = frozenset() + env_tool_modules = config.get_env("TOOL_MODULES", "gptme.tools") + + if env_tool_modules: + tool_modules = frozenset(env_tool_modules.split(",")) + + _available_tools = sorted(_discover_tools(tool_modules)) + + return _available_tools + + +def clear_tools(): + global _available_tools + global _loaded_tools + + _available_tools = None + _loaded_tools = [] + + +def get_tools() -> list[ToolSpec]: + """Returns all loaded tools""" + return _loaded_tools + + def get_tool(tool_name: str) -> ToolSpec | None: """Returns a loaded tool by name or block type.""" # check tool names - for tool in loaded_tools: + for tool in _loaded_tools: if tool.name == tool_name: return tool # check block types - for tool in loaded_tools: + for tool in _loaded_tools: if tool_name in tool.block_types: return tool return None @@ -157,7 +185,7 @@ def get_tool(tool_name: str) -> ToolSpec | None: def has_tool(tool_name: str) -> bool: """Returns True if a tool is loaded.""" - for tool in loaded_tools: + for tool in _loaded_tools: if tool.name == tool_name: return True return False diff --git a/gptme/tools/base.py b/gptme/tools/base.py index 92d18a41..7b5f29c0 100644 --- a/gptme/tools/base.py +++ b/gptme/tools/base.py @@ -20,6 +20,7 @@ import json_repair from lxml import etree + from ..codeblock import Codeblock from ..message import Message from ..util import clean_example, transform_examples_to_chat_directives @@ -144,9 +145,6 @@ def callable_signature(func: Callable) -> str: return f"{func.__name__}({args}){ret}" -_tools: dict[str, "ToolSpec"] = {} - - @dataclass(frozen=True, eq=False) class ToolSpec: """ @@ -156,12 +154,16 @@ class ToolSpec: name: The name of the tool. desc: A description of the tool. instructions: Instructions on how to use the tool. + instructions_format: Per tool format instructions when needed. examples: Example usage of the tool. functions: Functions registered in the IPython REPL. init: An optional function that is called when the tool is first loaded. execute: An optional function that is called when the tool executes a block. block_types: A list of block types that the tool will execute. available: Whether the tool is available for use. + parameters: Descriptor of parameters use by this tool. + load_priority: Influence the loading order of this tool. The higher the later. + disabled_by_default: Whether this tool should be disabled by default. """ name: str @@ -175,18 +177,8 @@ class ToolSpec: block_types: list[str] = field(default_factory=list) available: bool = True parameters: list[Parameter] = field(default_factory=list) - - def __post_init__(self): - global _tools - _tools[self.name] = self - - @classmethod - def get_tool(cls, name: str) -> "ToolSpec | None": - return _tools.get(name) - - @classmethod - def get_tools(cls) -> dict[str, "ToolSpec"]: - return _tools + load_priority: int = 0 + disabled_by_default: bool = False def get_doc(self, doc: str | None = None) -> str: """Returns an updated docstring with examples.""" @@ -215,6 +207,11 @@ def __eq__(self, other): return False return self.name == other.name + def __lt__(self, other): + if not isinstance(other, ToolSpec): + return NotImplemented + return (self.load_priority, self.name) < (other.load_priority, other.name) + def is_runnable(self): return bool(self.execute) @@ -257,7 +254,7 @@ def get_examples(self, tool_format: ToolFormat = "markdown", quote=False): def get_functions_description(self) -> str: # return a prompt with a brief description of the available functions if self.functions: - description = "This tool makes the following Python functions available in `ipython`:\n\n" + description = "The following Python functions are available using the `ipython` tool:\n\n" return description + "\n".join( f"{callable_signature(func)}: {func.__doc__ or 'No description'}" for func in self.functions @@ -503,7 +500,9 @@ def get_path( # TODO: allow using via specifying .py paths with --tools flag def load_from_file(path: Path) -> list[ToolSpec]: """Import a tool from a Python file and register the ToolSpec.""" - tools_before = set(ToolSpec.get_tools().keys()) + from . import get_tools, get_tool + + tools_before = set([t.name for t in get_tools()]) # import the python file script_dir = path.resolve().parent @@ -511,7 +510,7 @@ def load_from_file(path: Path) -> list[ToolSpec]: sys.path.append(str(script_dir)) importlib.import_module(path.stem) - tools_after = set(ToolSpec.get_tools().keys()) + tools_after = set([t.name for t in get_tools()]) tools_new = tools_after - tools_before print(f"Loaded tools {tools_new} from {path}") - return [tool for tool_name in tools_new if (tool := ToolSpec.get_tool(tool_name))] + return [tool for tool_name in tools_new if (tool := get_tool(tool_name))] diff --git a/gptme/tools/computer.py b/gptme/tools/computer.py index 8a628e02..fe67c6e9 100644 --- a/gptme/tools/computer.py +++ b/gptme/tools/computer.py @@ -290,6 +290,7 @@ def examples(tool_format): instructions=instructions, examples=examples, functions=[computer], + disabled_by_default=True, ) __doc__ = tool.get_doc(__doc__) diff --git a/gptme/tools/patch.py b/gptme/tools/patch.py index 6b4f6764..45986246 100644 --- a/gptme/tools/patch.py +++ b/gptme/tools/patch.py @@ -42,7 +42,8 @@ $UPDATED_CONTENT >>>>>>> UPDATED '''.strip()).to_output("markdown")} -""" +""", + "tool": "The `patch` parameter must be a string containing conflict markers without any code block.", } ORIGINAL = "<<<<<<< ORIGINAL\n" @@ -275,7 +276,7 @@ def execute_patch( Parameter( name="patch", type="string", - description="The patch to apply.", + description=f"The patch to apply. Use conflict markers! Example:\n{patch_content}", required=True, ), ], diff --git a/gptme/tools/python.py b/gptme/tools/python.py index 0a829f12..5c8359d3 100644 --- a/gptme/tools/python.py +++ b/gptme/tools/python.py @@ -15,6 +15,7 @@ from logging import getLogger from typing import TYPE_CHECKING, TypeVar +from . import get_tools from ..message import Message from ..util.ask_execute import print_preview from .base import ( @@ -225,6 +226,12 @@ def fib(n): def init() -> ToolSpec: + # Register python functions from other tools + for loaded_tool in get_tools(): + if loaded_tool.functions: + for func in loaded_tool.functions: + register_function(func) + python_libraries = get_installed_python_libraries() python_libraries_str = ( "\n".join(f"- {lib}" for lib in python_libraries) @@ -264,5 +271,6 @@ def init() -> ToolSpec: required=True, ), ], + load_priority=10, ) __doc__ = tool.get_doc(__doc__) diff --git a/gptme/tools/subagent.py b/gptme/tools/subagent.py index 4495df17..aab55e3c 100644 --- a/gptme/tools/subagent.py +++ b/gptme/tools/subagent.py @@ -163,5 +163,6 @@ def examples(tool_format): desc="Create and manage subagents", examples=examples, functions=[subagent, subagent_status, subagent_wait], + disabled_by_default=True, ) __doc__ = tool.get_doc(__doc__) diff --git a/gptme/util/ask_execute.py b/gptme/util/ask_execute.py index ef53a6bf..82b653fe 100644 --- a/gptme/util/ask_execute.py +++ b/gptme/util/ask_execute.py @@ -13,7 +13,7 @@ from rich.syntax import Syntax from ..message import Message -from ..tools.base import ConfirmFunc +from ..tools import ConfirmFunc from . import print_bell from .clipboard import copy, set_copytext from .prompt import get_prompt_session diff --git a/gptme/util/cli.py b/gptme/util/cli.py index d52c5ba3..386f2dcf 100644 --- a/gptme/util/cli.py +++ b/gptme/util/cli.py @@ -134,7 +134,7 @@ def tools(): def tools_list(available: bool, langtags: bool): """List available tools.""" from ..commands import _gen_help # fmt: skip - from ..tools import init_tools, loaded_tools # fmt: skip + from ..tools import init_tools, get_tools # fmt: skip # Initialize tools init_tools() @@ -150,7 +150,7 @@ def tools_list(available: bool, langtags: bool): return print("Available tools:") - for tool in loaded_tools: + for tool in get_tools(): if not available or tool.available: status = "✓" if tool.available else "✗" print( @@ -164,7 +164,7 @@ def tools_list(available: bool, langtags: bool): @click.argument("tool_name") def tools_info(tool_name: str): """Show detailed information about a tool.""" - from ..tools import get_tool, init_tools, loaded_tools # fmt: skip + from ..tools import get_tool, init_tools, get_tools # fmt: skip # Initialize tools init_tools() @@ -172,7 +172,7 @@ def tools_info(tool_name: str): tool = get_tool(tool_name) if not tool: print(f"Tool '{tool_name}' not found. Available tools:") - for t in loaded_tools: + for t in get_tools(): print(f"- {t.name}") sys.exit(1) @@ -197,7 +197,7 @@ def tools_info(tool_name: str): ) def tools_call(tool_name: str, function_name: str, arg: list[str]): """Call a tool with the given arguments.""" - from ..tools import get_tool, init_tools, loaded_tools # fmt: skip + from ..tools import get_tool, init_tools, get_tools # fmt: skip # Initialize tools init_tools() @@ -205,7 +205,7 @@ def tools_call(tool_name: str, function_name: str, arg: list[str]): tool = get_tool(tool_name) if not tool: print(f"Tool '{tool_name}' not found. Available tools:") - for t in loaded_tools: + for t in get_tools(): print(f"- {t.name}") sys.exit(1) diff --git a/tests/conftest.py b/tests/conftest.py index 112febee..5d01c170 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from gptme.tools.rag import _has_gptme_rag +from gptme.tools import clear_tools, init_tools def pytest_sessionstart(session): @@ -28,6 +29,13 @@ def download_model(): ef._download_model_if_not_exists() # type: ignore +@pytest.fixture(autouse=True) +def clear_tools_before(): + # Clear all tools and cache to prevent test conflicts + clear_tools() + init_tools.cache_clear() + + @pytest.fixture def temp_file(): @contextmanager diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 324964c1..17cb6294 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -4,7 +4,7 @@ from gptme.tools import init_tools -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def init(): init_tools() diff --git a/tests/test_tool_use.py b/tests/test_tool_use.py index 7299626d..09f6bba9 100644 --- a/tests/test_tool_use.py +++ b/tests/test_tool_use.py @@ -1,7 +1,7 @@ import json_repair import pytest from gptme.tools import init_tools -from gptme.tools.base import ToolUse, extract_json, toolcall_re +from gptme.tools.base import ToolUse, extract_json, set_tool_format, toolcall_re @pytest.mark.parametrize( @@ -148,5 +148,6 @@ def test_toolcall_regex(content, expected_tool, expected_json): ) def test_toolcall_regex_invalid(content): # No ToolUse should be created for invalid content + set_tool_format("tool") tool_uses = list(ToolUse.iter_from_content(content)) assert len(tool_uses) == 0 diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..41feb6e2 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,143 @@ +import pytest +from unittest.mock import patch +from gptme.tools import ( + _discover_tools, + init_tools, + get_tools, + has_tool, + get_tool, + get_available_tools, + is_supported_langtag, + get_tool_for_langtag, +) + + +def test_init_tools(): + init_tools() + + assert len(get_tools()) > 1 + + +def test_init_tools_allowlist(): + init_tools(allowlist=frozenset(("save",))) + + assert len(get_tools()) == 1 + + assert get_tools()[0].name == "save" + + # let's trigger a tool reloading + init_tools.cache_clear() + + init_tools(allowlist=frozenset(("save",))) + + assert len(get_tools()) == 1 + + +def test_init_tools_allowlist_from_env(): + # Define the behavior for get_env based on the input key + def mock_get_env(key, default=None): + if key == "TOOL_ALLOW_LIST": + return "save,patch" + return default # Return the default value for other keys + + with patch("gptme.tools.get_config") as mock_get_config: + # Mock the get_config function to return a mock object + mock_config = mock_get_config.return_value + # Mock the get_env method to return the custom_env_value + mock_config.get_env.side_effect = mock_get_env + + init_tools() + + assert len(get_tools()) == 2 + + +def test_init_tools_fails(): + with pytest.raises(ValueError): + init_tools(allowlist=frozenset(("save", "missing_tool"))) + + +def test_tool_loading_with_package(): + found = _discover_tools(["gptme.tools"]) + + found_names = [t.name for t in found] + + assert "save" in found_names + assert "python" in found_names + + +def test_tool_loading_with_module(): + found = _discover_tools(["gptme.tools.save"]) + + found_names = [t.name for t in found] + + assert "save" in found_names + assert "python" not in found_names + + +def test_tool_loading_with_missing_package(): + found = _discover_tools(["gptme.fake_"]) + + assert len(found) == 0 + + +def test_get_available_tools(): + custom_env_value = "gptme.tools.save,gptme.tools.patch" + + with patch("gptme.tools.get_config") as mock_get_config: + # Mock the get_config function to return a mock object + mock_config = mock_get_config.return_value + # Mock the get_env method to return the custom_env_value + mock_config.get_env.return_value = custom_env_value + + tools = get_available_tools() + + assert len(tools) == 3 + assert [t.name for t in tools] == ["append", "patch", "save"] + + +def test_has_tool(): + init_tools(allowlist=frozenset(("save",))) + + assert has_tool("save") + assert not has_tool("anothertool") + + +def test_get_tool(): + init_tools(allowlist=frozenset(("save",))) + + tool_save = get_tool("save") + + assert tool_save + assert tool_save.name == "save" + + assert not get_tool("anothertool") + + +def test_get_tool_for_lang_tag(): + init_tools( + allowlist=frozenset( + ( + "save", + "python", + ) + ) + ) + + assert (tool_python := get_tool_for_langtag("ipython")) + assert tool_python.name == "python" + + # Also test special use cases + assert (tool_save := get_tool_for_langtag("test.txt")) + assert tool_save.name == "save" + + assert (tool_save := get_tool_for_langtag("/src/test")) + assert tool_save.name == "save" + + assert not get_tool_for_langtag("randomtag") + + +def test_is_supported_lang_tag(): + init_tools(allowlist=frozenset(("save",))) + + assert is_supported_langtag("save") + assert not is_supported_langtag("randomtag")