From 8a8054743cb91a607dc5239175a25312bef2f23f Mon Sep 17 00:00:00 2001 From: James Date: Fri, 2 Aug 2024 13:56:49 +0700 Subject: [PATCH] fix: add back normalize message function Signed-off-by: James --- web/hooks/useSendMessage.ts | 54 ++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/web/hooks/useSendMessage.ts b/web/hooks/useSendMessage.ts index a8f1187677..3946c5fd8c 100644 --- a/web/hooks/useSendMessage.ts +++ b/web/hooks/useSendMessage.ts @@ -22,6 +22,8 @@ import { inferenceErrorAtom } from '@/screens/HubScreen2/components/InferenceErr import { showWarningMultipleModelModalAtom } from '@/screens/HubScreen2/components/WarningMultipleModelModal' import { concurrentModelWarningThreshold } from '@/screens/Settings/MyModels/ModelItem' +import { Stack } from '@/utils/Stack' + import useCortex from './useCortex' import useEngineInit from './useEngineInit' @@ -47,28 +49,29 @@ import { updateThreadTitleAtom, } from '@/helpers/atoms/Thread.atom' -// TODO: NamH add this back -// const normalizeMessages = (messages: Message[]): Message[] => { -// const stack = new Stack() -// for (const message of messages) { -// if (stack.isEmpty()) { -// stack.push(message) -// continue -// } -// const topMessage = stack.peek() - -// if (message.role === topMessage.role) { -// // add an empty message -// stack.push({ -// role: topMessage.role === 'user' ? 'assistant' : 'user', -// content: '.', // some model requires not empty message -// }) -// } -// stack.push(message) -// } - -// return stack.reverseOutput() -// } +const normalizeMessages = ( + messages: ChatCompletionMessageParam[] +): ChatCompletionMessageParam[] => { + const stack = new Stack() + for (const message of messages) { + if (stack.isEmpty()) { + stack.push(message) + continue + } + const topMessage = stack.peek() + + if (message.role === topMessage.role) { + // add an empty message + stack.push({ + role: topMessage.role === 'user' ? 'assistant' : 'user', + content: '.', // some model requires not empty message + }) + } + stack.push(message) + } + + return stack.reverseOutput() +} const useSendMessage = () => { const createMessage = useMessageCreateMutation() @@ -285,7 +288,7 @@ const useSendMessage = () => { content: activeThread!.assistants[0].instructions ?? '', } - const messages: ChatCompletionMessageParam[] = currentMessages + let messages: ChatCompletionMessageParam[] = currentMessages .map((msg) => { switch (msg.role) { case 'user': @@ -305,7 +308,7 @@ const useSendMessage = () => { }) .filter((msg) => msg != null) as ChatCompletionMessageParam[] messages.unshift(systemMessage) - + messages = normalizeMessages(messages) const modelOptions: Record = {} if (selectedModel!.frequency_penalty) { modelOptions.frequency_penalty = selectedModel!.frequency_penalty @@ -540,7 +543,7 @@ const useSendMessage = () => { content: activeThread!.assistants[0].instructions ?? '', } - const messages: ChatCompletionMessageParam[] = currentMessages + let messages: ChatCompletionMessageParam[] = currentMessages .map((msg) => { switch (msg.role) { case 'user': @@ -564,6 +567,7 @@ const useSendMessage = () => { content: message, }) messages.unshift(systemMessage) + messages = normalizeMessages(messages) const modelOptions: Record = {} if (selectedModel!.frequency_penalty) { modelOptions.frequency_penalty = selectedModel!.frequency_penalty