Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

various bug fixes #699

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions cookbook/slackbot/parent_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from contextlib import asynccontextmanager

from fastapi import FastAPI
Expand All @@ -14,6 +15,9 @@
from typing_extensions import TypedDict
from websockets.exceptions import ConnectionClosedError

PARENT_APP_STATE_BLOCK_NAME = "marvin-parent-app-state"
PARENT_APP_STATE = JSONBlockKV(block_name=PARENT_APP_STATE_BLOCK_NAME)


class Lesson(TypedDict):
relevance: confloat(ge=0, le=1)
Expand All @@ -22,9 +26,14 @@ class Lesson(TypedDict):

@ai_fn(model="gpt-3.5-turbo-1106")
def take_lesson_from_interaction(
transcript: str, assistant_instructions: str
transcript: str,
assistant_instructions: str,
observer_role: str = "data architect",
irrelevant_topics: str = "Anything not related to data engineering",
) -> Lesson:
"""You are an expert counselor, and you are teaching Marvin how to be a better assistant.
"""You are an expert {{ observer_role }}, and you are counseling an AI assistant named Marvin.
{{ irrelevant_topics }} is not relevant to Marvin's purpose, and has 0 relevance.
Here is the transcript of an interaction between Marvin and a user:
{{ transcript }}
Expand All @@ -33,7 +42,7 @@ def take_lesson_from_interaction(
{{ assistant_instructions }}
how directly relevant to the assistant's purpose is this interaction?
- if not at all, relevance = 0 & heuristic = None. (most of the time)
- if not at all, relevance = 0 & heuristic = None. (most commonly, this will be the case)
- if very, relevance >= 0.5, <1 & heuristic = "1 SHORT SENTENCE (max) summary of a generalizable lesson".
"""

Expand All @@ -57,24 +66,27 @@ def excerpt_from_event(event: Event) -> str:
async def update_parent_app_state(app: AIApplication, event: Event):
event_excerpt = excerpt_from_event(event)
lesson = take_lesson_from_interaction(
event_excerpt, event.payload.get("ai_instructions")
event_excerpt, event.payload.get("ai_instructions").split("START_USER_NOTES")[0]
)
if lesson["relevance"] >= 0.5 and lesson["heuristic"] is not None:
experience = f"transcript: {event_excerpt}\n\nlesson: {lesson['heuristic']}"
experience = (
f"transcript:\n\n{event_excerpt}\n\nlesson: {lesson['heuristic']!r}"
)
logger.debug_kv("💡 Learned lesson from excerpt", experience, "green")
await app.default_thread.add_async(experience)
logger.debug_kv("📝", "Updating parent app state", "green")
await app.default_thread.run_async(app)
else:
logger.debug_kv("🥱 ", "nothing special", "green")
user_id = event.payload.get("user").get("id")
current_user_state = await app.state.read(user_id)
await app.state.write(
user_id,
{
**current_user_state,
"n_interactions": current_user_state["n_interactions"] + 1,
},
current_user_state = app.state.read(user_id) or dict(
name=event.payload.get("user").get("name"),
n_interactions=0,
)
current_user_state["n_interactions"] += 1
app.state.write(user_id, state := current_user_state)
logger.debug_kv(
f"📋 Updated state for user {user_id} to", json.dumps(state), "green"
)


Expand All @@ -96,8 +108,9 @@ async def learn_from_child_interactions(
except Exception as e:
if isinstance(e, ConnectionClosedError):
logger.debug_kv("🚨 Connection closed, reconnecting...", "red")
else: # i know, i know
else:
logger.debug_kv("🚨", str(e), "red")
raise e


parent_assistant_options = dict(
Expand All @@ -108,10 +121,10 @@ async def learn_from_child_interactions(
" with the user's id as the key. The user id will be shown in the excerpt of the interaction."
" The user profiles (values) should include at least: {name: str, notes: list[str], n_interactions: int}."
" Keep NO MORE THAN 3 notes per user, but you may curate/update these over time for Marvin's maximum benefit."
" Notes must be 2 sentences or less, and must be concise and focused primarily on users' data engineering needs."
" Notes should not directly mention Marvin as an actor, they should be generally useful observations."
" Notes must be 2 sentences or less, must be concise and use inline markdown formatting for code and links."
" Each note should be a concrete and TECHNICAL observation related to the user's data engineering needs."
),
state=JSONBlockKV(block_name="marvin-parent-app-state"),
state=PARENT_APP_STATE,
)


Expand Down
56 changes: 35 additions & 21 deletions cookbook/slackbot/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from marvin.utilities.strings import count_tokens, slice_tokens
from parent_app import (
PARENT_APP_STATE,
emit_assistant_completed_event,
lifespan,
)
Expand All @@ -30,7 +31,7 @@
from prefect.tasks import task_input_hash

BOT_MENTION = r"<@(\w+)>"
CACHE = JSONBlockKV(block_name="slackbot-tool-cache")
CACHE = JSONBlockKV(block_name="slackbot-thread-cache")
USER_MESSAGE_MAX_TOKENS = 300


Expand All @@ -39,18 +40,24 @@ def cached(func: Callable) -> Callable:


async def get_notes_for_user(
user_id: str, parent_app: AIApplication, max_tokens: int = 100
) -> str | None:
json_notes: dict = parent_app.state.read(key=user_id)
get_logger("slackbot").debug_kv("📝 Notes for user", json_notes, "blue")
user_id: str, max_tokens: int = 100
) -> dict[str, str | None]:
user_name = await get_user_name(user_id)
json_notes: dict = PARENT_APP_STATE.read(key=user_id)
get_logger("slackbot").debug_kv(f"📝 Notes for {user_name}", json_notes, "blue")

if json_notes:
notes_template = Template(
"""
Here are some notes about {{ user_name }} (user id: {{ user_id }}):
START_USER_NOTES
Here are some notes about '{{ user_name }}' (user id: {{ user_id }}), which
are intended to help you understand their technical background and needs
- {{ user_name }} is recorded interacting with assistants {{ n_interactions }} time(s).
These notes have been passed down from previous interactions with this user -
they are strictly for your reference, and should not be shared with the user.
- They have interacted with assistants {{ n_interactions }} times.
{% if notes_content %}
Here are some notes gathered from those interactions:
{{ notes_content }}
Expand All @@ -65,14 +72,16 @@ async def get_notes_for_user(
break
notes_content += potential_addition

return notes_template.render(
notes = notes_template.render(
user_name=user_name,
user_id=user_id,
n_interactions=json_notes.get("n_interactions", 0),
notes_content=notes_content,
)

return None
return {user_name: notes}

return {user_name: None}


@flow
Expand All @@ -82,15 +91,12 @@ async def handle_message(payload: SlackPayload) -> Completed:
cleaned_message = re.sub(BOT_MENTION, "", user_message).strip()
thread = event.thread_ts or event.ts
if (count := count_tokens(cleaned_message)) > USER_MESSAGE_MAX_TOKENS:
exceeded_by = count - USER_MESSAGE_MAX_TOKENS
exceeded_amt = count - USER_MESSAGE_MAX_TOKENS
await task(post_slack_message)(
message=(
f"Your message was too long by {exceeded_by} tokens - please shorten it and try again.\n\n"
f" For reference, here's your message at the allowed limit:\n"
"> "
+ slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS).replace(
"\n", " "
)
f"Your message was too long by {exceeded_amt} tokens - please shorten it and try again."
f"\n\n For reference, here's your message at the allowed limit:\n"
f"> {slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS)}"
),
channel_id=event.channel,
thread_ts=thread,
Expand All @@ -106,7 +112,11 @@ async def handle_message(payload: SlackPayload) -> Completed:
if (stored_thread_data := CACHE.read(key=thread))
else Thread()
)
CACHE.write(key=thread, value=assistant_thread.model_dump())
logger.debug_kv(
"🧵 Thread data",
stored_thread_data or f"No stored thread data found for {thread}",
"blue",
)

await handle_keywords.submit(
message=cleaned_message,
Expand All @@ -117,6 +127,7 @@ async def handle_message(payload: SlackPayload) -> Completed:
f"{event.channel}/p{event.ts.replace('.', '')}"
),
)
user_name, user_notes = (await get_notes_for_user(user_id=event.user)).popitem()

with Assistant(
name="Marvin",
Expand All @@ -129,7 +140,7 @@ async def handle_message(payload: SlackPayload) -> Completed:
" in order to develop a coherent attempt to answer their questions. Think step-by-step."
" You must use your tools, as Prefect 2.x is new and you have no prior experience with it."
" Strongly prefer brevity in your responses, and format things prettily for Slack."
f"{await get_notes_for_user(event.user, parent_app := get_parent_app()) or ''}"
f"{user_notes or ''}"
),
) as ai:
logger.debug_kv(
Expand All @@ -142,6 +153,8 @@ async def handle_message(payload: SlackPayload) -> Completed:
ai_messages = await assistant_thread.get_messages_async(
after_message=user_thread_message.id
)
CACHE.write(key=thread, value=assistant_thread.model_dump())

await task(post_slack_message)(
ai_response_text := "\n\n".join(
m.content[0].text.value for m in ai_messages
Expand All @@ -150,21 +163,22 @@ async def handle_message(payload: SlackPayload) -> Completed:
thread,
)
logger.debug_kv(
success_msg := f"Responded in {channel}/{thread}",
success_msg
:= f"Responded in {await get_channel_name(channel)}/{thread}",
ai_response_text,
"green",
)
event = emit_assistant_completed_event(
child_assistant=ai,
parent_app=parent_app,
parent_app=get_parent_app(),
payload={
"messages": await assistant_thread.get_messages_async(
json_compatible=True
),
"metadata": assistant_thread.metadata,
"user": {
"id": event.user,
"name": await get_user_name(event.user),
"name": user_name,
},
"user_message": cleaned_message,
"ai_response": ai_response_text,
Expand Down
63 changes: 35 additions & 28 deletions src/marvin/kv/json_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TypeVar
from typing import Mapping, Optional, TypeVar

try:
from prefect.blocks.system import JSON
Expand All @@ -8,55 +8,62 @@
"The `prefect` package is required to use the JSONBlockKV class."
" You can install it with `pip install prefect` or `pip install marvin[prefect]`."
)
from pydantic import Field
from pydantic import Field, PrivateAttr, model_validator

from marvin.kv.base import StorageInterface
from marvin.utilities.asyncio import run_sync
from marvin.utilities.asyncio import run_sync, run_sync_if_awaitable

K = TypeVar("K", bound=str)
V = TypeVar("V")


class JSONBlockKV(StorageInterface[K, V, str]):
"""
A key-value store that uses Prefect's JSON blocks under the hood.
"""
async def load_json_block(block_name: str) -> JSON:
try:
return await JSON.load(name=block_name)
except Exception as exc:
if "Unable to find block document" in str(exc):
json_block = JSON(value={})
await json_block.save(name=block_name)
return json_block
raise ObjectNotFound(f"Unable to load JSON block {block_name}") from exc


class JSONBlockKV(StorageInterface):
block_name: str = Field(default="marvin-kv")
_state: dict[K, Mapping] = PrivateAttr(default_factory=dict)

async def _load_json_block(self) -> JSON:
try:
return await JSON.load(name=self.block_name)
except Exception as exc:
if "Unable to find block document" in str(exc):
json_block = JSON(value={})
await json_block.save(name=self.block_name)
return json_block
raise ObjectNotFound(
f"Unable to load JSON block {self.block_name}"
) from exc
@model_validator(mode="after")
def load_state(self) -> "JSONBlockKV":
json_block = run_sync(load_json_block(self.block_name))
self._state = json_block.value or {}
return self

def write(self, key: K, value: V) -> str:
json_block = run_sync(self._load_json_block())
json_block.value[key] = value
run_sync(json_block.save(name=self.block_name, overwrite=True))
self._state[key] = value
json_block = run_sync(load_json_block(self.block_name))
json_block.value = self._state
run_sync_if_awaitable(json_block.save(name=self.block_name, overwrite=True))
return f"Stored {key}= {value}"

def delete(self, key: K) -> str:
json_block = run_sync(self._load_json_block())
if key in self._state:
self._state.pop(key, None)
json_block = run_sync(load_json_block(self.block_name))
if key in json_block.value:
json_block.value.pop(key)
run_sync(json_block.save(name=self.block_name, overwrite=True))
json_block.value = self._state
run_sync_if_awaitable(json_block.save(name=self.block_name, overwrite=True))
return f"Deleted {key}"

def read(self, key: K) -> Optional[V]:
json_block = run_sync(self._load_json_block())
json_block = run_sync(load_json_block(self.block_name))
return json_block.value.get(key)

def read_all(self, limit: Optional[int] = None) -> dict[K, V]:
json_block = run_sync(self._load_json_block())
return dict(list(json_block.value.items())[:limit])
json_block = run_sync(load_json_block(self.block_name))

limited_items = dict(list(json_block.value.items())[:limit])
return limited_items

def list_keys(self) -> list[K]:
json_block = run_sync(self._load_json_block())
json_block = run_sync(load_json_block(self.block_name))
return list(json_block.value.keys())
29 changes: 28 additions & 1 deletion src/marvin/tools/chroma.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import os
import uuid
from typing import TYPE_CHECKING, Any, Optional

try:
from chromadb import Documents, EmbeddingFunction, Embeddings, HttpClient
from chromadb import Documents, EmbeddingFunction, Embeddings, GetResult, HttpClient
except ImportError:
raise ImportError(
"The chromadb package is required to query Chroma. Please install"
Expand Down Expand Up @@ -127,3 +128,29 @@ async def multi_query_chroma(
for query in queries
]
return "\n".join(await asyncio.gather(*coros))[:max_characters]


def store_document(
document: str, metadata: dict[str, Any], collection_name: str = "glacial"
) -> GetResult:
"""Store a document in Chroma for future reference.
Args:
document: The document to store.
metadata: The metadata to store with the document.
Returns:
The stored document.
"""
collection = client.get_or_create_collection(
name=collection_name, embedding_function=OpenAIEmbeddingFunction()
)
doc_id = metadata.get("msg_id", str(uuid.uuid4()))

collection.add(
ids=[doc_id],
documents=[document],
metadatas=[metadata],
)

return collection.get(id=doc_id)
Loading