Skip to content

Commit

Permalink
make auto tag thread opt in (#927)
Browse files Browse the repository at this point in the history
  • Loading branch information
willydouhard authored Apr 22, 2024
1 parent b6b4ef4 commit 8695e94
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 21 deletions.
10 changes: 7 additions & 3 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
# Process and display mathematical expressions. This can clash with "$" characters in messages.
latex = false
# Automatically tag threads with the current chat profile (if a chat profile is used)
auto_tag_thread = true
# Authorize users to upload files with messages
[features.multi_modal]
enabled = true
Expand Down Expand Up @@ -206,6 +209,7 @@ class FeaturesSettings(DataClassJsonMixin):
latex: bool = False
unsafe_allow_html: bool = False
speech_to_text: Optional[SpeechToTextFeature] = None
auto_tag_thread: bool = True


@dataclass()
Expand Down Expand Up @@ -246,9 +250,9 @@ class CodeSettings:
on_message: Optional[Callable[[str], Any]] = None
author_rename: Optional[Callable[[str], str]] = None
on_settings_update: Optional[Callable[[Dict[str, Any]], Any]] = None
set_chat_profiles: Optional[
Callable[[Optional["User"]], List["ChatProfile"]]
] = None
set_chat_profiles: Optional[Callable[[Optional["User"]], List["ChatProfile"]]] = (
None
)


@dataclass()
Expand Down
43 changes: 36 additions & 7 deletions backend/chainlit/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,36 @@
import json
import os
from collections import deque
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast, Protocol, Any
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
cast,
)

import aiofiles
from chainlit.config import config
from chainlit.context import context
from chainlit.logger import logger
from chainlit.session import WebsocketSession
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter, PageInfo, PaginatedResponse
from chainlit.types import (
Feedback,
PageInfo,
PaginatedResponse,
Pagination,
ThreadDict,
ThreadFilter,
)
from chainlit.user import PersistedUser, User
from literalai import Attachment, PaginatedResponse as LiteralPaginatedResponse, Score as LiteralScore, Step as LiteralStep
from literalai import Attachment
from literalai import PaginatedResponse as LiteralPaginatedResponse
from literalai import Score as LiteralScore
from literalai import Step as LiteralStep
from literalai.filter import threads_filters as LiteralThreadsFilters
from literalai.step import StepDict as LiteralStepDict

Expand Down Expand Up @@ -226,7 +246,7 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
return PersistedUser(
id=_user.id or "",
identifier=_user.identifier or "",
metadata=_user.metadata,
metadata=user.metadata,
createdAt=_user.created_at or "",
)

Expand Down Expand Up @@ -421,8 +441,8 @@ async def list_threads(
pageInfo=PageInfo(
hasNextPage=literal_response.pageInfo.hasNextPage,
startCursor=literal_response.pageInfo.startCursor,
endCursor=literal_response.pageInfo.endCursor
),
endCursor=literal_response.pageInfo.endCursor,
),
data=literal_response.data,
)

Expand Down Expand Up @@ -470,11 +490,20 @@ async def update_thread(
tags=tags,
)


class BaseStorageClient(Protocol):
"""Base class for non-text data persistence like Azure Data Lake, S3, Google Storage, etc."""
async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:

async def upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> Dict[str, Any]:
pass


if api_key := os.environ.get("LITERAL_API_KEY"):
# support legacy LITERAL_SERVER variable as fallback
server = os.environ.get("LITERAL_API_URL") or os.environ.get("LITERAL_SERVER")
Expand Down
20 changes: 12 additions & 8 deletions backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import uuid
from typing import Any, Dict, List, Literal, Optional, Union, cast

from chainlit.config import config
from chainlit.data import get_data_layer
from chainlit.element import Element, File
from chainlit.logger import logger
Expand Down Expand Up @@ -175,15 +176,18 @@ async def flush_thread_queues(self, interaction: str):
else:
user_id = None
try:
tags = (
[self.session.chat_profile] if self.session.chat_profile else None
should_tag_thread = (
self.session.chat_profile and config.features.auto_tag_thread
)
tags = [self.session.chat_profile] if should_tag_thread else None
asyncio.create_task(
data_layer.update_thread(
thread_id=self.session.thread_id,
name=interaction,
user_id=user_id,
tags=tags,
)
)
asyncio.create_task(data_layer.update_thread(
thread_id=self.session.thread_id,
name=interaction,
user_id=user_id,
tags=tags,
))
except Exception as e:
logger.error(f"Error updating thread: {e}")
asyncio.create_task(self.session.flush_method_queue())
Expand Down
23 changes: 22 additions & 1 deletion cypress/e2e/data_layer/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional

import chainlit.data as cl_data
from chainlit.step import StepDict
Expand Down Expand Up @@ -69,6 +69,23 @@ async def get_user(self, identifier: str):
async def create_user(self, user: cl.User):
return cl.PersistedUser(id="test", createdAt=now, identifier=user.identifier)

async def update_thread(
self,
thread_id: str,
name: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
):
thread = next((t for t in thread_history if t["id"] == "test2"), None)
if thread:
if name:
thread["name"] = name
if metadata:
thread["metadata"] = metadata
if tags:
thread["tags"] = tags

@cl_data.queue_until_user_message()
async def create_step(self, step_dict: StepDict):
global create_step_counter
Expand Down Expand Up @@ -131,3 +148,7 @@ def auth_callback(username: str, password: str) -> Optional[cl.User]:
@cl.on_chat_resume
async def on_chat_resume(thread: cl_data.ThreadDict):
await cl.Message(f"Welcome back to {thread['name']}").send()
if "metadata" in thread:
await cl.Message(thread["metadata"], author="metadata", language="json").send()
if "tags" in thread:
await cl.Message(thread["tags"], author="tags", language="json").send()
7 changes: 5 additions & 2 deletions cypress/e2e/data_layer/spec.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ function resumeThread() {
cy.get('#resumeThread').click();
cy.get(`#chat-input`).should('exist');

cy.get('.step').should('have.length', 3);
cy.get('.step').should('have.length', 4);

cy.get('.step').eq(0).should('contain', 'Message 3');
cy.get('.step').eq(1).should('contain', 'Message 4');
cy.get('.step').eq(2).should('contain', 'Welcome back to thread 2');
// Thread name should be renamed with first interaction
cy.get('.step').eq(2).should('contain', 'Welcome back to Hello');
cy.get('.step').eq(3).should('contain', 'metadata');
cy.get('.step').eq(3).should('contain', 'chat_profile');
}

describe('Data Layer', () => {
Expand Down

0 comments on commit 8695e94

Please sign in to comment.