Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inject app into tools #709

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions cookbook/maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
python cookbook/maze.py
```
"""

import random
from enum import Enum
from io import StringIO
Expand All @@ -20,8 +21,6 @@
from rich.table import Table
from typing_extensions import Literal

_app: AIApplication | None = None

GAME_INSTRUCTIONS = """
This is a TERROR game. You are the disembodied narrator of a maze. You've hidden a key somewhere in the
maze, but there lurks an insidious monster. The user must find the key and exit the maze without encounter-
Expand Down Expand Up @@ -115,9 +114,9 @@ def render(self) -> str:
self.user_location: MazeObject.USER.value,
self.exit_location: MazeObject.EXIT.value,
self.key_location: MazeObject.KEY.value if self.key_location else "",
self.monster_location: MazeObject.MONSTER.value
if self.monster_location
else "",
self.monster_location: (
MazeObject.MONSTER.value if self.monster_location else ""
),
}

for row in range(self.size):
Expand Down Expand Up @@ -162,18 +161,18 @@ def movable_directions(self) -> list[Literal["N", "S", "E", "W"]]:
return directions


def look_around() -> str:
maze = Maze.model_validate(_app.state.read_all())
def look_around(app: AIApplication) -> str:
maze = Maze.model_validate(app.state.read_all())
return (
f"The maze sprawls.\n{maze.render()}"
f"The user may move {maze.movable_directions()=}"
)


def move(direction: Literal["N", "S", "E", "W"]) -> str:
def move(app: AIApplication, direction: Literal["N", "S", "E", "W"]) -> str:
"""moves the user in the given direction."""
print(f"Moving {direction}")
maze: Maze = Maze.model_validate(_app.state.read_all())
maze: Maze = Maze.model_validate(app.state.read_all())
prev_location = maze.user_location
match direction:
case "N":
Expand All @@ -195,18 +194,18 @@ def move(direction: Literal["N", "S", "E", "W"]) -> str:

match maze.user_location:
case maze.key_location:
_app.state.write("key_location", (-1, -1))
_app.state.write("user_location", maze.user_location)
app.state.write("key_location", (-1, -1))
app.state.write("user_location", maze.user_location)
return "The user found the key! Now they must find the exit."
case maze.monster_location:
return "The user encountered the monster and died. Game over."
case maze.exit_location:
if maze.key_location != (-1, -1):
_app.state.write("user_location", prev_location)
app.state.write("user_location", prev_location)
return "The user can't exit without the key."
return "The user found the exit! They win!"

_app.state.write("user_location", maze.user_location)
app.state.write("user_location", maze.user_location)
if move_monster := random.random() < 0.4:
maze.shuffle_monster()
return (
Expand All @@ -216,9 +215,9 @@ def move(direction: Literal["N", "S", "E", "W"]) -> str:
)


def reset_maze() -> str:
def reset_maze(app: AIApplication) -> str:
"""Resets the maze - only to be used when the game is over."""
_app.state.store = Maze.create().model_dump()
app.state.store = Maze.create().model_dump()
return "Resetting the maze."


Expand Down
2 changes: 1 addition & 1 deletion src/marvin/beta/ai_flow/ai_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def task_completed_with_result(result: T):
self.result = result
raise CancelRun()

tool.function.python_fn = task_completed_with_result
tool.function._python_fn = task_completed_with_result

return tool

Expand Down
101 changes: 56 additions & 45 deletions src/marvin/beta/assistants/applications.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import types
import inspect
from typing import Optional, Union

from pydantic import BaseModel, Field, field_validator

from marvin.kv.base import StorageInterface
from marvin.kv.in_memory import InMemoryKV
from marvin.requests import Tool
from marvin.utilities.jinja import Environment as JinjaEnvironment
from marvin.utilities.tools import tool_from_function

Expand Down Expand Up @@ -44,6 +45,13 @@


class AIApplication(Assistant):
"""
Tools for AI Applications have a special property: if any parameter is
annotated as `AIApplication`, then the tool will be called with the
AIApplication instance as the value for that parameter. This allows tools to
access the AIApplication's state and other properties.
"""

state: StorageInterface = Field(default_factory=InMemoryKV)

@field_validator("state", mode="before")
Expand All @@ -55,57 +63,60 @@ def _check_state(cls, v):
return InMemoryKV(store=v)
else:
raise ValueError(
"must be a `StorageInterface` or a `dict` that can be stored in `InMemoryKV`"
"must be a `StorageInterface` or a `dict` that can be stored in"
" `InMemoryKV`"
)
return v

def get_instructions(self) -> str:
return JinjaEnvironment.render(APPLICATION_INSTRUCTIONS, self_=self)

def _inject_app(self, tool: AssistantTool) -> AssistantTool:
if not ((fn := getattr(tool, "function")) and hasattr(fn, "python_fn")):
return tool
def get_tools(self) -> list[AssistantTool]:
tools = []

original_function = tool.function.python_fn
for tool in [
write_state_key,
delete_state_key,
read_state_key,
read_state,
list_state_keys,
] + self.tools:
if not isinstance(tool, Tool):
kwargs = None
signature = inspect.signature(tool)
parameter = None
for parameter in signature.parameters.values():
if parameter.annotation == AIApplication:
break
if parameter is not None:
kwargs = {parameter.name: self}

tool.function.python_fn = types.FunctionType(
original_function.__code__,
dict(original_function.__globals__, _app=self),
name=original_function.__name__,
argdefs=original_function.__defaults__,
closure=original_function.__closure__,
)
tool = tool_from_function(tool, kwargs=kwargs)
tools.append(tool)

return tool
return tools

def get_tools(self) -> list[AssistantTool]:
def write_state_key(key: str, value: StateValueType):
"""Writes a key to the state in order to remember it for later."""
return self.state.write(key, value)

def delete_state_key(key: str):
"""Deletes a key from the state."""
return self.state.delete(key)

def read_state_key(key: str) -> Optional[StateValueType]:
"""Returns the value of a key from the state."""
return self.state.read(key)

def read_state() -> dict[str, StateValueType]:
"""Returns the entire state."""
return self.state.read_all()

def list_state_keys() -> list[str]:
"""Returns the list of keys in the state."""
return self.state.list_keys()

return [
tool_from_function(tool)
for tool in [
write_state_key,
delete_state_key,
read_state_key,
read_state,
list_state_keys,
]
] + [self._inject_app(tool) for tool in super().get_tools()]

def write_state_key(key: str, value: StateValueType, app: AIApplication):
"""Writes a key to the state in order to remember it for later."""
return app.state.write(key, value)


def delete_state_key(key: str, app: AIApplication):
"""Deletes a key from the state."""
return app.state.delete(key)


def read_state_key(key: str, app: AIApplication) -> Optional[StateValueType]:
"""Returns the value of a key from the state."""
return app.state.read(key)


def read_state(app: AIApplication) -> dict[str, StateValueType]:
"""Returns the entire state."""
return app.state.read_all()


def list_state_keys(app: AIApplication) -> list[str]:
"""Returns the list of keys in the state."""
return app.state.list_keys()
24 changes: 10 additions & 14 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field

import marvin.utilities.tools
from marvin.requests import Tool
Expand All @@ -26,7 +26,7 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin):
name: str = "Assistant"
model: str = "gpt-4-1106-preview"
instructions: Optional[str] = Field(None, repr=False)
tools: list[AssistantTool] = []
tools: list[Union[AssistantTool, Callable]] = []
file_ids: list[str] = []
metadata: dict[str, str] = {}

Expand All @@ -40,7 +40,14 @@ def clear_default_thread(self):
self.default_thread = Thread()

def get_tools(self) -> list[AssistantTool]:
return self.tools
return [
(
tool
if isinstance(tool, Tool)
else marvin.utilities.tools.tool_from_function(tool)
)
for tool in self.tools
]

def get_instructions(self) -> str:
return self.instructions or ""
Expand All @@ -66,17 +73,6 @@ async def say_async(
)
return run

@field_validator("tools", mode="before")
def format_tools(cls, tools: list[Union[Tool, Callable]]):
return [
(
tool
if isinstance(tool, Tool)
else marvin.utilities.tools.tool_from_function(tool)
)
for tool in tools
]

def __enter__(self):
self.create()
return self
Expand Down
20 changes: 13 additions & 7 deletions src/marvin/requests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, Generic, Optional, TypeVar, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from typing_extensions import Annotated, Literal, Self

from marvin.settings import settings
Expand All @@ -21,18 +21,24 @@ class Function(BaseModel, Generic[T]):
parameters: dict[str, Any]

model: Optional[type[T]] = Field(default=None, exclude=True, repr=False)
python_fn: Optional[Callable[..., Any]] = Field(
default=None,
description="Private field that holds the executable function, if available",
exclude=True,
repr=False,
)

# Private field that holds the executable function, if available
_python_fn: Optional[Callable[..., Any]] = PrivateAttr(default=None)

def validate_json(self: Self, json_data: Union[str, bytes, bytearray]) -> T:
if self.model is None:
raise ValueError("This Function was not initialized with a model.")
return self.model.model_validate_json(json_data)

@classmethod
def create(
cls, *, _python_fn: Optional[Callable[..., Any]] = None, **kwargs: Any
) -> "Function":
instance = cls(**kwargs)
if _python_fn is not None:
instance._python_fn = _python_fn
return instance


class Tool(BaseModel, Generic[T]):
type: str
Expand Down
Loading