From d7c0eefee2632cd3936cb93b74c53e653f23c745 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 2 Jan 2024 12:06:16 -0500 Subject: [PATCH] Explicit dependency injection --- cookbook/maze.py | 29 +++--- src/marvin/beta/ai_flow/ai_task.py | 2 +- src/marvin/beta/assistants/applications.py | 101 ++++++++++++--------- src/marvin/beta/assistants/assistants.py | 24 ++--- src/marvin/requests.py | 20 ++-- src/marvin/utilities/tools.py | 58 +++++++++++- 6 files changed, 147 insertions(+), 87 deletions(-) diff --git a/cookbook/maze.py b/cookbook/maze.py index b33de394c..6d724c578 100644 --- a/cookbook/maze.py +++ b/cookbook/maze.py @@ -10,6 +10,7 @@ python cookbook/maze.py ``` """ + import random from enum import Enum from io import StringIO @@ -20,8 +21,6 @@ from rich.table import Table from typing_extensions import Literal -_app: AIApplication | None = None - GAME_INSTRUCTIONS = """ This is a TERROR game. You are the disembodied narrator of a maze. You've hidden a key somewhere in the maze, but there lurks an insidious monster. The user must find the key and exit the maze without encounter- @@ -115,9 +114,9 @@ def render(self) -> str: self.user_location: MazeObject.USER.value, self.exit_location: MazeObject.EXIT.value, self.key_location: MazeObject.KEY.value if self.key_location else "", - self.monster_location: MazeObject.MONSTER.value - if self.monster_location - else "", + self.monster_location: ( + MazeObject.MONSTER.value if self.monster_location else "" + ), } for row in range(self.size): @@ -162,18 +161,18 @@ def movable_directions(self) -> list[Literal["N", "S", "E", "W"]]: return directions -def look_around() -> str: - maze = Maze.model_validate(_app.state.read_all()) +def look_around(app: AIApplication) -> str: + maze = Maze.model_validate(app.state.read_all()) return ( f"The maze sprawls.\n{maze.render()}" f"The user may move {maze.movable_directions()=}" ) -def move(direction: Literal["N", "S", "E", "W"]) -> str: +def move(app: AIApplication, direction: Literal["N", "S", "E", "W"]) -> str: """moves the user in the given direction.""" print(f"Moving {direction}") - maze: Maze = Maze.model_validate(_app.state.read_all()) + maze: Maze = Maze.model_validate(app.state.read_all()) prev_location = maze.user_location match direction: case "N": @@ -195,18 +194,18 @@ def move(direction: Literal["N", "S", "E", "W"]) -> str: match maze.user_location: case maze.key_location: - _app.state.write("key_location", (-1, -1)) - _app.state.write("user_location", maze.user_location) + app.state.write("key_location", (-1, -1)) + app.state.write("user_location", maze.user_location) return "The user found the key! Now they must find the exit." case maze.monster_location: return "The user encountered the monster and died. Game over." case maze.exit_location: if maze.key_location != (-1, -1): - _app.state.write("user_location", prev_location) + app.state.write("user_location", prev_location) return "The user can't exit without the key." return "The user found the exit! They win!" - _app.state.write("user_location", maze.user_location) + app.state.write("user_location", maze.user_location) if move_monster := random.random() < 0.4: maze.shuffle_monster() return ( @@ -216,9 +215,9 @@ def move(direction: Literal["N", "S", "E", "W"]) -> str: ) -def reset_maze() -> str: +def reset_maze(app: AIApplication) -> str: """Resets the maze - only to be used when the game is over.""" - _app.state.store = Maze.create().model_dump() + app.state.store = Maze.create().model_dump() return "Resetting the maze." diff --git a/src/marvin/beta/ai_flow/ai_task.py b/src/marvin/beta/ai_flow/ai_task.py index b77704679..ebdd37d5f 100644 --- a/src/marvin/beta/ai_flow/ai_task.py +++ b/src/marvin/beta/ai_flow/ai_task.py @@ -260,7 +260,7 @@ def task_completed_with_result(result: T): self.result = result raise CancelRun() - tool.function.python_fn = task_completed_with_result + tool.function._python_fn = task_completed_with_result return tool diff --git a/src/marvin/beta/assistants/applications.py b/src/marvin/beta/assistants/applications.py index 75d17d02a..604c189d0 100644 --- a/src/marvin/beta/assistants/applications.py +++ b/src/marvin/beta/assistants/applications.py @@ -1,10 +1,11 @@ -import types +import inspect from typing import Optional, Union from pydantic import BaseModel, Field, field_validator from marvin.kv.base import StorageInterface from marvin.kv.in_memory import InMemoryKV +from marvin.requests import Tool from marvin.utilities.jinja import Environment as JinjaEnvironment from marvin.utilities.tools import tool_from_function @@ -44,6 +45,13 @@ class AIApplication(Assistant): + """ + Tools for AI Applications have a special property: if any parameter is + annotated as `AIApplication`, then the tool will be called with the + AIApplication instance as the value for that parameter. This allows tools to + access the AIApplication's state and other properties. + """ + state: StorageInterface = Field(default_factory=InMemoryKV) @field_validator("state", mode="before") @@ -55,57 +63,60 @@ def _check_state(cls, v): return InMemoryKV(store=v) else: raise ValueError( - "must be a `StorageInterface` or a `dict` that can be stored in `InMemoryKV`" + "must be a `StorageInterface` or a `dict` that can be stored in" + " `InMemoryKV`" ) return v def get_instructions(self) -> str: return JinjaEnvironment.render(APPLICATION_INSTRUCTIONS, self_=self) - def _inject_app(self, tool: AssistantTool) -> AssistantTool: - if not ((fn := getattr(tool, "function")) and hasattr(fn, "python_fn")): - return tool + def get_tools(self) -> list[AssistantTool]: + tools = [] - original_function = tool.function.python_fn + for tool in [ + write_state_key, + delete_state_key, + read_state_key, + read_state, + list_state_keys, + ] + self.tools: + if not isinstance(tool, Tool): + kwargs = None + signature = inspect.signature(tool) + parameter = None + for parameter in signature.parameters.values(): + if parameter.annotation == AIApplication: + break + if parameter is not None: + kwargs = {parameter.name: self} - tool.function.python_fn = types.FunctionType( - original_function.__code__, - dict(original_function.__globals__, _app=self), - name=original_function.__name__, - argdefs=original_function.__defaults__, - closure=original_function.__closure__, - ) + tool = tool_from_function(tool, kwargs=kwargs) + tools.append(tool) - return tool + return tools - def get_tools(self) -> list[AssistantTool]: - def write_state_key(key: str, value: StateValueType): - """Writes a key to the state in order to remember it for later.""" - return self.state.write(key, value) - - def delete_state_key(key: str): - """Deletes a key from the state.""" - return self.state.delete(key) - - def read_state_key(key: str) -> Optional[StateValueType]: - """Returns the value of a key from the state.""" - return self.state.read(key) - - def read_state() -> dict[str, StateValueType]: - """Returns the entire state.""" - return self.state.read_all() - - def list_state_keys() -> list[str]: - """Returns the list of keys in the state.""" - return self.state.list_keys() - - return [ - tool_from_function(tool) - for tool in [ - write_state_key, - delete_state_key, - read_state_key, - read_state, - list_state_keys, - ] - ] + [self._inject_app(tool) for tool in super().get_tools()] + +def write_state_key(key: str, value: StateValueType, app: AIApplication): + """Writes a key to the state in order to remember it for later.""" + return app.state.write(key, value) + + +def delete_state_key(key: str, app: AIApplication): + """Deletes a key from the state.""" + return app.state.delete(key) + + +def read_state_key(key: str, app: AIApplication) -> Optional[StateValueType]: + """Returns the value of a key from the state.""" + return app.state.read(key) + + +def read_state(app: AIApplication) -> dict[str, StateValueType]: + """Returns the entire state.""" + return app.state.read_all() + + +def list_state_keys(app: AIApplication) -> list[str]: + """Returns the list of keys in the state.""" + return app.state.list_keys() diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 13169b7e3..c6ec31dc1 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field import marvin.utilities.tools from marvin.requests import Tool @@ -26,7 +26,7 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin): name: str = "Assistant" model: str = "gpt-4-1106-preview" instructions: Optional[str] = Field(None, repr=False) - tools: list[AssistantTool] = [] + tools: list[Union[AssistantTool, Callable]] = [] file_ids: list[str] = [] metadata: dict[str, str] = {} @@ -40,7 +40,14 @@ def clear_default_thread(self): self.default_thread = Thread() def get_tools(self) -> list[AssistantTool]: - return self.tools + return [ + ( + tool + if isinstance(tool, Tool) + else marvin.utilities.tools.tool_from_function(tool) + ) + for tool in self.tools + ] def get_instructions(self) -> str: return self.instructions or "" @@ -66,17 +73,6 @@ async def say_async( ) return run - @field_validator("tools", mode="before") - def format_tools(cls, tools: list[Union[Tool, Callable]]): - return [ - ( - tool - if isinstance(tool, Tool) - else marvin.utilities.tools.tool_from_function(tool) - ) - for tool in tools - ] - def __enter__(self): self.create() return self diff --git a/src/marvin/requests.py b/src/marvin/requests.py index 44c9834cc..95a78cbff 100644 --- a/src/marvin/requests.py +++ b/src/marvin/requests.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Generic, Optional, TypeVar, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from typing_extensions import Annotated, Literal, Self from marvin.settings import settings @@ -21,18 +21,24 @@ class Function(BaseModel, Generic[T]): parameters: dict[str, Any] model: Optional[type[T]] = Field(default=None, exclude=True, repr=False) - python_fn: Optional[Callable[..., Any]] = Field( - default=None, - description="Private field that holds the executable function, if available", - exclude=True, - repr=False, - ) + + # Private field that holds the executable function, if available + _python_fn: Optional[Callable[..., Any]] = PrivateAttr(default=None) def validate_json(self: Self, json_data: Union[str, bytes, bytearray]) -> T: if self.model is None: raise ValueError("This Function was not initialized with a model.") return self.model.model_validate_json(json_data) + @classmethod + def create( + cls, *, _python_fn: Optional[Callable[..., Any]] = None, **kwargs: Any + ) -> "Function": + instance = cls(**kwargs) + if _python_fn is not None: + instance._python_fn = _python_fn + return instance + class Tool(BaseModel, Generic[T]): type: str diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index f36696470..c7315ac11 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -2,8 +2,11 @@ import inspect import json +from functools import update_wrapper from typing import Any, Callable, Optional +from pydantic import PydanticInvalidForJsonSchema + from marvin.requests import Function, Tool from marvin.utilities.asyncio import run_sync from marvin.utilities.logging import get_logger @@ -12,23 +15,68 @@ logger = get_logger("Tools") +def custom_partial(func: Callable, **fixed_kwargs: Any) -> Callable: + """ + Returns a new function with partial application of the given keyword arguments. + The new function has the same __name__ and docstring as the original, and its + signature excludes the provided kwargs. + """ + + # Define the new function with a dynamic signature + def wrapper(**kwargs): + # Merge the provided kwargs with the fixed ones, prioritizing the former + all_kwargs = {**fixed_kwargs, **kwargs} + return func(**all_kwargs) + + # Update the wrapper function's metadata to match the original function + update_wrapper(wrapper, func) + + # Modify the signature to exclude the fixed kwargs + original_sig = inspect.signature(func) + new_params = [ + param + for param in original_sig.parameters.values() + if param.name not in fixed_kwargs + ] + wrapper.__signature__ = original_sig.replace(parameters=new_params) + + return wrapper + + def tool_from_function( fn: Callable[..., Any], name: Optional[str] = None, description: Optional[str] = None, + kwargs: Optional[dict[str, Any]] = None, ): + """ + Creates an OpenAI-CLI tool from a Python function. + + If any kwargs are provided, they will be stored and provided at runtime. + Provided kwargs will be removed from the tool's parameter schema. + """ + if kwargs: + fn = custom_partial(fn, **kwargs) + model = cast_callable_to_model(fn) serializer: Callable[..., dict[str, Any]] = getattr( model, "model_json_schema", getattr(model, "schema") ) + try: + parameters = serializer() + except PydanticInvalidForJsonSchema: + raise TypeError( + "Could not create tool from function because annotations could not be" + f" serialized to JSON: {fn}" + ) return Tool( type="function", - function=Function( + function=Function.create( name=name or fn.__name__, description=description or fn.__doc__, - parameters=serializer(), - python_fn=fn, + parameters=parameters, + _python_fn=fn, ), ) @@ -49,7 +97,7 @@ def call_function_tool( if ( not tool or not tool.function - or not tool.function.python_fn + or not tool.function._python_fn or not tool.function.name ): raise ValueError(f"Could not find function '{function_name}'") @@ -58,7 +106,7 @@ def call_function_tool( logger.debug_kv( f"{tool.function.name}", f"called with arguments: {arguments}", "green" ) - output = tool.function.python_fn(**arguments) + output = tool.function._python_fn(**arguments) if inspect.isawaitable(output): output = run_sync(output) truncated_output = str(output)[:100]