From bad481bf05aa38edcf553e1273f5d692a65c9225 Mon Sep 17 00:00:00 2001 From: NamH Date: Wed, 7 Aug 2024 13:58:51 +0700 Subject: [PATCH] feat: add chunk count (#3290) * feat: add chunk count * bump cortex version --- electron/resources/version.txt | 2 +- web/helpers/atoms/ChatMessage.atom.ts | 4 ++ web/hooks/useSendMessage.ts | 21 +++++++ .../ChatActionButton/StopInferenceButton.tsx | 26 +++++--- .../components/TokenCount.tsx | 59 +++++++++++++++++++ .../SimpleTextMessage/index.tsx | 50 +--------------- 6 files changed, 105 insertions(+), 57 deletions(-) create mode 100644 web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/components/TokenCount.tsx diff --git a/electron/resources/version.txt b/electron/resources/version.txt index 5dd5d632fb..4ad94ccc12 100644 --- a/electron/resources/version.txt +++ b/electron/resources/version.txt @@ -1 +1 @@ -0.5.0-30 +0.5.0-31 diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index 38198823fd..289d8d3e8c 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -5,6 +5,10 @@ import { getActiveThreadIdAtom } from './Thread.atom' const chatMessages = atom>({}) +export const disableStopInferenceAtom = atom(false) + +export const chunkCountAtom = atom>({}) + /** * Return the chat messages for the current active thread */ diff --git a/web/hooks/useSendMessage.ts b/web/hooks/useSendMessage.ts index 2749109da4..f71438a4da 100644 --- a/web/hooks/useSendMessage.ts +++ b/web/hooks/useSendMessage.ts @@ -36,6 +36,8 @@ import useModelStart from './useModelStart' import { addNewMessageAtom, + chunkCountAtom, + disableStopInferenceAtom, getCurrentChatMessagesAtom, updateMessageAtom, } from '@/helpers/atoms/ChatMessage.atom' @@ -104,6 +106,9 @@ const useSendMessage = () => { showWarningMultipleModelModalAtom ) + const setDisableStopInference = useSetAtom(disableStopInferenceAtom) + const setChunkCount = useSetAtom(chunkCountAtom) + const validatePrerequisite = useCallback(async (): Promise => { const errorTitle = 'Failed to send message' if (!activeThread) { @@ -361,7 +366,12 @@ const useSendMessage = () => { addNewMessage(responseMessage) + let chunkCount = 1 for await (const chunk of stream) { + setChunkCount((prev) => ({ + ...prev, + [responseMessage.id]: chunkCount++, + })) const content = chunk.choices[0]?.delta?.content || '' assistantResponseMessage += content const messageContent: MessageContent = { @@ -579,6 +589,7 @@ const useSendMessage = () => { let assistantResponseMessage = '' try { if (selectedModel!.stream === true) { + setDisableStopInference(true) const stream = await chatCompletionStreaming({ messages, model: selectedModel!.model, @@ -623,7 +634,14 @@ const useSendMessage = () => { addNewMessage(responseMessage) + let chunkCount = 1 for await (const chunk of stream) { + setChunkCount((prev) => ({ + ...prev, + [responseMessage.id]: chunkCount++, + })) + // we have first chunk, enable the inference button + setDisableStopInference(false) const content = chunk.choices[0]?.delta?.content || '' assistantResponseMessage += content const messageContent: MessageContent = { @@ -737,6 +755,7 @@ const useSendMessage = () => { }) } + setDisableStopInference(false) setIsGeneratingResponse(false) shouldSummarize = false @@ -780,6 +799,8 @@ const useSendMessage = () => { chatCompletionStreaming, summarizeThread, setShowWarningMultipleModelModal, + setDisableStopInference, + setChunkCount, ] ) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx index a8cf9273c0..7fe2764cd9 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx @@ -2,20 +2,28 @@ import React from 'react' import { Button } from '@janhq/joi' +import { useAtomValue } from 'jotai' import { StopCircle } from 'lucide-react' +import { disableStopInferenceAtom } from '@/helpers/atoms/ChatMessage.atom' + type Props = { onStopInferenceClick: () => void } -const StopInferenceButton: React.FC = ({ onStopInferenceClick }) => ( - -) +const StopInferenceButton: React.FC = ({ onStopInferenceClick }) => { + const disabled = useAtomValue(disableStopInferenceAtom) + + return ( + + ) +} export default React.memo(StopInferenceButton) diff --git a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/components/TokenCount.tsx b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/components/TokenCount.tsx new file mode 100644 index 0000000000..9a2054fddb --- /dev/null +++ b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/components/TokenCount.tsx @@ -0,0 +1,59 @@ +import { useEffect, useMemo, useState } from 'react' + +import { Message, TextContentBlock } from '@janhq/core' +import { useAtomValue } from 'jotai' + +import { chunkCountAtom } from '@/helpers/atoms/ChatMessage.atom' + +type Props = { + message: Message +} + +const TokenCount: React.FC = ({ message }) => { + const chunkCountMap = useAtomValue(chunkCountAtom) + const [lastTimestamp, setLastTimestamp] = useState() + const [tokenSpeed, setTokenSpeed] = useState(0) + + const receivedChunkCount = useMemo( + () => chunkCountMap[message.id] ?? 0, + [chunkCountMap, message.id] + ) + + useEffect(() => { + if (message.status !== 'in_progress') { + return + } + const currentTimestamp = Date.now() + if (!lastTimestamp) { + // If this is the first update, just set the lastTimestamp and return + if (message.content && message.content.length > 0) { + const messageContent = message.content[0] + if (messageContent && messageContent.type === 'text') { + const textContentBlock = messageContent as TextContentBlock + if (textContentBlock.text.value !== '') { + setLastTimestamp(currentTimestamp) + } + } + } + return + } + + const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000 + const averageTokenSpeed = receivedChunkCount / timeDiffInSeconds + + setTokenSpeed(averageTokenSpeed) + }, [message.content, lastTimestamp, receivedChunkCount, message.status]) + + if (tokenSpeed === 0) return null + + return ( +
+

+ Token count: {receivedChunkCount}, speed:{' '} + {Number(tokenSpeed).toFixed(2)}t/s +

+
+ ) +} + +export default TokenCount diff --git a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx index e4b36fa9c6..38b098c025 100644 --- a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx @@ -28,6 +28,8 @@ import { openFileTitle } from '@/utils/titleUtils' import EditChatInput from '../EditChatInput' import MessageToolbar from '../MessageToolbar' +import TokenCount from './components/TokenCount' + import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' type Props = { @@ -114,9 +116,6 @@ const SimpleTextMessage: React.FC = ({ const isUser = msg.role === 'user' const { onViewFileContainer } = usePath() const parsedText = useMemo(() => marked.parse(text), [marked, text]) - const [tokenCount, setTokenCount] = useState(0) - const [lastTimestamp, setLastTimestamp] = useState() - const [tokenSpeed, setTokenSpeed] = useState(0) const codeBlockCopyEvent = useRef((e: Event) => { const target: HTMLElement = e.target as HTMLElement @@ -138,34 +137,6 @@ const SimpleTextMessage: React.FC = ({ } }, []) - useEffect(() => { - if (msg.status !== 'in_progress') { - return - } - const currentTimestamp = new Date().getTime() // Get current time in milliseconds - if (!lastTimestamp) { - // If this is the first update, just set the lastTimestamp and return - if (msg.content && msg.content.length > 0) { - const message = msg.content[0] - if (message && message.type === 'text') { - const textContentBlock = message as TextContentBlock - if (textContentBlock.text.value !== '') { - setLastTimestamp(currentTimestamp) - } - } - } - return - } - - const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000 // Time difference in seconds - const totalTokenCount = tokenCount + 1 - const averageTokenSpeed = totalTokenCount / timeDiffInSeconds // Calculate average token speed - - setTokenSpeed(averageTokenSpeed) - setTokenCount(totalTokenCount) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [msg.content]) - return (
= ({ )} > {isUser ? : } -
= ({ onResendMessage={onResendMessage} />
- {isLatestMessage && - (msg.status === 'in_progress' || tokenSpeed > 0) && ( -

- Token Speed: {Number(tokenSpeed).toFixed(2)}t/s -

- )} + {isLatestMessage && }
= ({ {msg.content[0]?.type === 'image_file' && (
-
- {/* */} - {/* onViewFile(`${msg.content[0]?.text.annotations[0]}`) */} - {/* } */} - {/* /> */} -