From 1ced1e449dba89802cef786a9aeca7ea36118015 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Wed, 27 Mar 2024 11:55:42 +0100 Subject: [PATCH 1/9] migrate to literal score --- backend/chainlit/data/__init__.py | 98 +++++++++++-------- backend/chainlit/server.py | 21 +++- backend/chainlit/types.py | 15 +-- backend/pyproject.toml | 4 +- .../messages/components/FeedbackButtons.tsx | 46 ++++++--- .../organisms/chat/Messages/container.tsx | 7 ++ .../organisms/chat/Messages/index.tsx | 29 ++++++ .../organisms/threadHistory/Thread.tsx | 37 +++++++ .../sidebar/filters/FeedbackSelect.tsx | 19 ++-- frontend/src/types/messageContext.ts | 5 + libs/copilot/src/chat/messages/container.tsx | 7 ++ libs/copilot/src/chat/messages/index.tsx | 27 +++++ libs/react-client/src/api/index.tsx | 8 ++ libs/react-client/src/types/feedback.ts | 1 - 14 files changed, 250 insertions(+), 74 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 320cb13ba9..7c74122b05 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -2,7 +2,7 @@ import json import os from collections import deque -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, Literal, cast import aiofiles from chainlit.config import config @@ -11,13 +11,9 @@ from chainlit.session import WebsocketSession from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter from chainlit.user import PersistedUser, User, UserDict -from literalai import Attachment -from literalai import Feedback as ClientFeedback -from literalai import PageInfo, PaginatedResponse -from literalai import Step as ClientStep -from literalai.step import StepDict as ClientStepDict -from literalai.thread import NumberListFilter, StringFilter, StringListFilter -from literalai.thread import ThreadFilter as ClientThreadFilter +from literalai import Score as LiteralScore, PageInfo, PaginatedResponse, Attachment, Step as LiteralStep +from literalai.step import StepDict as LiteralStepDict +from literalai.filter import threads_filters as LiteralThreadsFilters if TYPE_CHECKING: from chainlit.element import Element, ElementDict @@ -57,6 +53,13 @@ async def get_user(self, identifier: str) -> Optional["PersistedUser"]: async def create_user(self, user: "User") -> Optional["PersistedUser"]: pass + async def delete_feedback( + self, + feedback_id: str, + ) -> bool: + return True + + async def upsert_feedback( self, feedback: Feedback, @@ -98,7 +101,7 @@ async def list_threads( self, pagination: "Pagination", filters: "ThreadFilter" ) -> "PaginatedResponse[ThreadDict]": return PaginatedResponse( - data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None) + data=[], pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None) ) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": @@ -146,20 +149,19 @@ def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": "threadId": attachment.thread_id, } - def feedback_to_feedback_dict( - self, feedback: Optional[ClientFeedback] + def score_to_feedback_dict( + self, score: Optional[LiteralScore] ) -> "Optional[FeedbackDict]": - if not feedback: + if not score: return None return { - "id": feedback.id or "", - "forId": feedback.step_id or "", - "value": feedback.value or 0, # type: ignore - "comment": feedback.comment, - "strategy": "BINARY", + "id": score.id or "", + "forId": score.step_id or "", + "value": cast(Literal[0, 1], score.value), + "comment": score.comment, } - def step_to_step_dict(self, step: ClientStep) -> "StepDict": + def step_to_step_dict(self, step: LiteralStep) -> "StepDict": metadata = step.metadata or {} input = (step.input or {}).get("content") or ( json.dumps(step.input) if step.input and step.input != {} else "" @@ -167,12 +169,15 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict": output = (step.output or {}).get("content") or ( json.dumps(step.output) if step.output and step.output != {} else "" ) + + user_feedback = next((s for s in step.scores if s.type == "HUMAN" and s.name == "user-feedback"), None) if step.scores else None + return { "createdAt": step.created_at, "id": step.id or "", "threadId": step.thread_id or "", "parentId": step.parent_id, - "feedback": self.feedback_to_feedback_dict(step.feedback), + "feedback": self.score_to_feedback_dict(user_feedback), "start": step.start_time, "end": step.end_time, "type": step.type or "undefined", @@ -186,7 +191,6 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict": "language": metadata.get("language"), "isError": metadata.get("isError", False), "waitForAnswer": metadata.get("waitForAnswer", False), - "feedback": self.feedback_to_feedback_dict(step.feedback), } async def get_user(self, identifier: str) -> Optional[PersistedUser]: @@ -215,26 +219,38 @@ async def create_user(self, user: User) -> Optional[PersistedUser]: createdAt=_user.created_at or "", ) + async def delete_feedback( + self, + feedback_id: str, + ): + if feedback_id: + await self.client.api.delete_score( + id=feedback_id, + ) + return True + return False + + async def upsert_feedback( self, feedback: Feedback, ): if feedback.id: - await self.client.api.update_feedback( + await self.client.api.update_score( id=feedback.id, update_params={ "comment": feedback.comment, - "strategy": feedback.strategy, "value": feedback.value, }, ) return feedback.id else: - created = await self.client.api.create_feedback( + created = await self.client.api.create_score( step_id=feedback.forId, value=feedback.value, comment=feedback.comment, - strategy=feedback.strategy, + name="user-feedback", + type="HUMAN", ) return created.id or "" @@ -307,7 +323,7 @@ async def create_step(self, step_dict: "StepDict"): "showInput": step_dict.get("showInput"), } - step: ClientStepDict = { + step: LiteralStepDict = { "createdAt": step_dict.get("createdAt"), "startTime": step_dict.get("start"), "endTime": step_dict.get("end"), @@ -349,22 +365,26 @@ async def delete_thread(self, thread_id: str): async def list_threads( self, pagination: "Pagination", filters: "ThreadFilter" ) -> "PaginatedResponse[ThreadDict]": - if not filters.userIdentifier: - raise ValueError("userIdentifier is required") - - client_filters = ClientThreadFilter( - participantsIdentifier=StringListFilter( - operator="in", value=[filters.userIdentifier] - ), - ) + if not filters.userId: + raise ValueError("userId is required") + + literal_filters: LiteralThreadsFilters = [ + { + "field": "participantId", + "operator": "eq", + "value": filters.userId, + } + ] + if filters.search: - client_filters.search = StringFilter(operator="ilike", value=filters.search) - if filters.feedback: - client_filters.feedbacksValue = NumberListFilter( - operator="in", value=[filters.feedback] - ) + literal_filters.append({"field": "stepOutput", "operator": "ilike", "value": filters.search, "path": "content"}) + + + if filters.feedback is not None: + literal_filters.append({"field": "scoreValue", "operator": "eq", "value": filters.feedback, "path": "user-feedback"}) + return await self.client.api.list_threads( - first=pagination.first, after=pagination.cursor, filters=client_filters + first=pagination.first, after=pagination.cursor, filters=literal_filters, order_by={"column": "createdAt", "direction": "DESC"} ) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 37646fd7f8..8f74748953 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -41,6 +41,7 @@ GetThreadsRequest, Theme, UpdateFeedbackRequest, + DeleteFeedbackRequest, ) from chainlit.user import PersistedUser, User from fastapi import ( @@ -551,6 +552,24 @@ async def update_feedback( return JSONResponse(content={"success": True, "feedbackId": feedback_id}) +@app.delete("/feedback") +async def delete_feedback( + request: Request, + payload: DeleteFeedbackRequest, + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], +): + """Delete a feedback.""" + + data_layer = get_data_layer() + + if not data_layer: + raise HTTPException(status_code=400, detail="Data persistence is not enabled") + + feedback_id = payload.feedbackId + + await data_layer.delete_feedback(feedback_id) + return JSONResponse(content={"success": True}) + @app.post("/project/threads") async def get_user_threads( @@ -566,7 +585,7 @@ async def get_user_threads( if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") - payload.filter.userIdentifier = current_user.identifier + payload.filter.userId = current_user.id res = await data_layer.list_threads(payload.pagination, payload.filter) return JSONResponse(content=res.to_dict()) diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index 28058bff28..56be377ebe 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -33,8 +33,8 @@ class Pagination(BaseModel): class ThreadFilter(BaseModel): - feedback: Optional[Literal[-1, 0, 1]] = None - userIdentifier: Optional[str] = None + feedback: Optional[Literal[0, 1]] = None + userId: Optional[str] = None search: Optional[str] = None @@ -122,6 +122,9 @@ def is_chat(self): class DeleteThreadRequest(BaseModel): threadId: str +class DeleteFeedbackRequest(BaseModel): + feedbackId: str + class GetThreadsRequest(BaseModel): pagination: Pagination @@ -146,16 +149,16 @@ class ChatProfile(DataClassJsonMixin): class FeedbackDict(TypedDict): - value: Literal[-1, 0, 1] - strategy: FeedbackStrategy + forId: str + id: Optional[str] + value: Literal[0, 1] comment: Optional[str] @dataclass class Feedback: forId: str - value: Literal[-1, 0, 1] - strategy: FeedbackStrategy = "BINARY" + value: Literal[0, 1] id: Optional[str] = None comment: Optional[str] = None diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a0a6ba0c36..918e74fe9e 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainlit" -version = "1.0.401" +version = "1.0.500" keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'openai', 'copilot', 'langchain', 'conversational ai'] description = "Build Conversational AI." authors = ["Chainlit"] @@ -23,7 +23,7 @@ chainlit = 'chainlit.cli:cli' [tool.poetry.dependencies] python = ">=3.8.1,<4.0.0" httpx = ">=0.23.0" -literalai = "0.0.300" +literalai = "0.0.400" dataclasses_json = "^0.5.7" fastapi = ">=0.100" # Starlette >= 0.33.0 breaks socketio (alway 404) diff --git a/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx b/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx index 608cb20d08..e075406c2a 100644 --- a/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx +++ b/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx @@ -1,6 +1,7 @@ import { MessageContext } from 'contexts/MessageContext'; import { useContext, useState } from 'react'; import { useMemo } from 'react'; +import { useRecoilValue } from 'recoil'; import StickyNote2Outlined from '@mui/icons-material/StickyNote2Outlined'; import ThumbDownAlt from '@mui/icons-material/ThumbDownAlt'; @@ -11,6 +12,8 @@ import IconButton from '@mui/material/IconButton'; import Stack from '@mui/material/Stack'; import Tooltip from '@mui/material/Tooltip'; +import { firstUserInteraction } from '@chainlit/react-client'; + import Dialog from 'components/atoms/Dialog'; import { AccentButton } from 'components/atoms/buttons/AccentButton'; import { TextInput } from 'components/atoms/inputs'; @@ -24,18 +27,30 @@ interface Props { } const FeedbackButtons = ({ message }: Props) => { - const { onFeedbackUpdated } = useContext(MessageContext); + const { onFeedbackUpdated, onFeedbackDeleted } = useContext(MessageContext); const [showFeedbackDialog, setShowFeedbackDialog] = useState(); const [commentInput, setCommentInput] = useState(); + const firstInteraction = useRecoilValue(firstUserInteraction); - const [feedback, setFeedback] = useState(message.feedback?.value || 0); + const [feedback, setFeedback] = useState(message.feedback?.value); const [comment, setComment] = useState(message.feedback?.comment); - const DownIcon = feedback === -1 ? ThumbDownAlt : ThumbDownAltOutlined; + const DownIcon = feedback === 0 ? ThumbDownAlt : ThumbDownAltOutlined; const UpIcon = feedback === 1 ? ThumbUpAlt : ThumbUpAltOutlined; - const handleFeedbackChanged = (feedback: number, comment?: string) => { - onFeedbackUpdated && + const handleFeedbackChanged = (feedback?: number, comment?: string) => { + if (feedback === undefined) { + if (onFeedbackDeleted && message.feedback?.id) { + onFeedbackDeleted( + message, + () => { + setFeedback(undefined); + setComment(undefined); + }, + message.feedback.id + ); + } + } else if (onFeedbackUpdated) { onFeedbackUpdated( message, () => { @@ -43,23 +58,24 @@ const FeedbackButtons = ({ message }: Props) => { setComment(comment); }, { - ...(message.feedback || { strategy: 'BINARY' }), + ...(message.feedback || {}), forId: message.id, value: feedback, comment } ); + } }; - const handleFeedbackClick = (status: number) => { - if (feedback === status) { - handleFeedbackChanged(0); + const handleFeedbackClick = (nextValue: number) => { + if (feedback === nextValue) { + handleFeedbackChanged(undefined); } else { - setShowFeedbackDialog(status); + setShowFeedbackDialog(nextValue); } }; - const disabled = !!message.streaming; + const disabled = !!message.streaming || !firstInteraction; const buttons = useMemo(() => { const iconSx = { @@ -92,7 +108,7 @@ const FeedbackButtons = ({ message }: Props) => { disabled={disabled} className={`negative-feedback-${feedback === -1 ? 'on' : 'off'}`} onClick={() => { - handleFeedbackClick(-1); + handleFeedbackClick(0); }} > @@ -138,11 +154,11 @@ const FeedbackButtons = ({ message }: Props) => { onClose={() => { setShowFeedbackDialog(undefined); }} - open={!!showFeedbackDialog} + open={showFeedbackDialog !== undefined} title={ - {showFeedbackDialog === -1 ? : } - Provide additional feedback + {showFeedbackDialog === 0 ? : } + Add a comment } content={ diff --git a/frontend/src/components/organisms/chat/Messages/container.tsx b/frontend/src/components/organisms/chat/Messages/container.tsx index a58f5b1617..9c937a3e7a 100644 --- a/frontend/src/components/organisms/chat/Messages/container.tsx +++ b/frontend/src/components/organisms/chat/Messages/container.tsx @@ -36,6 +36,11 @@ interface Props { onSuccess: () => void, feedback: IFeedback ) => void; + onFeedbackDeleted: ( + message: IStep, + onSuccess: () => void, + feedback: string + ) => void; callAction?: (action: IAction) => void; setAutoScroll?: (autoScroll: boolean) => void; } @@ -50,6 +55,7 @@ const MessageContainer = memo( elements, messages, onFeedbackUpdated, + onFeedbackDeleted, callAction, setAutoScroll }: Props) => { @@ -164,6 +170,7 @@ const MessageContainer = memo( onElementRefClick, onError, onFeedbackUpdated, + onFeedbackDeleted, onPlaygroundButtonClick }; }, [ diff --git a/frontend/src/components/organisms/chat/Messages/index.tsx b/frontend/src/components/organisms/chat/Messages/index.tsx index 7ff1d8ec28..0ac12ea380 100644 --- a/frontend/src/components/organisms/chat/Messages/index.tsx +++ b/frontend/src/components/organisms/chat/Messages/index.tsx @@ -106,6 +106,34 @@ const Messages = ({ [] ); + const onFeedbackDeleted = useCallback( + async (message: IStep, onSuccess: () => void, feedbackId: string) => { + try { + toast.promise(apiClient.deleteFeedback(feedbackId, accessToken), { + loading: t('components.organisms.chat.Messages.index.updating'), + success: () => { + setMessages((prev) => + updateMessageById(prev, message.id, { + ...message, + feedback: undefined + }) + ); + onSuccess(); + return t( + 'components.organisms.chat.Messages.index.feedbackUpdated' + ); + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); + return !idToResume && !messages.length && projectSettings?.ui.show_readme_as_default ? ( @@ -125,6 +153,7 @@ const Messages = ({ messages={messages} autoScroll={autoScroll} onFeedbackUpdated={onFeedbackUpdated} + onFeedbackDeleted={onFeedbackDeleted} callAction={callActionWithToast} setAutoScroll={setAutoScroll} /> diff --git a/frontend/src/components/organisms/threadHistory/Thread.tsx b/frontend/src/components/organisms/threadHistory/Thread.tsx index 8063224245..2f113a5397 100644 --- a/frontend/src/components/organisms/threadHistory/Thread.tsx +++ b/frontend/src/components/organisms/threadHistory/Thread.tsx @@ -1,4 +1,5 @@ import { useCallback, useEffect, useState } from 'react'; +import { useTranslation } from 'react-i18next'; import { Link } from 'react-router-dom'; import { useRecoilValue } from 'recoil'; import { toast } from 'sonner'; @@ -31,6 +32,7 @@ const Thread = ({ thread, error, isLoading }: Props) => { const accessToken = useRecoilValue(accessTokenState); const [steps, setSteps] = useState([]); const apiClient = useRecoilValue(apiClientState); + const { t } = useTranslation(); useEffect(() => { if (!thread) return; @@ -72,6 +74,40 @@ const Thread = ({ thread, error, isLoading }: Props) => { [] ); + const onFeedbackDeleted = useCallback( + async (message: IStep, onSuccess: () => void, feedbackId: string) => { + try { + toast.promise(apiClient.deleteFeedback(feedbackId, accessToken), { + loading: t('components.organisms.chat.Messages.index.updating'), + success: () => { + setSteps((prev) => + prev.map((step) => { + if (step.id === message.id) { + return { + ...step, + feedback: undefined + }; + } + return step; + }) + ); + + onSuccess(); + return t( + 'components.organisms.chat.Messages.index.feedbackUpdated' + ); + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); + if (isLoading) { return ( <> @@ -150,6 +186,7 @@ const Thread = ({ thread, error, isLoading }: Props) => { actions={actions} elements={(elements || []) as IMessageElement[]} onFeedbackUpdated={onFeedbackUpdated} + onFeedbackDeleted={onFeedbackDeleted} messages={messages} autoScroll={true} /> diff --git a/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx b/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx index 53f0bb020a..e6ce544f8a 100644 --- a/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx @@ -13,10 +13,9 @@ import Stack from '@mui/material/Stack'; import { threadsFiltersState } from 'state/threads'; -export enum FEEDBACKS { - ALL = 0, +export enum Feedback { POSITIVE = 1, - NEGATIVE = -1 + NEGATIVE = 0 } export default function FeedbackSelect() { @@ -25,12 +24,12 @@ export default function FeedbackSelect() { const { t } = useTranslation(); - const handleChange = (feedback: number) => { + const handleChange = (feedback?: number) => { setFilters((prev) => ({ ...prev, feedback })); setAnchorEl(null); }; - const renderMenuItem = (label: string, feedback: number) => { + const renderMenuItem = (label: string, feedback?: number) => { return ( handleChange(feedback)} @@ -53,9 +52,9 @@ export default function FeedbackSelect() { const sx = { width: 16, height: 16 }; switch (filters.feedback) { - case FEEDBACKS.POSITIVE: + case Feedback.POSITIVE: return ; - case FEEDBACKS.NEGATIVE: + case Feedback.NEGATIVE: return ; default: return ; @@ -102,19 +101,19 @@ export default function FeedbackSelect() { t( 'components.organisms.threadHistory.sidebar.filters.FeedbackSelect.feedbackAll' ), - FEEDBACKS.ALL + undefined )} {renderMenuItem( t( 'components.organisms.threadHistory.sidebar.filters.FeedbackSelect.feedbackPositive' ), - FEEDBACKS.POSITIVE + Feedback.POSITIVE )} {renderMenuItem( t( 'components.organisms.threadHistory.sidebar.filters.FeedbackSelect.feedbackNegative' ), - FEEDBACKS.NEGATIVE + Feedback.NEGATIVE )} diff --git a/frontend/src/types/messageContext.ts b/frontend/src/types/messageContext.ts index 34f272a4d0..a5d436a6f5 100644 --- a/frontend/src/types/messageContext.ts +++ b/frontend/src/types/messageContext.ts @@ -30,6 +30,11 @@ interface IMessageContext { onSuccess: () => void, feedback: IFeedback ) => void; + onFeedbackDeleted?: ( + message: IStep, + onSuccess: () => void, + feedbackId: string + ) => void; onError: (error: string) => void; } diff --git a/libs/copilot/src/chat/messages/container.tsx b/libs/copilot/src/chat/messages/container.tsx index 391d850bf0..e73dd03afa 100644 --- a/libs/copilot/src/chat/messages/container.tsx +++ b/libs/copilot/src/chat/messages/container.tsx @@ -33,6 +33,11 @@ interface Props { onSuccess: () => void, feedback: IFeedback ) => void; + onFeedbackDeleted: ( + message: IStep, + onSuccess: () => void, + feedbackId: string + ) => void; callAction?: (action: IAction) => void; setAutoScroll?: (autoScroll: boolean) => void; } @@ -47,6 +52,7 @@ const MessageContainer = memo( elements, messages, onFeedbackUpdated, + onFeedbackDeleted, callAction, setAutoScroll }: Props) => { @@ -115,6 +121,7 @@ const MessageContainer = memo( onElementRefClick, onError, onFeedbackUpdated, + onFeedbackDeleted, onPlaygroundButtonClick }; }, [ diff --git a/libs/copilot/src/chat/messages/index.tsx b/libs/copilot/src/chat/messages/index.tsx index a8806d05c2..af3f5ac367 100644 --- a/libs/copilot/src/chat/messages/index.tsx +++ b/libs/copilot/src/chat/messages/index.tsx @@ -96,6 +96,32 @@ const Messages = ({ [] ); + const onFeedbackDeleted = useCallback( + async (message: IStep, onSuccess: () => void, feedbackId: string) => { + try { + toast.promise(apiClient.deleteFeedback(feedbackId, accessToken), { + loading: 'Updating', + success: (res) => { + setMessages((prev) => + updateMessageById(prev, message.id, { + ...message, + feedback: undefined + }) + ); + onSuccess(); + return 'Feedback updated!'; + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); + const showWelcomeScreen = !idToResume && !messages.length && @@ -122,6 +148,7 @@ const Messages = ({ messages={messages} autoScroll={autoScroll} onFeedbackUpdated={onFeedbackUpdated} + onFeedbackDeleted={onFeedbackDeleted} callAction={callActionWithToast} setAutoScroll={setAutoScroll} /> diff --git a/libs/react-client/src/api/index.tsx b/libs/react-client/src/api/index.tsx index b6f7aeb30a..f542d11a16 100644 --- a/libs/react-client/src/api/index.tsx +++ b/libs/react-client/src/api/index.tsx @@ -213,6 +213,14 @@ export class ChainlitAPI extends APIBase { return res.json(); } + async deleteFeedback( + feedbackId: string, + accessToken?: string + ): Promise<{ success: boolean }> { + const res = await this.delete(`/feedback`, { feedbackId }, accessToken); + return res.json(); + } + async listThreads( pagination: IPagination, filter: IThreadFilters, diff --git a/libs/react-client/src/types/feedback.ts b/libs/react-client/src/types/feedback.ts index 7c8b764875..625a8df4de 100644 --- a/libs/react-client/src/types/feedback.ts +++ b/libs/react-client/src/types/feedback.ts @@ -2,6 +2,5 @@ export interface IFeedback { id?: string; forId?: string; comment?: string; - strategy: 'BINARY'; value: number; } From dcccca7f62c158416f3f57e2bf2ae758d219fea5 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Wed, 27 Mar 2024 13:25:28 +0100 Subject: [PATCH 2/9] fix tests --- backend/chainlit/server.py | 3 +++ cypress/e2e/data_layer/main.py | 2 +- .../molecules/messages/components/FeedbackButtons.tsx | 8 ++++++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 8f74748953..676409ab04 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -585,6 +585,9 @@ async def get_user_threads( if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") + if not isinstance(current_user, PersistedUser): + raise HTTPException(status_code=400, detail="User not persisted") + payload.filter.userId = current_user.id res = await data_layer.list_threads(payload.pagination, payload.filter) diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py index 9518adbb51..a971193e35 100644 --- a/cypress/e2e/data_layer/main.py +++ b/cypress/e2e/data_layer/main.py @@ -81,7 +81,7 @@ async def list_threads( ) -> cl_data.PaginatedResponse[cl_data.ThreadDict]: return cl_data.PaginatedResponse( data=[t for t in thread_history if t["id"] not in deleted_thread_ids], - pageInfo=cl_data.PageInfo(hasNextPage=False, endCursor=None), + pageInfo=cl_data.PageInfo(hasNextPage=False, startCursor=None, endCursor=None), ) async def get_thread(self, thread_id: str): diff --git a/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx b/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx index e075406c2a..cde0681837 100644 --- a/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx +++ b/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx @@ -12,7 +12,7 @@ import IconButton from '@mui/material/IconButton'; import Stack from '@mui/material/Stack'; import Tooltip from '@mui/material/Tooltip'; -import { firstUserInteraction } from '@chainlit/react-client'; +import { firstUserInteraction, useChatSession } from '@chainlit/react-client'; import Dialog from 'components/atoms/Dialog'; import { AccentButton } from 'components/atoms/buttons/AccentButton'; @@ -31,6 +31,7 @@ const FeedbackButtons = ({ message }: Props) => { const [showFeedbackDialog, setShowFeedbackDialog] = useState(); const [commentInput, setCommentInput] = useState(); const firstInteraction = useRecoilValue(firstUserInteraction); + const { idToResume } = useChatSession(); const [feedback, setFeedback] = useState(message.feedback?.value); const [comment, setComment] = useState(message.feedback?.comment); @@ -75,7 +76,10 @@ const FeedbackButtons = ({ message }: Props) => { } }; - const disabled = !!message.streaming || !firstInteraction; + const isPersisted = firstInteraction || idToResume; + const isStreaming = !!message.streaming; + + const disabled = isStreaming || !isPersisted; const buttons = useMemo(() => { const iconSx = { From 2ebc408ec7bef7d93af90e7c11d403677b3193c4 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Wed, 27 Mar 2024 15:30:03 +0100 Subject: [PATCH 3/9] enhance langchain llm step display --- backend/chainlit/langchain/callbacks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/chainlit/langchain/callbacks.py b/backend/chainlit/langchain/callbacks.py index 8c6d16049c..0876444929 100644 --- a/backend/chainlit/langchain/callbacks.py +++ b/backend/chainlit/langchain/callbacks.py @@ -533,8 +533,7 @@ def _on_run_update(self, run: Run) -> None: break current_step.language = "json" - current_step.output = json.dumps(message_completion) - completion = message_completion.get("content", "") + current_step.output = json.dumps(message_completion, indent=4, ensure_ascii=False) else: completion_start = self.completion_generations[str(run.id)] completion = generation.get("text", "") From 0c9a3dfd7b5027d0f95ce1a74f1732816a199215 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Wed, 27 Mar 2024 16:59:30 +0100 Subject: [PATCH 4/9] correctly display new lines --- frontend/src/components/molecules/Code.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/src/components/molecules/Code.tsx b/frontend/src/components/molecules/Code.tsx index d37c6cb079..050f2bd44f 100644 --- a/frontend/src/components/molecules/Code.tsx +++ b/frontend/src/components/molecules/Code.tsx @@ -54,7 +54,7 @@ const Code = ({ children, ...props }: any) => { const codeChildren = props.node?.children?.[0]; const className = codeChildren?.properties?.className?.[0]; const match = /language-(\w+)/.exec(className || ''); - const code = codeChildren?.children?.[0]?.value; + const code = codeChildren?.children?.[0]?.value.replace(/\\n/g, '\n'); const showSyntaxHighlighter = match && code; @@ -99,6 +99,7 @@ const Code = ({ children, ...props }: any) => { alignItems: 'center', borderTopLeftRadius: '4px', borderTopRightRadius: '4px', + color: 'text.secondary', background: isDarkMode ? grey[900] : grey[200] }} > From 51034949d18a5ebdcd3cc57a40e75c3a2980b87f Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Wed, 27 Mar 2024 18:02:03 +0100 Subject: [PATCH 5/9] move new line cleaning to backend --- backend/chainlit/step.py | 4 ++-- frontend/src/components/molecules/Code.tsx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py index 9ce9752304..46c9a263c3 100644 --- a/backend/chainlit/step.py +++ b/backend/chainlit/step.py @@ -194,13 +194,13 @@ def _process_content(self, content, set_language=False): if set_language: self.language = "json" except TypeError: - processed_content = str(content) + processed_content = str(content).replace("\\n", "\n") if set_language: self.language = "text" elif isinstance(content, str): processed_content = content else: - processed_content = str(content) + processed_content = str(content).replace("\\n", "\n") if set_language: self.language = "text" return processed_content diff --git a/frontend/src/components/molecules/Code.tsx b/frontend/src/components/molecules/Code.tsx index 050f2bd44f..58e73707b3 100644 --- a/frontend/src/components/molecules/Code.tsx +++ b/frontend/src/components/molecules/Code.tsx @@ -54,7 +54,7 @@ const Code = ({ children, ...props }: any) => { const codeChildren = props.node?.children?.[0]; const className = codeChildren?.properties?.className?.[0]; const match = /language-(\w+)/.exec(className || ''); - const code = codeChildren?.children?.[0]?.value.replace(/\\n/g, '\n'); + const code = codeChildren?.children?.[0]?.value; const showSyntaxHighlighter = match && code; From 5981af29a964d197f6c60207ffd2fba684071e91 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Sat, 30 Mar 2024 16:06:47 +0100 Subject: [PATCH 6/9] fix thread dict --- backend/chainlit/data/__init__.py | 83 ++++++++++++++++++++----------- backend/chainlit/types.py | 4 +- 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 7c74122b05..f8689e1b88 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -2,7 +2,7 @@ import json import os from collections import deque -from typing import TYPE_CHECKING, Dict, List, Optional, Union, Literal, cast +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast import aiofiles from chainlit.config import config @@ -11,9 +11,11 @@ from chainlit.session import WebsocketSession from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter from chainlit.user import PersistedUser, User, UserDict -from literalai import Score as LiteralScore, PageInfo, PaginatedResponse, Attachment, Step as LiteralStep -from literalai.step import StepDict as LiteralStepDict +from literalai import Attachment, PageInfo, PaginatedResponse +from literalai import Score as LiteralScore +from literalai import Step as LiteralStep from literalai.filter import threads_filters as LiteralThreadsFilters +from literalai.step import StepDict as LiteralStepDict if TYPE_CHECKING: from chainlit.element import Element, ElementDict @@ -59,7 +61,6 @@ async def delete_feedback( ) -> bool: return True - async def upsert_feedback( self, feedback: Feedback, @@ -101,7 +102,8 @@ async def list_threads( self, pagination: "Pagination", filters: "ThreadFilter" ) -> "PaginatedResponse[ThreadDict]": return PaginatedResponse( - data=[], pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None) + data=[], + pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None), ) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": @@ -157,7 +159,7 @@ def score_to_feedback_dict( return { "id": score.id or "", "forId": score.step_id or "", - "value": cast(Literal[0, 1], score.value), + "value": cast(Literal[0, 1], score.value), "comment": score.comment, } @@ -169,9 +171,20 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict": output = (step.output or {}).get("content") or ( json.dumps(step.output) if step.output and step.output != {} else "" ) - - user_feedback = next((s for s in step.scores if s.type == "HUMAN" and s.name == "user-feedback"), None) if step.scores else None - + + user_feedback = ( + next( + ( + s + for s in step.scores + if s.type == "HUMAN" and s.name == "user-feedback" + ), + None, + ) + if step.scores + else None + ) + return { "createdAt": step.created_at, "id": step.id or "", @@ -230,7 +243,6 @@ async def delete_feedback( return True return False - async def upsert_feedback( self, feedback: Feedback, @@ -354,10 +366,16 @@ async def get_thread_author(self, thread_id: str) -> str: thread = await self.get_thread(thread_id) if not thread: return "" - user = thread.get("user") - if not user: + user_id = thread.get("user_id") + if not user_id: + return "" + + user = await self.client.api.get_user(id=user_id) + + if not user or not user.identifier: return "" - return user.get("identifier") or "" + + return user.identifier async def delete_thread(self, thread_id: str): await self.client.api.delete_thread(id=thread_id) @@ -372,19 +390,35 @@ async def list_threads( { "field": "participantId", "operator": "eq", - "value": filters.userId, + "value": filters.userId, } ] - - if filters.search: - literal_filters.append({"field": "stepOutput", "operator": "ilike", "value": filters.search, "path": "content"}) + if filters.search: + literal_filters.append( + { + "field": "stepOutput", + "operator": "ilike", + "value": filters.search, + "path": "content", + } + ) if filters.feedback is not None: - literal_filters.append({"field": "scoreValue", "operator": "eq", "value": filters.feedback, "path": "user-feedback"}) + literal_filters.append( + { + "field": "scoreValue", + "operator": "eq", + "value": filters.feedback, + "path": "user-feedback", + } + ) return await self.client.api.list_threads( - first=pagination.first, after=pagination.cursor, filters=literal_filters, order_by={"column": "createdAt", "direction": "DESC"} + first=pagination.first, + after=pagination.cursor, + filters=literal_filters, + order_by={"column": "createdAt", "direction": "DESC"}, ) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": @@ -403,15 +437,6 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": step.generation = None steps.append(self.step_to_step_dict(step)) - user = None # type: Optional["UserDict"] - - if thread.user: - user = { - "id": thread.user.id or "", - "identifier": thread.user.identifier or "", - "metadata": thread.user.metadata, - } - return { "createdAt": thread.created_at or "", "id": thread.id, @@ -419,7 +444,7 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": "steps": steps, "elements": elements, "metadata": thread.metadata, - "user": user, + "user_id": thread.participant_id, "tags": thread.tags, } diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index 56be377ebe..9177fdc568 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -3,7 +3,6 @@ if TYPE_CHECKING: from chainlit.element import ElementDict - from chainlit.user import UserDict from chainlit.step import StepDict from dataclasses_json import DataClassJsonMixin @@ -20,7 +19,7 @@ class ThreadDict(TypedDict): id: str createdAt: str name: Optional[str] - user: Optional["UserDict"] + user_id: Optional[str] tags: Optional[List[str]] metadata: Optional[Dict] steps: List["StepDict"] @@ -122,6 +121,7 @@ def is_chat(self): class DeleteThreadRequest(BaseModel): threadId: str + class DeleteFeedbackRequest(BaseModel): feedbackId: str From 7146d5054e6975f576553c5cf0894b9b85c54893 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Sat, 30 Mar 2024 16:41:03 +0100 Subject: [PATCH 7/9] bump sdk version --- backend/chainlit/data/__init__.py | 12 ++++-------- backend/chainlit/socket.py | 2 +- backend/chainlit/types.py | 1 + backend/pyproject.toml | 2 +- frontend/tests/message.spec.tsx | 3 ++- libs/react-client/src/types/thread.ts | 4 ++-- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 64b3c5bb32..cecad7e6e0 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -366,16 +366,11 @@ async def get_thread_author(self, thread_id: str) -> str: thread = await self.get_thread(thread_id) if not thread: return "" - user_id = thread.get("user_id") - if not user_id: + user_identifier = thread.get("user_identifier") + if not user_identifier: return "" - user = await self.client.api.get_user(id=user_id) - - if not user or not user.identifier: - return "" - - return user.identifier + return user_identifier async def delete_thread(self, thread_id: str): await self.client.api.delete_thread(id=thread_id) @@ -445,6 +440,7 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": "elements": elements, "metadata": thread.metadata, "user_id": thread.participant_id, + "user_identifier": thread.participant_identifier, "tags": thread.tags, } diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 773bd1434a..f4b0cc2a4e 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -42,7 +42,7 @@ async def resume_thread(session: WebsocketSession): if not thread: return - author = thread.get("user").get("identifier") if thread["user"] else None + author = thread.get("user_identifier") user_is_author = author == session.user.identifier if user_is_author: diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index 9177fdc568..11eecc16cf 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -20,6 +20,7 @@ class ThreadDict(TypedDict): createdAt: str name: Optional[str] user_id: Optional[str] + user_identifier: Optional[str] tags: Optional[List[str]] metadata: Optional[Dict] steps: List["StepDict"] diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 918e74fe9e..9109d683ed 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -23,7 +23,7 @@ chainlit = 'chainlit.cli:cli' [tool.poetry.dependencies] python = ">=3.8.1,<4.0.0" httpx = ">=0.23.0" -literalai = "0.0.400" +literalai = "0.0.401" dataclasses_json = "^0.5.7" fastapi = ">=0.100" # Starlette >= 0.33.0 breaks socketio (alway 404) diff --git a/frontend/tests/message.spec.tsx b/frontend/tests/message.spec.tsx index c3169b860e..c326ca000f 100644 --- a/frontend/tests/message.spec.tsx +++ b/frontend/tests/message.spec.tsx @@ -24,7 +24,8 @@ describe('Message', () => { name: 'bar', createdAt: '12/12/2002', start: '12/12/2002', - end: '12/12/2002' + end: '12/12/2002', + disableFeedback: true } ], waitForAnswer: false, diff --git a/libs/react-client/src/types/thread.ts b/libs/react-client/src/types/thread.ts index 8dbdfc73e5..a6747ed589 100644 --- a/libs/react-client/src/types/thread.ts +++ b/libs/react-client/src/types/thread.ts @@ -1,12 +1,12 @@ import { IElement } from './element'; import { IStep } from './step'; -import { IUser } from './user'; export interface IThread { id: string; createdAt: number | string; name?: string; - user?: IUser; + user_id?: string; + user_identifier?: string; metadata?: Record; steps: IStep[]; elements?: IElement[]; From 0d01d1aafb76cf718ccb31ec659d79d1b31ad9b9 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Sat, 30 Mar 2024 16:57:35 +0100 Subject: [PATCH 8/9] fix data layer test --- cypress/e2e/data_layer/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py index a971193e35..cc414684dd 100644 --- a/cypress/e2e/data_layer/main.py +++ b/cypress/e2e/data_layer/main.py @@ -10,14 +10,14 @@ create_step_counter = 0 -user_dict = {"id": "test", "createdAt": now, "identifier": "admin"} thread_history = [ { "id": "test1", "name": "thread 1", "createdAt": now, - "user": user_dict, + "user_id": "test", + "user_identifier": "admin", "steps": [ { "id": "test1", @@ -38,7 +38,8 @@ { "id": "test2", "createdAt": now, - "user": user_dict, + "user_id": "test", + "user_identifier": "admin", "name": "thread 2", "steps": [ { From 439d49ca203deec4947e2032f894f9472b1d4b0d Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Sat, 30 Mar 2024 17:07:45 +0100 Subject: [PATCH 9/9] fix casing --- backend/chainlit/data/__init__.py | 6 +++--- backend/chainlit/socket.py | 2 +- backend/chainlit/types.py | 4 ++-- cypress/e2e/data_layer/main.py | 12 +++++++----- libs/react-client/src/types/thread.ts | 4 ++-- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index cecad7e6e0..d8210d9ff0 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -366,7 +366,7 @@ async def get_thread_author(self, thread_id: str) -> str: thread = await self.get_thread(thread_id) if not thread: return "" - user_identifier = thread.get("user_identifier") + user_identifier = thread.get("userIdentifier") if not user_identifier: return "" @@ -439,8 +439,8 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": "steps": steps, "elements": elements, "metadata": thread.metadata, - "user_id": thread.participant_id, - "user_identifier": thread.participant_identifier, + "userId": thread.participant_id, + "userIdentifier": thread.participant_identifier, "tags": thread.tags, } diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index f4b0cc2a4e..2217849032 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -42,7 +42,7 @@ async def resume_thread(session: WebsocketSession): if not thread: return - author = thread.get("user_identifier") + author = thread.get("userIdentifier") user_is_author = author == session.user.identifier if user_is_author: diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index 11eecc16cf..56427778d7 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -19,8 +19,8 @@ class ThreadDict(TypedDict): id: str createdAt: str name: Optional[str] - user_id: Optional[str] - user_identifier: Optional[str] + userId: Optional[str] + userIdentifier: Optional[str] tags: Optional[List[str]] metadata: Optional[Dict] steps: List["StepDict"] diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py index cc414684dd..057edb4bb2 100644 --- a/cypress/e2e/data_layer/main.py +++ b/cypress/e2e/data_layer/main.py @@ -16,8 +16,8 @@ "id": "test1", "name": "thread 1", "createdAt": now, - "user_id": "test", - "user_identifier": "admin", + "userId": "test", + "userIdentifier": "admin", "steps": [ { "id": "test1", @@ -38,8 +38,8 @@ { "id": "test2", "createdAt": now, - "user_id": "test", - "user_identifier": "admin", + "userId": "test", + "userIdentifier": "admin", "name": "thread 2", "steps": [ { @@ -82,7 +82,9 @@ async def list_threads( ) -> cl_data.PaginatedResponse[cl_data.ThreadDict]: return cl_data.PaginatedResponse( data=[t for t in thread_history if t["id"] not in deleted_thread_ids], - pageInfo=cl_data.PageInfo(hasNextPage=False, startCursor=None, endCursor=None), + pageInfo=cl_data.PageInfo( + hasNextPage=False, startCursor=None, endCursor=None + ), ) async def get_thread(self, thread_id: str): diff --git a/libs/react-client/src/types/thread.ts b/libs/react-client/src/types/thread.ts index a6747ed589..fe0cc2bbc2 100644 --- a/libs/react-client/src/types/thread.ts +++ b/libs/react-client/src/types/thread.ts @@ -5,8 +5,8 @@ export interface IThread { id: string; createdAt: number | string; name?: string; - user_id?: string; - user_identifier?: string; + userId?: string; + userIdentifier?: string; metadata?: Record; steps: IStep[]; elements?: IElement[];