diff --git a/src/marvin/beta/assistants/applications.py b/src/marvin/beta/assistants/applications.py index eb1dfc63c..2c8238411 100644 --- a/src/marvin/beta/assistants/applications.py +++ b/src/marvin/beta/assistants/applications.py @@ -1,7 +1,11 @@ -from typing import Union +from typing import Optional, Union -import marvin.utilities.tools +from pydantic import Field + +from marvin.kv.base import StorageInterface +from marvin.kv.in_memory import InMemoryKV from marvin.utilities.jinja import Environment as JinjaEnvironment +from marvin.utilities.tools import tool_from_function from .assistants import Assistant, AssistantTools @@ -22,7 +26,7 @@ objectives to keep track of various threads assist in long-term execution. Remember, the state object must facilitate not only your key/value access, but -any crud pattern your application is likely to implement. You may want to create +any CRUD pattern your application is likely to implement. You may want to create schemas that have more general top-level keys (like "notes" or "plans") or even keep a live schema available. @@ -39,7 +43,7 @@ class AIApplication(Assistant): - state: dict = {} + state: StorageInterface = Field(default_factory=InMemoryKV) def get_instructions(self) -> str: return JinjaEnvironment.render(APPLICATION_INSTRUCTIONS, self_=self) @@ -47,35 +51,31 @@ def get_instructions(self) -> str: def get_tools(self) -> list[AssistantTools]: def write_state_key(key: str, value: StateValueType): """Writes a key to the state in order to remember it for later.""" - self.state[key] = value - return f"Wrote {key} to state." + return self.state.write(key, value) def delete_state_key(key: str): """Deletes a key from the state.""" - del self.state[key] - return f"Deleted {key} from state." + return self.state.delete(key) - def read_state_key(key: str) -> StateValueType: - """Returns the value of a key in the state.""" - return self.state.get(key) + def read_state_key(key: str) -> Optional[StateValueType]: + """Returns the value of a key from the state.""" + return self.state.read(key) def read_state() -> dict[str, StateValueType]: """Returns the entire state.""" - return self.state + return self.state.read_all() - def read_state_keys() -> list[str]: - """Returns a list of all keys in the state.""" - return list(self.state.keys()) + def list_state_keys() -> list[str]: + """Returns the list of keys in the state.""" + return self.state.list_keys() - state_tools = [ - marvin.utilities.tools.tool_from_function(tool) + return [ + tool_from_function(tool) for tool in [ write_state_key, + delete_state_key, read_state_key, read_state, - read_state_keys, - delete_state_key, + list_state_keys, ] - ] - - return super().get_tools() + state_tools + ] + super().get_tools() diff --git a/src/marvin/kv/__init__.py b/src/marvin/kv/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/marvin/kv/base.py b/src/marvin/kv/base.py new file mode 100644 index 000000000..fa8d10ebb --- /dev/null +++ b/src/marvin/kv/base.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Generic, List, Mapping, Optional, TypeVar + +from pydantic import BaseModel +from typing_extensions import ParamSpec + +K = TypeVar("K") # Key type +V = TypeVar("V") # Value type +R = TypeVar("R") # Return type for write/delete operations +P = ParamSpec("P") # Additional parameters + + +class StorageInterface(BaseModel, Generic[K, V, R], ABC): + """An abstract key-value store interface. + + Example: + ```python + store = SomeStorageInterface() + + store.write("foo", "bar") + store.write("baz", "qux") + assert store.read("foo") == "bar" + assert store.read_all() == {"foo": "bar", "baz": "qux"} + assert store.list_keys() == ["foo", "baz"] + store.delete("foo") + assert store.read("foo") is None + assert store.read_all() == {"baz": "qux"} + """ + + @abstractmethod + def write(self, key: K, value: V, *args: P.args, **kwargs: P.kwargs) -> Optional[R]: + pass + + @abstractmethod + def read(self, key: K, *args: P.args, **kwargs: P.kwargs) -> Optional[V]: + pass + + @abstractmethod + def read_all(self, *args: P.args, **kwargs: P.kwargs) -> Mapping[K, V]: + pass + + @abstractmethod + def delete(self, key: K, *args: P.args, **kwargs: P.kwargs) -> Optional[R]: + pass + + @abstractmethod + def list_keys(self, *args: P.args, **kwargs: P.kwargs) -> List[K]: + pass diff --git a/src/marvin/kv/in_memory.py b/src/marvin/kv/in_memory.py new file mode 100644 index 000000000..1619b7da3 --- /dev/null +++ b/src/marvin/kv/in_memory.py @@ -0,0 +1,42 @@ +from typing import Optional, TypeVar + +from pydantic import Field + +from marvin.kv.base import StorageInterface + +K = TypeVar("K", bound=str) +V = TypeVar("V") + + +class InMemoryKV(StorageInterface[K, V, str]): + """An in-memory key-value store. + + Example: + ```python + from marvin.kv.in_memory import InMemoryKV + store = InMemoryKV() + store.write("key", "value") + assert store.read("key") == "value" + ``` + """ + + store: dict[K, V] = Field(default_factory=dict) + + def write(self, key: K, value: V) -> str: + self.store[key] = value + return f"Stored {key}= {value}" + + def delete(self, key: K) -> str: + v = self.store.pop(key, None) + return f"Deleted {key}= {v}" + + def read(self, key: K) -> Optional[V]: + return self.store.get(key) + + def read_all(self, limit: Optional[int] = None) -> dict[K, V]: + if limit is None: + return self.store + return dict(list(self.store.items())[:limit]) + + def list_keys(self) -> list[K]: + return list(self.store.keys()) diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index 6b4cb4790..0b16aefb0 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -42,14 +42,16 @@ def call_function_tool( raise ValueError(f"Could not find function '{function_name}'") arguments = json.loads(function_arguments_json) - logger.debug(f"Calling {tool.function.name} with arguments: {arguments}") + logger.debug_kv( + f"{tool.function.name}", f"called with arguments: {arguments}", "green" + ) output = tool.function.python_fn(**arguments) if inspect.isawaitable(output): output = run_sync(output) truncated_output = str(output)[:100] if len(truncated_output) < len(str(output)): truncated_output += "..." - logger.debug(f"{tool.function.name} returned: {truncated_output}") + logger.debug_kv(f"{tool.function.name}", f"returned: {truncated_output}", "green") if not isinstance(output, str): output = json.dumps(output) return output