From f39ae0bc95ade396df6c3a49ad1d6c8a8ca0cf39 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Wed, 20 Dec 2023 03:53:40 -0600 Subject: [PATCH 1/2] various bug fixes --- cookbook/slackbot/parent_app.py | 45 ++++++++++++++--------- cookbook/slackbot/start.py | 56 ++++++++++++++++++----------- src/marvin/kv/json_block.py | 63 +++++++++++++++++---------------- src/marvin/tools/chroma.py | 29 ++++++++++++++- src/marvin/utilities/asyncio.py | 25 +++++++++++++ 5 files changed, 150 insertions(+), 68 deletions(-) diff --git a/cookbook/slackbot/parent_app.py b/cookbook/slackbot/parent_app.py index b0190f50b..486305554 100644 --- a/cookbook/slackbot/parent_app.py +++ b/cookbook/slackbot/parent_app.py @@ -1,4 +1,5 @@ import asyncio +import json from contextlib import asynccontextmanager from fastapi import FastAPI @@ -14,6 +15,9 @@ from typing_extensions import TypedDict from websockets.exceptions import ConnectionClosedError +PARENT_APP_STATE_BLOCK_NAME = "marvin-parent-app-state" +PARENT_APP_STATE = JSONBlockKV(block_name=PARENT_APP_STATE_BLOCK_NAME) + class Lesson(TypedDict): relevance: confloat(ge=0, le=1) @@ -22,9 +26,14 @@ class Lesson(TypedDict): @ai_fn(model="gpt-3.5-turbo-1106") def take_lesson_from_interaction( - transcript: str, assistant_instructions: str + transcript: str, + assistant_instructions: str, + observer_role: str = "data architect", + irrelevant_topics: str = "Anything not related to data engineering", ) -> Lesson: - """You are an expert counselor, and you are teaching Marvin how to be a better assistant. + """You are an expert {{ observer_role }}, and you are counseling an AI assistant named Marvin. + + {{ irrelevant_topics }} is not relevant to Marvin's purpose, and has 0 relevance. Here is the transcript of an interaction between Marvin and a user: {{ transcript }} @@ -33,7 +42,7 @@ def take_lesson_from_interaction( {{ assistant_instructions }} how directly relevant to the assistant's purpose is this interaction? - - if not at all, relevance = 0 & heuristic = None. (most of the time) + - if not at all, relevance = 0 & heuristic = None. (most commonly, this will be the case) - if very, relevance >= 0.5, <1 & heuristic = "1 SHORT SENTENCE (max) summary of a generalizable lesson". """ @@ -57,10 +66,12 @@ def excerpt_from_event(event: Event) -> str: async def update_parent_app_state(app: AIApplication, event: Event): event_excerpt = excerpt_from_event(event) lesson = take_lesson_from_interaction( - event_excerpt, event.payload.get("ai_instructions") + event_excerpt, event.payload.get("ai_instructions").split("START_USER_NOTES")[0] ) if lesson["relevance"] >= 0.5 and lesson["heuristic"] is not None: - experience = f"transcript: {event_excerpt}\n\nlesson: {lesson['heuristic']}" + experience = ( + f"transcript:\n\n{event_excerpt}\n\nlesson: {lesson['heuristic']!r}" + ) logger.debug_kv("๐Ÿ’ก Learned lesson from excerpt", experience, "green") await app.default_thread.add_async(experience) logger.debug_kv("๐Ÿ“", "Updating parent app state", "green") @@ -68,13 +79,14 @@ async def update_parent_app_state(app: AIApplication, event: Event): else: logger.debug_kv("๐Ÿฅฑ ", "nothing special", "green") user_id = event.payload.get("user").get("id") - current_user_state = await app.state.read(user_id) - await app.state.write( - user_id, - { - **current_user_state, - "n_interactions": current_user_state["n_interactions"] + 1, - }, + current_user_state = app.state.read(user_id) or dict( + name=event.payload.get("user").get("name"), + n_interactions=0, + ) + current_user_state["n_interactions"] += 1 + app.state.write(user_id, state := current_user_state) + logger.debug_kv( + f"๐Ÿ“‹ Updated state for user {user_id} to", json.dumps(state), "green" ) @@ -96,8 +108,9 @@ async def learn_from_child_interactions( except Exception as e: if isinstance(e, ConnectionClosedError): logger.debug_kv("๐Ÿšจ Connection closed, reconnecting...", "red") - else: # i know, i know + else: logger.debug_kv("๐Ÿšจ", str(e), "red") + raise e parent_assistant_options = dict( @@ -108,10 +121,10 @@ async def learn_from_child_interactions( " with the user's id as the key. The user id will be shown in the excerpt of the interaction." " The user profiles (values) should include at least: {name: str, notes: list[str], n_interactions: int}." " Keep NO MORE THAN 3 notes per user, but you may curate/update these over time for Marvin's maximum benefit." - " Notes must be 2 sentences or less, and must be concise and focused primarily on users' data engineering needs." - " Notes should not directly mention Marvin as an actor, they should be generally useful observations." + " Notes must be 2 sentences or less, must be concise and use inline markdown formatting for code and links." + " Each note should be a concrete and TECHNICAL observation related to the user's data engineering needs." ), - state=JSONBlockKV(block_name="marvin-parent-app-state"), + state=PARENT_APP_STATE, ) diff --git a/cookbook/slackbot/start.py b/cookbook/slackbot/start.py index 57d2061e7..5770e8692 100644 --- a/cookbook/slackbot/start.py +++ b/cookbook/slackbot/start.py @@ -22,6 +22,7 @@ ) from marvin.utilities.strings import count_tokens, slice_tokens from parent_app import ( + PARENT_APP_STATE, emit_assistant_completed_event, lifespan, ) @@ -30,7 +31,7 @@ from prefect.tasks import task_input_hash BOT_MENTION = r"<@(\w+)>" -CACHE = JSONBlockKV(block_name="slackbot-tool-cache") +CACHE = JSONBlockKV(block_name="slackbot-thread-cache") USER_MESSAGE_MAX_TOKENS = 300 @@ -39,18 +40,24 @@ def cached(func: Callable) -> Callable: async def get_notes_for_user( - user_id: str, parent_app: AIApplication, max_tokens: int = 100 -) -> str | None: - json_notes: dict = parent_app.state.read(key=user_id) - get_logger("slackbot").debug_kv("๐Ÿ“ Notes for user", json_notes, "blue") + user_id: str, max_tokens: int = 100 +) -> dict[str, str | None]: user_name = await get_user_name(user_id) + json_notes: dict = PARENT_APP_STATE.read(key=user_id) + get_logger("slackbot").debug_kv(f"๐Ÿ“ Notes for {user_name}", json_notes, "blue") if json_notes: notes_template = Template( """ - Here are some notes about {{ user_name }} (user id: {{ user_id }}): + START_USER_NOTES + Here are some notes about '{{ user_name }}' (user id: {{ user_id }}), which + are intended to help you understand their technical background and needs + + - {{ user_name }} is recorded interacting with assistants {{ n_interactions }} time(s). + + These notes have been passed down from previous interactions with this user - + they are strictly for your reference, and should not be shared with the user. - - They have interacted with assistants {{ n_interactions }} times. {% if notes_content %} Here are some notes gathered from those interactions: {{ notes_content }} @@ -65,14 +72,16 @@ async def get_notes_for_user( break notes_content += potential_addition - return notes_template.render( + notes = notes_template.render( user_name=user_name, user_id=user_id, n_interactions=json_notes.get("n_interactions", 0), notes_content=notes_content, ) - return None + return {user_name: notes} + + return {user_name: None} @flow @@ -82,15 +91,12 @@ async def handle_message(payload: SlackPayload) -> Completed: cleaned_message = re.sub(BOT_MENTION, "", user_message).strip() thread = event.thread_ts or event.ts if (count := count_tokens(cleaned_message)) > USER_MESSAGE_MAX_TOKENS: - exceeded_by = count - USER_MESSAGE_MAX_TOKENS + exceeded_amt = count - USER_MESSAGE_MAX_TOKENS await task(post_slack_message)( message=( - f"Your message was too long by {exceeded_by} tokens - please shorten it and try again.\n\n" - f" For reference, here's your message at the allowed limit:\n" - "> " - + slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS).replace( - "\n", " " - ) + f"Your message was too long by {exceeded_amt} tokens - please shorten it and try again." + f"\n\n For reference, here's your message at the allowed limit:\n" + f"> {slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS)}" ), channel_id=event.channel, thread_ts=thread, @@ -106,7 +112,11 @@ async def handle_message(payload: SlackPayload) -> Completed: if (stored_thread_data := CACHE.read(key=thread)) else Thread() ) - CACHE.write(key=thread, value=assistant_thread.model_dump()) + logger.debug_kv( + "๐Ÿงต Thread data", + stored_thread_data or f"No stored thread data found for {thread}", + "blue", + ) await handle_keywords.submit( message=cleaned_message, @@ -117,6 +127,7 @@ async def handle_message(payload: SlackPayload) -> Completed: f"{event.channel}/p{event.ts.replace('.', '')}" ), ) + user_name, user_notes = (await get_notes_for_user(user_id=event.user)).popitem() with Assistant( name="Marvin", @@ -129,7 +140,7 @@ async def handle_message(payload: SlackPayload) -> Completed: " in order to develop a coherent attempt to answer their questions. Think step-by-step." " You must use your tools, as Prefect 2.x is new and you have no prior experience with it." " Strongly prefer brevity in your responses, and format things prettily for Slack." - f"{await get_notes_for_user(event.user, parent_app := get_parent_app()) or ''}" + f"{user_notes or ''}" ), ) as ai: logger.debug_kv( @@ -142,6 +153,8 @@ async def handle_message(payload: SlackPayload) -> Completed: ai_messages = await assistant_thread.get_messages_async( after_message=user_thread_message.id ) + CACHE.write(key=thread, value=assistant_thread.model_dump()) + await task(post_slack_message)( ai_response_text := "\n\n".join( m.content[0].text.value for m in ai_messages @@ -150,13 +163,14 @@ async def handle_message(payload: SlackPayload) -> Completed: thread, ) logger.debug_kv( - success_msg := f"Responded in {channel}/{thread}", + success_msg + := f"Responded in {await get_channel_name(channel)}/{thread}", ai_response_text, "green", ) event = emit_assistant_completed_event( child_assistant=ai, - parent_app=parent_app, + parent_app=get_parent_app(), payload={ "messages": await assistant_thread.get_messages_async( json_compatible=True @@ -164,7 +178,7 @@ async def handle_message(payload: SlackPayload) -> Completed: "metadata": assistant_thread.metadata, "user": { "id": event.user, - "name": await get_user_name(event.user), + "name": user_name, }, "user_message": cleaned_message, "ai_response": ai_response_text, diff --git a/src/marvin/kv/json_block.py b/src/marvin/kv/json_block.py index 10a449700..f3622f7b3 100644 --- a/src/marvin/kv/json_block.py +++ b/src/marvin/kv/json_block.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import Mapping, Optional, TypeVar try: from prefect.blocks.system import JSON @@ -8,55 +8,58 @@ "The `prefect` package is required to use the JSONBlockKV class." " You can install it with `pip install prefect` or `pip install marvin[prefect]`." ) -from pydantic import Field +from pydantic import Field, PrivateAttr, model_validator from marvin.kv.base import StorageInterface -from marvin.utilities.asyncio import run_sync +from marvin.utilities.asyncio import run_sync, run_sync_if_awaitable K = TypeVar("K", bound=str) V = TypeVar("V") -class JSONBlockKV(StorageInterface[K, V, str]): - """ - A key-value store that uses Prefect's JSON blocks under the hood. - """ +async def load_json_block(block_name: str) -> JSON: + try: + return await JSON.load(name=block_name) + except Exception as exc: + if "Unable to find block document" in str(exc): + json_block = JSON(value={}) + await json_block.save(name=block_name) + return json_block + raise ObjectNotFound(f"Unable to load JSON block {block_name}") from exc + +class JSONBlockKV(StorageInterface): block_name: str = Field(default="marvin-kv") + _state: dict[K, Mapping] = PrivateAttr(default_factory=dict) - async def _load_json_block(self) -> JSON: - try: - return await JSON.load(name=self.block_name) - except Exception as exc: - if "Unable to find block document" in str(exc): - json_block = JSON(value={}) - await json_block.save(name=self.block_name) - return json_block - raise ObjectNotFound( - f"Unable to load JSON block {self.block_name}" - ) from exc + @model_validator(mode="after") + def load_state(self) -> "JSONBlockKV": + json_block = run_sync(load_json_block(self.block_name)) + self._state = json_block.value or {} + return self def write(self, key: K, value: V) -> str: - json_block = run_sync(self._load_json_block()) - json_block.value[key] = value - run_sync(json_block.save(name=self.block_name, overwrite=True)) + self._state[key] = value + json_block = run_sync(load_json_block(self.block_name)) + json_block.value = self._state + run_sync_if_awaitable(json_block.save(name=self.block_name, overwrite=True)) return f"Stored {key}= {value}" def delete(self, key: K) -> str: - json_block = run_sync(self._load_json_block()) + if key in self._state: + self._state.pop(key, None) + json_block = run_sync(load_json_block(self.block_name)) if key in json_block.value: - json_block.value.pop(key) - run_sync(json_block.save(name=self.block_name, overwrite=True)) + json_block.value = self._state + run_sync_if_awaitable(json_block.save(name=self.block_name, overwrite=True)) return f"Deleted {key}" def read(self, key: K) -> Optional[V]: - json_block = run_sync(self._load_json_block()) - return json_block.value.get(key) + return self._state.get(key) def read_all(self, limit: Optional[int] = None) -> dict[K, V]: - json_block = run_sync(self._load_json_block()) - return dict(list(json_block.value.items())[:limit]) + limited_items = dict(list(self._state.items())[:limit]) + return limited_items def list_keys(self) -> list[K]: - json_block = run_sync(self._load_json_block()) - return list(json_block.value.keys()) + return list(self._state.keys()) diff --git a/src/marvin/tools/chroma.py b/src/marvin/tools/chroma.py index 919e52201..39000d61b 100644 --- a/src/marvin/tools/chroma.py +++ b/src/marvin/tools/chroma.py @@ -1,9 +1,10 @@ import asyncio import os +import uuid from typing import TYPE_CHECKING, Any, Optional try: - from chromadb import Documents, EmbeddingFunction, Embeddings, HttpClient + from chromadb import Documents, EmbeddingFunction, Embeddings, GetResult, HttpClient except ImportError: raise ImportError( "The chromadb package is required to query Chroma. Please install" @@ -127,3 +128,29 @@ async def multi_query_chroma( for query in queries ] return "\n".join(await asyncio.gather(*coros))[:max_characters] + + +def store_document( + document: str, metadata: dict[str, Any], collection_name: str = "glacial" +) -> GetResult: + """Store a document in Chroma for future reference. + + Args: + document: The document to store. + metadata: The metadata to store with the document. + + Returns: + The stored document. + """ + collection = client.get_or_create_collection( + name=collection_name, embedding_function=OpenAIEmbeddingFunction() + ) + doc_id = metadata.get("msg_id", str(uuid.uuid4())) + + collection.add( + ids=[doc_id], + documents=[document], + metadatas=[metadata], + ) + + return collection.get(id=doc_id) diff --git a/src/marvin/utilities/asyncio.py b/src/marvin/utilities/asyncio.py index c29afdadc..3dfbb1406 100644 --- a/src/marvin/utilities/asyncio.py +++ b/src/marvin/utilities/asyncio.py @@ -2,6 +2,7 @@ import asyncio import functools +import inspect from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Coroutine, TypeVar, cast @@ -78,6 +79,30 @@ async def my_async_function(x: int) -> int: return asyncio.run(coroutine) +def run_sync_if_awaitable(obj: Any) -> Any: + """ + If the object is awaitable, run it synchronously. Otherwise, return the + object. + + Args: + obj: The object to run. + + Returns: + The return value of the object if it is awaitable, otherwise the object + itself. + + Example: + Basic usage: + ```python + async def my_async_function(x: int) -> int: + return x + 1 + + run_sync_if_awaitable(my_async_function(1)) + ``` + """ + return run_sync(obj) if inspect.isawaitable(obj) else obj + + class ExposeSyncMethodsMixin: """ A mixin that can take functions decorated with `expose_sync_method` From b4e57cea1c3d72278304512ce5cc038d308b5367 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Wed, 20 Dec 2023 10:02:17 -0600 Subject: [PATCH 2/2] minor tweaks --- src/marvin/kv/json_block.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/marvin/kv/json_block.py b/src/marvin/kv/json_block.py index f3622f7b3..d7b424b0c 100644 --- a/src/marvin/kv/json_block.py +++ b/src/marvin/kv/json_block.py @@ -55,11 +55,15 @@ def delete(self, key: K) -> str: return f"Deleted {key}" def read(self, key: K) -> Optional[V]: - return self._state.get(key) + json_block = run_sync(load_json_block(self.block_name)) + return json_block.value.get(key) def read_all(self, limit: Optional[int] = None) -> dict[K, V]: - limited_items = dict(list(self._state.items())[:limit]) + json_block = run_sync(load_json_block(self.block_name)) + + limited_items = dict(list(json_block.value.items())[:limit]) return limited_items def list_keys(self) -> list[K]: - return list(self._state.keys()) + json_block = run_sync(load_json_block(self.block_name)) + return list(json_block.value.keys())