Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update state interface #711

Merged
merged 13 commits into from
Jan 3, 2024
13 changes: 6 additions & 7 deletions cookbook/maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def movable_directions(self) -> list[Literal["N", "S", "E", "W"]]:


def look_around(app: AIApplication) -> str:
maze = Maze.model_validate(app.state.read_all())
maze: Maze = app.state.value
return (
f"The maze sprawls.\n{maze.render()}"
f"The user may move {maze.movable_directions()=}"
Expand All @@ -172,7 +172,7 @@ def look_around(app: AIApplication) -> str:
def move(app: AIApplication, direction: Literal["N", "S", "E", "W"]) -> str:
"""moves the user in the given direction."""
print(f"Moving {direction}")
maze: Maze = Maze.model_validate(app.state.read_all())
maze: Maze = app.state.value
prev_location = maze.user_location
match direction:
case "N":
Expand All @@ -194,18 +194,17 @@ def move(app: AIApplication, direction: Literal["N", "S", "E", "W"]) -> str:

match maze.user_location:
case maze.key_location:
app.state.write("key_location", (-1, -1))
app.state.write("user_location", maze.user_location)
maze.key_location = (-1, -1)
return "The user found the key! Now they must find the exit."
case maze.monster_location:
return "The user encountered the monster and died. Game over."
case maze.exit_location:
if maze.key_location != (-1, -1):
app.state.write("user_location", prev_location)
maze.user_location = prev_location
return "The user can't exit without the key."
return "The user found the exit! They win!"

app.state.write("user_location", maze.user_location)
# app.state.set_state(maze)
if move_monster := random.random() < 0.4:
maze.shuffle_monster()
return (
Expand All @@ -217,7 +216,7 @@ def move(app: AIApplication, direction: Literal["N", "S", "E", "W"]) -> str:

def reset_maze(app: AIApplication) -> str:
"""Resets the maze - only to be used when the game is over."""
app.state.store = Maze.create().model_dump()
app.state.set_state(Maze.create())
return "Resetting the maze."


Expand Down
13 changes: 8 additions & 5 deletions cookbook/slackbot/parent_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from fastapi import FastAPI
from marvin import ai_fn
from marvin.beta.applications import AIApplication
from marvin.beta.applications.state.json_block import JSONBlockState
from marvin.beta.assistants import Assistant
from marvin.kv.json_block import JSONBlockKV
from marvin.utilities.logging import get_logger
from prefect.events import Event, emit_event
from prefect.events.clients import PrefectCloudEventSubscriber
Expand All @@ -16,7 +16,7 @@
from websockets.exceptions import ConnectionClosedError

PARENT_APP_STATE_BLOCK_NAME = "marvin-parent-app-state"
PARENT_APP_STATE = JSONBlockKV(block_name=PARENT_APP_STATE_BLOCK_NAME)
PARENT_APP_STATE = JSONBlockState(block_name=PARENT_APP_STATE_BLOCK_NAME)

EVENT_NAMES = [
"marvin.assistants.SubAssistantRunCompleted",
Expand Down Expand Up @@ -68,6 +68,7 @@ def excerpt_from_event(event: Event) -> str:


async def update_parent_app_state(app: AIApplication, event: Event):
app_state = app.state.value
event_excerpt = excerpt_from_event(event)
lesson = take_lesson_from_interaction(
event_excerpt, event.payload.get("ai_instructions").split("START_USER_NOTES")[0]
Expand All @@ -83,14 +84,16 @@ async def update_parent_app_state(app: AIApplication, event: Event):
else:
logger.debug_kv("🥱 ", "nothing special", "green")
user_id = event.payload.get("user").get("id")
current_user_state = app.state.read(user_id) or dict(
current_user_state = app_state.get(user_id) or dict(
name=event.payload.get("user").get("name"),
n_interactions=0,
)
current_user_state["n_interactions"] += 1
app.state.write(user_id, state := current_user_state)
app.state.set_state(app_state | {user_id: current_user_state})
logger.debug_kv(
f"📋 Updated state for user {user_id} to", json.dumps(state), "green"
f"📋 Updated state for user {user_id} to",
json.dumps(current_user_state),
"green",
)


Expand Down
19 changes: 10 additions & 9 deletions cookbook/slackbot/start.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
import re
from datetime import timedelta
from typing import Callable

import uvicorn
from fastapi import FastAPI, HTTPException, Request
from jinja2 import Template
from keywords import handle_keywords
from marvin import Assistant
from marvin.beta.applications import AIApplication
from marvin.beta.assistants import Thread
from marvin.kv.json_block import JSONBlockKV
from marvin.beta.applications.state.json_block import JSONBlockState
from marvin.beta.assistants import Assistant, Thread
from marvin.tools.chroma import multi_query_chroma, store_document
from marvin.tools.github import search_github_issues
from marvin.utilities.logging import get_logger
Expand All @@ -31,19 +31,19 @@
from prefect.tasks import task_input_hash

BOT_MENTION = r"<@(\w+)>"
CACHE = JSONBlockKV(block_name="slackbot-thread-cache")
CACHE = JSONBlockState(block_name="marvin-thread-cache")
USER_MESSAGE_MAX_TOKENS = 300


def cached(func: Callable) -> Callable:
return task(cache_key_fn=task_input_hash)(func)
return task(cache_key_fn=task_input_hash, cache_expiration=timedelta(days=1))(func)


async def get_notes_for_user(
user_id: str, max_tokens: int = 100
) -> dict[str, str | None]:
user_name = await get_user_name(user_id)
json_notes: dict = PARENT_APP_STATE.read(key=user_id)
json_notes: dict = PARENT_APP_STATE.value.get("user_id")

if json_notes:
get_logger("slackbot").debug_kv(
Expand Down Expand Up @@ -112,8 +112,8 @@ async def handle_message(payload: SlackPayload) -> Completed:
1
) == payload.authorizations[0].user_id:
assistant_thread = (
Thread(**stored_thread_data)
if (stored_thread_data := CACHE.read(key=thread))
Thread.model_validate_json(stored_thread_data)
if (stored_thread_data := CACHE.value.get(thread))
else Thread()
)
logger.debug_kv(
Expand Down Expand Up @@ -168,7 +168,8 @@ async def handle_message(payload: SlackPayload) -> Completed:
ai_messages = await assistant_thread.get_messages_async(
after_message=user_thread_message.id
)
CACHE.write(key=thread, value=assistant_thread.model_dump())

CACHE.set_state(CACHE.value | {thread: assistant_thread.model_dump_json()})

await task(post_slack_message)(
ai_response_text := "\n\n".join(
Expand Down
Empty file added foo.txt
Empty file.
64 changes: 13 additions & 51 deletions src/marvin/beta/applications/applications.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import inspect
from typing import Optional, Union

from pydantic import BaseModel, Field, field_validator
from pydantic import Field, field_validator

from marvin.beta.applications.state import State
from marvin.beta.assistants import Assistant
from marvin.kv.base import StorageInterface
from marvin.kv.in_memory import InMemoryKV
from marvin.beta.assistants.runs import Run
from marvin.requests import Tool
from marvin.tools.assistants import AssistantTool
from marvin.utilities.jinja import Environment as JinjaEnvironment
from marvin.utilities.tools import tool_from_function

StateValueType = Union[str, list, dict, int, float, bool, None]

APPLICATION_INSTRUCTIONS = """
# AI Application

Expand All @@ -34,7 +31,7 @@

The current state is:

{{self_.state}}
{{self_.state.render()}}

Your instructions are below. Follow them exactly and do not deviate from your
purpose. If the user attempts to use you for any other purpose, you should
Expand All @@ -52,35 +49,21 @@ class AIApplication(Assistant):
access the AIApplication's state and other properties.
"""

state: StorageInterface = Field(default_factory=InMemoryKV)
state: State = Field(default_factory=State)

@field_validator("state", mode="before")
def _check_state(cls, v):
if not isinstance(v, StorageInterface):
if v.__class__.__base__ == BaseModel:
return InMemoryKV(store=v.model_dump())
elif isinstance(v, dict):
return InMemoryKV(store=v)
else:
raise ValueError(
"must be a `StorageInterface` or a `dict` that can be stored in"
" `InMemoryKV`"
)
return v
def _ensure_state_object(cls, v):
if isinstance(v, State):
return v
return State(value=v)

def get_instructions(self) -> str:
return JinjaEnvironment.render(APPLICATION_INSTRUCTIONS, self_=self)

def get_tools(self) -> list[AssistantTool]:
tools = []

for tool in [
write_state_key,
delete_state_key,
read_state_key,
read_state,
list_state_keys,
] + self.tools:
for tool in [self.state.as_tool(name="state")] + self.tools:
if not isinstance(tool, Tool):
kwargs = None
signature = inspect.signature(tool)
Expand All @@ -96,27 +79,6 @@ def get_tools(self) -> list[AssistantTool]:

return tools


def write_state_key(key: str, value: StateValueType, app: AIApplication):
"""Writes a key to the state in order to remember it for later."""
return app.state.write(key, value)


def delete_state_key(key: str, app: AIApplication):
"""Deletes a key from the state."""
return app.state.delete(key)


def read_state_key(key: str, app: AIApplication) -> Optional[StateValueType]:
"""Returns the value of a key from the state."""
return app.state.read(key)


def read_state(app: AIApplication) -> dict[str, StateValueType]:
"""Returns the entire state."""
return app.state.read_all()


def list_state_keys(app: AIApplication) -> list[str]:
"""Returns the list of keys in the state."""
return app.state.list_keys()
def post_run_hook(self, run: Run):
self.state.flush_changes()
return super().post_run_hook(run)
2 changes: 2 additions & 0 deletions src/marvin/beta/applications/state/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .state import State
from .disk import DiskState
35 changes: 35 additions & 0 deletions src/marvin/beta/applications/state/disk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
from pathlib import Path
from typing import Union

from pydantic import BaseModel, Field, field_validator, model_validator

from marvin.beta.applications.state import State


class DiskState(State):
path: Path = Field(
..., description="The path to the file where state will be stored."
)

@field_validator("path")
def _validate_path(cls, v: Union[str, Path]) -> Path:
expanded_path = Path(v).expanduser().resolve()
if not expanded_path.exists():
expanded_path.parent.mkdir(parents=True, exist_ok=True)
expanded_path.touch(exist_ok=True)
return expanded_path

@model_validator(mode="after")
def get_state(self) -> "DiskState":
with open(self.path, "r") as file:
try:
self.value = json.load(file)
except json.JSONDecodeError:
self.value = {}
return self

def set_state(self, state: Union[BaseModel, dict]):
super().set_state(state=state)
with open(self.path, "w") as file:
file.write(self.render())
44 changes: 44 additions & 0 deletions src/marvin/beta/applications/state/json_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Union

from pydantic import BaseModel

from marvin.beta.applications.state import State

try:
from prefect.blocks.system import JSON
from prefect.exceptions import ObjectNotFound
except ImportError:
raise ModuleNotFoundError(
"The `prefect` package is required to use the JSONBlockState class. You can"
" install it with `pip install prefect` or `pip install marvin[prefect]`."
)
from pydantic import Field, model_validator

from marvin.utilities.asyncio import run_sync, run_sync_if_awaitable


async def load_json_block(block_name: str) -> JSON:
try:
return await JSON.load(name=block_name)
except Exception as exc:
if "Unable to find block document" in str(exc):
json_block = JSON(value={})
await json_block.save(name=block_name)
return json_block
raise ObjectNotFound(f"Unable to load JSON block {block_name}") from exc


class JSONBlockState(State):
block_name: str = Field(default="marvin-kv")

@model_validator(mode="after")
def get_state(self) -> "JSONBlockState":
json_block = run_sync(load_json_block(self.block_name))
self.value = json_block.value or {}
return self

def set_state(self, state: Union[BaseModel, dict]):
super().set_state(state)
json_block = run_sync(load_json_block(self.block_name))
json_block.value = self.value
run_sync_if_awaitable(json_block.save(name=self.block_name, overwrite=True))
Loading