Skip to content

Commit

Permalink
Merge pull request #699 from PrefectHQ/fix-json-kv
Browse files Browse the repository at this point in the history
various bug fixes
  • Loading branch information
zzstoatzz authored Dec 20, 2023
2 parents e0b4803 + b4e57ce commit 9c7fbf5
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 66 deletions.
45 changes: 29 additions & 16 deletions cookbook/slackbot/parent_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from contextlib import asynccontextmanager

from fastapi import FastAPI
Expand All @@ -14,6 +15,9 @@
from typing_extensions import TypedDict
from websockets.exceptions import ConnectionClosedError

PARENT_APP_STATE_BLOCK_NAME = "marvin-parent-app-state"
PARENT_APP_STATE = JSONBlockKV(block_name=PARENT_APP_STATE_BLOCK_NAME)


class Lesson(TypedDict):
relevance: confloat(ge=0, le=1)
Expand All @@ -22,9 +26,14 @@ class Lesson(TypedDict):

@ai_fn(model="gpt-3.5-turbo-1106")
def take_lesson_from_interaction(
transcript: str, assistant_instructions: str
transcript: str,
assistant_instructions: str,
observer_role: str = "data architect",
irrelevant_topics: str = "Anything not related to data engineering",
) -> Lesson:
"""You are an expert counselor, and you are teaching Marvin how to be a better assistant.
"""You are an expert {{ observer_role }}, and you are counseling an AI assistant named Marvin.
{{ irrelevant_topics }} is not relevant to Marvin's purpose, and has 0 relevance.
Here is the transcript of an interaction between Marvin and a user:
{{ transcript }}
Expand All @@ -33,7 +42,7 @@ def take_lesson_from_interaction(
{{ assistant_instructions }}
how directly relevant to the assistant's purpose is this interaction?
- if not at all, relevance = 0 & heuristic = None. (most of the time)
- if not at all, relevance = 0 & heuristic = None. (most commonly, this will be the case)
- if very, relevance >= 0.5, <1 & heuristic = "1 SHORT SENTENCE (max) summary of a generalizable lesson".
"""

Expand All @@ -57,24 +66,27 @@ def excerpt_from_event(event: Event) -> str:
async def update_parent_app_state(app: AIApplication, event: Event):
event_excerpt = excerpt_from_event(event)
lesson = take_lesson_from_interaction(
event_excerpt, event.payload.get("ai_instructions")
event_excerpt, event.payload.get("ai_instructions").split("START_USER_NOTES")[0]
)
if lesson["relevance"] >= 0.5 and lesson["heuristic"] is not None:
experience = f"transcript: {event_excerpt}\n\nlesson: {lesson['heuristic']}"
experience = (
f"transcript:\n\n{event_excerpt}\n\nlesson: {lesson['heuristic']!r}"
)
logger.debug_kv("💡 Learned lesson from excerpt", experience, "green")
await app.default_thread.add_async(experience)
logger.debug_kv("📝", "Updating parent app state", "green")
await app.default_thread.run_async(app)
else:
logger.debug_kv("🥱 ", "nothing special", "green")
user_id = event.payload.get("user").get("id")
current_user_state = await app.state.read(user_id)
await app.state.write(
user_id,
{
**current_user_state,
"n_interactions": current_user_state["n_interactions"] + 1,
},
current_user_state = app.state.read(user_id) or dict(
name=event.payload.get("user").get("name"),
n_interactions=0,
)
current_user_state["n_interactions"] += 1
app.state.write(user_id, state := current_user_state)
logger.debug_kv(
f"📋 Updated state for user {user_id} to", json.dumps(state), "green"
)


Expand All @@ -96,8 +108,9 @@ async def learn_from_child_interactions(
except Exception as e:
if isinstance(e, ConnectionClosedError):
logger.debug_kv("🚨 Connection closed, reconnecting...", "red")
else: # i know, i know
else:
logger.debug_kv("🚨", str(e), "red")
raise e


parent_assistant_options = dict(
Expand All @@ -108,10 +121,10 @@ async def learn_from_child_interactions(
" with the user's id as the key. The user id will be shown in the excerpt of the interaction."
" The user profiles (values) should include at least: {name: str, notes: list[str], n_interactions: int}."
" Keep NO MORE THAN 3 notes per user, but you may curate/update these over time for Marvin's maximum benefit."
" Notes must be 2 sentences or less, and must be concise and focused primarily on users' data engineering needs."
" Notes should not directly mention Marvin as an actor, they should be generally useful observations."
" Notes must be 2 sentences or less, must be concise and use inline markdown formatting for code and links."
" Each note should be a concrete and TECHNICAL observation related to the user's data engineering needs."
),
state=JSONBlockKV(block_name="marvin-parent-app-state"),
state=PARENT_APP_STATE,
)


Expand Down
56 changes: 35 additions & 21 deletions cookbook/slackbot/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from marvin.utilities.strings import count_tokens, slice_tokens
from parent_app import (
PARENT_APP_STATE,
emit_assistant_completed_event,
lifespan,
)
Expand All @@ -30,7 +31,7 @@
from prefect.tasks import task_input_hash

BOT_MENTION = r"<@(\w+)>"
CACHE = JSONBlockKV(block_name="slackbot-tool-cache")
CACHE = JSONBlockKV(block_name="slackbot-thread-cache")
USER_MESSAGE_MAX_TOKENS = 300


Expand All @@ -39,18 +40,24 @@ def cached(func: Callable) -> Callable:


async def get_notes_for_user(
user_id: str, parent_app: AIApplication, max_tokens: int = 100
) -> str | None:
json_notes: dict = parent_app.state.read(key=user_id)
get_logger("slackbot").debug_kv("📝 Notes for user", json_notes, "blue")
user_id: str, max_tokens: int = 100
) -> dict[str, str | None]:
user_name = await get_user_name(user_id)
json_notes: dict = PARENT_APP_STATE.read(key=user_id)
get_logger("slackbot").debug_kv(f"📝 Notes for {user_name}", json_notes, "blue")

if json_notes:
notes_template = Template(
"""
Here are some notes about {{ user_name }} (user id: {{ user_id }}):
START_USER_NOTES
Here are some notes about '{{ user_name }}' (user id: {{ user_id }}), which
are intended to help you understand their technical background and needs
- {{ user_name }} is recorded interacting with assistants {{ n_interactions }} time(s).
These notes have been passed down from previous interactions with this user -
they are strictly for your reference, and should not be shared with the user.
- They have interacted with assistants {{ n_interactions }} times.
{% if notes_content %}
Here are some notes gathered from those interactions:
{{ notes_content }}
Expand All @@ -65,14 +72,16 @@ async def get_notes_for_user(
break
notes_content += potential_addition

return notes_template.render(
notes = notes_template.render(
user_name=user_name,
user_id=user_id,
n_interactions=json_notes.get("n_interactions", 0),
notes_content=notes_content,
)

return None
return {user_name: notes}

return {user_name: None}


@flow
Expand All @@ -82,15 +91,12 @@ async def handle_message(payload: SlackPayload) -> Completed:
cleaned_message = re.sub(BOT_MENTION, "", user_message).strip()
thread = event.thread_ts or event.ts
if (count := count_tokens(cleaned_message)) > USER_MESSAGE_MAX_TOKENS:
exceeded_by = count - USER_MESSAGE_MAX_TOKENS
exceeded_amt = count - USER_MESSAGE_MAX_TOKENS
await task(post_slack_message)(
message=(
f"Your message was too long by {exceeded_by} tokens - please shorten it and try again.\n\n"
f" For reference, here's your message at the allowed limit:\n"
"> "
+ slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS).replace(
"\n", " "
)
f"Your message was too long by {exceeded_amt} tokens - please shorten it and try again."
f"\n\n For reference, here's your message at the allowed limit:\n"
f"> {slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS)}"
),
channel_id=event.channel,
thread_ts=thread,
Expand All @@ -106,7 +112,11 @@ async def handle_message(payload: SlackPayload) -> Completed:
if (stored_thread_data := CACHE.read(key=thread))
else Thread()
)
CACHE.write(key=thread, value=assistant_thread.model_dump())
logger.debug_kv(
"🧵 Thread data",
stored_thread_data or f"No stored thread data found for {thread}",
"blue",
)

await handle_keywords.submit(
message=cleaned_message,
Expand All @@ -117,6 +127,7 @@ async def handle_message(payload: SlackPayload) -> Completed:
f"{event.channel}/p{event.ts.replace('.', '')}"
),
)
user_name, user_notes = (await get_notes_for_user(user_id=event.user)).popitem()

with Assistant(
name="Marvin",
Expand All @@ -129,7 +140,7 @@ async def handle_message(payload: SlackPayload) -> Completed:
" in order to develop a coherent attempt to answer their questions. Think step-by-step."
" You must use your tools, as Prefect 2.x is new and you have no prior experience with it."
" Strongly prefer brevity in your responses, and format things prettily for Slack."
f"{await get_notes_for_user(event.user, parent_app := get_parent_app()) or ''}"
f"{user_notes or ''}"
),
) as ai:
logger.debug_kv(
Expand All @@ -142,6 +153,8 @@ async def handle_message(payload: SlackPayload) -> Completed:
ai_messages = await assistant_thread.get_messages_async(
after_message=user_thread_message.id
)
CACHE.write(key=thread, value=assistant_thread.model_dump())

await task(post_slack_message)(
ai_response_text := "\n\n".join(
m.content[0].text.value for m in ai_messages
Expand All @@ -150,21 +163,22 @@ async def handle_message(payload: SlackPayload) -> Completed:
thread,
)
logger.debug_kv(
success_msg := f"Responded in {channel}/{thread}",
success_msg
:= f"Responded in {await get_channel_name(channel)}/{thread}",
ai_response_text,
"green",
)
event = emit_assistant_completed_event(
child_assistant=ai,
parent_app=parent_app,
parent_app=get_parent_app(),
payload={
"messages": await assistant_thread.get_messages_async(
json_compatible=True
),
"metadata": assistant_thread.metadata,
"user": {
"id": event.user,
"name": await get_user_name(event.user),
"name": user_name,
},
"user_message": cleaned_message,
"ai_response": ai_response_text,
Expand Down
63 changes: 35 additions & 28 deletions src/marvin/kv/json_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TypeVar
from typing import Mapping, Optional, TypeVar

try:
from prefect.blocks.system import JSON
Expand All @@ -8,55 +8,62 @@
"The `prefect` package is required to use the JSONBlockKV class."
" You can install it with `pip install prefect` or `pip install marvin[prefect]`."
)
from pydantic import Field
from pydantic import Field, PrivateAttr, model_validator

from marvin.kv.base import StorageInterface
from marvin.utilities.asyncio import run_sync
from marvin.utilities.asyncio import run_sync, run_sync_if_awaitable

K = TypeVar("K", bound=str)
V = TypeVar("V")


class JSONBlockKV(StorageInterface[K, V, str]):
"""
A key-value store that uses Prefect's JSON blocks under the hood.
"""
async def load_json_block(block_name: str) -> JSON:
try:
return await JSON.load(name=block_name)
except Exception as exc:
if "Unable to find block document" in str(exc):
json_block = JSON(value={})
await json_block.save(name=block_name)
return json_block
raise ObjectNotFound(f"Unable to load JSON block {block_name}") from exc


class JSONBlockKV(StorageInterface):
block_name: str = Field(default="marvin-kv")
_state: dict[K, Mapping] = PrivateAttr(default_factory=dict)

async def _load_json_block(self) -> JSON:
try:
return await JSON.load(name=self.block_name)
except Exception as exc:
if "Unable to find block document" in str(exc):
json_block = JSON(value={})
await json_block.save(name=self.block_name)
return json_block
raise ObjectNotFound(
f"Unable to load JSON block {self.block_name}"
) from exc
@model_validator(mode="after")
def load_state(self) -> "JSONBlockKV":
json_block = run_sync(load_json_block(self.block_name))
self._state = json_block.value or {}
return self

def write(self, key: K, value: V) -> str:
json_block = run_sync(self._load_json_block())
json_block.value[key] = value
run_sync(json_block.save(name=self.block_name, overwrite=True))
self._state[key] = value
json_block = run_sync(load_json_block(self.block_name))
json_block.value = self._state
run_sync_if_awaitable(json_block.save(name=self.block_name, overwrite=True))
return f"Stored {key}= {value}"

def delete(self, key: K) -> str:
json_block = run_sync(self._load_json_block())
if key in self._state:
self._state.pop(key, None)
json_block = run_sync(load_json_block(self.block_name))
if key in json_block.value:
json_block.value.pop(key)
run_sync(json_block.save(name=self.block_name, overwrite=True))
json_block.value = self._state
run_sync_if_awaitable(json_block.save(name=self.block_name, overwrite=True))
return f"Deleted {key}"

def read(self, key: K) -> Optional[V]:
json_block = run_sync(self._load_json_block())
json_block = run_sync(load_json_block(self.block_name))
return json_block.value.get(key)

def read_all(self, limit: Optional[int] = None) -> dict[K, V]:
json_block = run_sync(self._load_json_block())
return dict(list(json_block.value.items())[:limit])
json_block = run_sync(load_json_block(self.block_name))

limited_items = dict(list(json_block.value.items())[:limit])
return limited_items

def list_keys(self) -> list[K]:
json_block = run_sync(self._load_json_block())
json_block = run_sync(load_json_block(self.block_name))
return list(json_block.value.keys())
29 changes: 28 additions & 1 deletion src/marvin/tools/chroma.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import os
import uuid
from typing import TYPE_CHECKING, Any, Optional

try:
from chromadb import Documents, EmbeddingFunction, Embeddings, HttpClient
from chromadb import Documents, EmbeddingFunction, Embeddings, GetResult, HttpClient
except ImportError:
raise ImportError(
"The chromadb package is required to query Chroma. Please install"
Expand Down Expand Up @@ -127,3 +128,29 @@ async def multi_query_chroma(
for query in queries
]
return "\n".join(await asyncio.gather(*coros))[:max_characters]


def store_document(
document: str, metadata: dict[str, Any], collection_name: str = "glacial"
) -> GetResult:
"""Store a document in Chroma for future reference.
Args:
document: The document to store.
metadata: The metadata to store with the document.
Returns:
The stored document.
"""
collection = client.get_or_create_collection(
name=collection_name, embedding_function=OpenAIEmbeddingFunction()
)
doc_id = metadata.get("msg_id", str(uuid.uuid4()))

collection.add(
ids=[doc_id],
documents=[document],
metadatas=[metadata],
)

return collection.get(id=doc_id)
Loading

0 comments on commit 9c7fbf5

Please sign in to comment.