Skip to content

Commit

Permalink
add chat context (#1108)
Browse files Browse the repository at this point in the history
* add chat context
  • Loading branch information
willydouhard authored Jun 29, 2024
1 parent 6d07ae0 commit bb0bebe
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 48 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]

Nothing unreleased.
### Added

- `cl.chat_context` to help keeping track of the messages of the current thread

### Fixed

- Message are now collapsible if too long
- The Langchain callback handler should better capture chain runs
- The Llama Index callback handler should now work with other decorators

## [1.1.305] - 2024-06-26

Expand Down
5 changes: 3 additions & 2 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
from chainlit.openai import instrument_openai
from chainlit.mistralai import instrument_mistralai

from literalai import ChatGeneration, CompletionGeneration, GenerationMessage

import chainlit.input_widget as input_widget
from chainlit.action import Action
from chainlit.cache import cache
from chainlit.chat_context import chat_context
from chainlit.chat_settings import ChatSettings
from chainlit.config import config
from chainlit.context import context
Expand Down Expand Up @@ -60,6 +59,7 @@
from chainlit.user_session import user_session
from chainlit.utils import make_module_getattr, wrap_user_function
from chainlit.version import __version__
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage

if env_found:
logger.info("Loaded .env file")
Expand Down Expand Up @@ -370,6 +370,7 @@ def acall(self):
"ChatProfile",
"Starter",
"user_session",
"chat_context",
"CopilotFunction",
"AudioChunk",
"Action",
Expand Down
61 changes: 61 additions & 0 deletions backend/chainlit/chat_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING, Dict, List

from chainlit.context import context

if TYPE_CHECKING:
from chainlit.message import Message

chat_contexts: Dict[str, List["Message"]] = {}


class ChatContext:
def get(self) -> List["Message"]:
if not context.session:
return []

if context.session.id not in chat_contexts:
# Create a new chat context
chat_contexts[context.session.id] = []

return chat_contexts[context.session.id]

def add(self, message: "Message") -> None:
if not context.session:
return

if context.session.id not in chat_contexts:
chat_contexts[context.session.id] = []

chat_contexts[context.session.id].append(message)

def remove(self, message: "Message") -> bool:
if not context.session:
return False

if context.session.id not in chat_contexts:
return False

if message in chat_contexts[context.session.id]:
chat_contexts[context.session.id].remove(message)
return True

return False

def clear(self) -> None:
if context.session and context.session.id in chat_contexts:
chat_contexts[context.session.id] = []

def to_openai(self):
messages = []
for message in self.get():
if message.type == "assistant_message":
messages.append({"role": "assistant", "content": message.content})
elif message.type == "user_message":
messages.append({"role": "user", "content": message.content})
else:
messages.append({"role": "system", "content": message.content})

return messages


chat_context = ChatContext()
6 changes: 2 additions & 4 deletions backend/chainlit/discord/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __init__(self, step_id: str):
async def thumbs_down(self, interaction: discord.Interaction, button: Button):
if data_layer := get_data_layer():
try:
thread_id = context_var.get().session.thread_id
feedback = Feedback(forId=self.step_id, threadId=thread_id, value=0)
feedback = Feedback(forId=self.step_id, value=0)
await data_layer.upsert_feedback(feedback)
except Exception as e:
logger.error(f"Error upserting feedback: {e}")
Expand All @@ -47,8 +46,7 @@ async def thumbs_down(self, interaction: discord.Interaction, button: Button):
async def thumbs_up(self, interaction: discord.Interaction, button: Button):
if data_layer := get_data_layer():
try:
thread_id = context_var.get().session.thread_id
feedback = Feedback(forId=self.step_id, threadId=thread_id, value=1)
feedback = Feedback(forId=self.step_id, value=1)
await data_layer.upsert_feedback(feedback)
except Exception as e:
logger.error(f"Error upserting feedback: {e}")
Expand Down
4 changes: 3 additions & 1 deletion backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import uuid
from typing import Any, Dict, List, Literal, Optional, Union, cast

from chainlit.chat_context import chat_context
from chainlit.config import config
from chainlit.data import get_data_layer
from chainlit.element import Element, ElementDict, File
from chainlit.logger import logger
from chainlit.message import Message
from chainlit.session import BaseSession, HTTPSession, WebsocketSession
from chainlit.session import BaseSession, WebsocketSession
from chainlit.step import StepDict
from chainlit.types import (
AskActionResponse,
Expand Down Expand Up @@ -220,6 +221,7 @@ async def process_message(self, payload: MessagePayload):
message = Message.from_dict(step_dict)
# Overwrite the created_at timestamp with the current time
message.created_at = utc_now()
chat_context.add(message)

asyncio.create_task(message._create())

Expand Down
6 changes: 2 additions & 4 deletions backend/chainlit/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,8 @@ def _start_trace(self, run: Run) -> None:
if run.run_type == "agent":
step_type = "run"
elif run.run_type == "chain":
pass
if not self.steps:
step_type = "run"
elif run.run_type == "llm":
step_type = "llm"
elif run.run_type == "retriever":
Expand All @@ -462,9 +463,6 @@ def _start_trace(self, run: Run) -> None:
elif run.run_type == "embedding":
step_type = "embedding"

if not self.steps and step_type != "llm":
step_type = "run"

disable_feedback = not self._is_annotable(run)

step = Step(
Expand Down
40 changes: 13 additions & 27 deletions backend/chainlit/llama_index/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,22 @@ def __init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
self.context = context_var.get()

self.steps = {}

def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
if event_parent_id and event_parent_id in self.steps:
return event_parent_id
elif self.context.current_step:
return self.context.current_step.id
elif self.context.session.root_message:
return self.context.session.root_message.id
elif context_var.get().current_step:
return context_var.get().current_step.id
elif context_var.get().session.root_message:
root_message = context_var.get().session.root_message
if root_message:
return root_message.id
return None
else:
return None

def _restore_context(self) -> None:
"""Restore Chainlit context in the current thread
Chainlit context is local to the main thread, and LlamaIndex
runs the callbacks in its own threads, so they don't have a
Chainlit context by default.
This method restores the context in which the callback handler
has been created (it's always created in the main thread), so
that we can actually send messages.
"""
context_var.set(self.context)

def on_event_start(
self,
event_type: CBEventType,
Expand All @@ -69,8 +58,6 @@ def on_event_start(
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
self._restore_context()

step_type: StepType = "undefined"
if event_type == CBEventType.RETRIEVE:
step_type = "tool"
Expand All @@ -88,10 +75,11 @@ def on_event_start(
id=event_id,
disable_feedback=True,
)

self.steps[event_id] = step
step.start = utc_now()
step.input = payload or {}
self.context.loop.create_task(step.send())
context_var.get().loop.create_task(step.send())
return event_id

def on_event_end(
Expand All @@ -107,8 +95,6 @@ def on_event_end(
if payload is None or step is None:
return

self._restore_context()

step.end = utc_now()

if event_type == CBEventType.QUERY:
Expand All @@ -127,7 +113,7 @@ def on_event_end(
for idx, source in enumerate(source_nodes)
]
step.output = f"Retrieved the following sources: {source_refs}"
self.context.loop.create_task(step.update())
context_var.get().loop.create_task(step.update())

elif event_type == CBEventType.RETRIEVE:
sources = payload.get(EventPayload.NODES)
Expand All @@ -144,7 +130,7 @@ def on_event_end(
for idx, source in enumerate(sources)
]
step.output = f"Retrieved the following sources: {source_refs}"
self.context.loop.create_task(step.update())
context_var.get().loop.create_task(step.update())

elif event_type == CBEventType.LLM:
formatted_messages = payload.get(
Expand Down Expand Up @@ -195,11 +181,11 @@ def on_event_end(
token_count=token_count,
)

self.context.loop.create_task(step.update())
context_var.get().loop.create_task(step.update())

else:
step.output = payload
self.context.loop.create_task(step.update())
context_var.get().loop.create_task(step.update())

self.steps.pop(event_id, None)

Expand Down
4 changes: 3 additions & 1 deletion backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, List, Optional, Union, cast

from chainlit.action import Action
from chainlit.chat_context import chat_context
from chainlit.config import config
from chainlit.context import context, local_steps
from chainlit.data import get_data_layer
Expand Down Expand Up @@ -127,7 +128,7 @@ async def remove(self):
Remove a message already sent to the UI.
"""
trace_event("remove_message")

chat_context.remove(self)
step_dict = self.to_dict()
data_layer = get_data_layer()
if data_layer:
Expand Down Expand Up @@ -169,6 +170,7 @@ async def send(self):
self.streaming = False

step_dict = await self._create()
chat_context.add(self)
await context.emitter.send_step(step_dict)

return self
Expand Down
8 changes: 3 additions & 5 deletions backend/chainlit/slack/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ async def handle_app_mentions(event, say):
async def handle_message(message, say):
user = await get_user(message["user"])
thread_name = f"{user.identifier} Slack DM"
await process_slack_message(message, say, thread_name)
await process_slack_message(message, say, thread_name, True)


@slack_app.block_action("thumbdown")
Expand All @@ -341,8 +341,7 @@ async def thumb_down(ack, context, body):
step_id = body["actions"][0]["value"]

if data_layer := get_data_layer():
thread_id = context_var.get().session.thread_id
feedback = Feedback(forId=step_id, threadId=thread_id, value=0)
feedback = Feedback(forId=step_id, value=0)
await data_layer.upsert_feedback(feedback)

text = body["message"]["text"]
Expand All @@ -368,8 +367,7 @@ async def thumb_up(ack, context, body):
step_id = body["actions"][0]["value"]

if data_layer := get_data_layer():
thread_id = context_var.get().session.thread_id
feedback = Feedback(forId=step_id, threadId=thread_id, value=1)
feedback = Feedback(forId=step_id, value=1)
await data_layer.upsert_feedback(feedback)

text = body["message"]["text"]
Expand Down
6 changes: 6 additions & 0 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from chainlit.auth import get_current_user, require_login
from chainlit.config import config
from chainlit.context import init_ws_context
from chainlit.chat_context import chat_context
from chainlit.data import get_data_layer
from chainlit.element import Element
from chainlit.logger import logger
Expand Down Expand Up @@ -183,6 +184,11 @@ async def connection_successful(sid):
{"interaction": "resume", "thread_id": thread.get("id")},
)
await config.code.on_chat_resume(thread)

for step in thread.get("steps", []):
if "message" in step["type"]:
chat_context.add(Message.from_dict(step))

await context.emitter.resume_thread(thread)
return

Expand Down
1 change: 0 additions & 1 deletion backend/chainlit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from chainlit.step import StepDict

from dataclasses_json import DataClassJsonMixin
from literalai import ChatGeneration, CompletionGeneration
from pydantic import BaseModel
from pydantic.dataclasses import dataclass

Expand Down
8 changes: 8 additions & 0 deletions cypress/e2e/chat_context/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import chainlit as cl


@cl.on_message
async def main():
await cl.Message(
content=f"Chat context length: {len(cl.chat_context.get())}"
).send()
17 changes: 17 additions & 0 deletions cypress/e2e/chat_context/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { runTestServer, submitMessage } from '../../support/testUtils';

describe('Chat Context', () => {
before(() => {
runTestServer();
});

it('should be able to store data related per user session', () => {
submitMessage('Hello 1');

cy.get('.step').eq(1).should('contain', 'Chat context length: 1');

submitMessage('Hello 2');

cy.get('.step').eq(3).should('contain', 'Chat context length: 3');
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ const MessageContent = memo(
);

const collapse =
!message.type.includes('message') &&
(lineCount > COLLAPSE_MIN_LINES || contentLength > COLLAPSE_MIN_LENGTH);
lineCount > COLLAPSE_MIN_LINES || contentLength > COLLAPSE_MIN_LENGTH;
const messageContent = collapse ? (
<Collapse defaultExpandAll={preserveSize}>{markdownContent}</Collapse>
) : (
Expand Down

0 comments on commit bb0bebe

Please sign in to comment.