diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 325e1e3bb5f..f79e0e8f6f1 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store"; import Locale from "../locales"; import { InputRange } from "./input-range"; -import { List, ListItem, Select } from "./ui-lib"; +import { ListItem, Select } from "./ui-lib"; export function ModelConfigList(props: { modelConfig: ModelConfig; @@ -109,6 +109,21 @@ export function ModelConfigList(props: { > + + + props.updateConfig( + (config) => (config.template = e.currentTarget.value), + ) + } + > + + `当前版本:${x}`, IsLatest: "已是最新版本", diff --git a/app/locales/en.ts b/app/locales/en.ts index 265f5cff3f7..8e56147c827 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -116,6 +116,12 @@ const en: LocaleType = { Title: "Font Size", SubTitle: "Adjust font size of chat content", }, + + InputTemplate: { + Title: "Input Template", + SubTitle: "Newest message will be filled to this template", + }, + Update: { Version: (x: string) => `Version: ${x}`, IsLatest: "Latest version", diff --git a/app/store/chat.ts b/app/store/chat.ts index 629eeab94de..2a826b6bd52 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -3,11 +3,11 @@ import { persist } from "zustand/middleware"; import { trimTopic } from "../utils"; -import Locale from "../locales"; +import Locale, { getLang } from "../locales"; import { showToast } from "../components/ui-lib"; -import { ModelType } from "./config"; +import { ModelConfig, ModelType } from "./config"; import { createEmptyMask, Mask } from "./mask"; -import { StoreKey } from "../constant"; +import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; import { api, RequestMessage } from "../client/api"; import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; @@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) { return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0); } +function fillTemplateWith(input: string, modelConfig: ModelConfig) { + const vars = { + model: modelConfig.model, + time: new Date().toLocaleString(), + lang: getLang(), + input: input, + }; + + let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE; + + // must contains {{input}} + const inputVar = "{{input}}"; + if (!output.includes(inputVar)) { + output += "\n" + inputVar; + } + + Object.entries(vars).forEach(([name, value]) => { + output = output.replaceAll(`{{${name}}}`, value); + }); + + return output; +} + export const useChatStore = create()( persist( (set, get) => ({ @@ -238,9 +261,12 @@ export const useChatStore = create()( const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + const userContent = fillTemplateWith(content, modelConfig); + console.log("[User Input] fill with template: ", userContent); + const userMessage: ChatMessage = createMessage({ role: "user", - content, + content: userContent, }); const botMessage: ChatMessage = createMessage({ @@ -250,31 +276,22 @@ export const useChatStore = create()( model: modelConfig.model, }); - const systemInfo = createMessage({ - role: "system", - content: `IMPORTANT: You are a virtual assistant powered by the ${ - modelConfig.model - } model, now time is ${new Date().toLocaleString()}}`, - id: botMessage.id! + 1, - }); - // get recent messages - const systemMessages = []; - // if user define a mask with context prompts, wont send system info - if (session.mask.context.length === 0) { - systemMessages.push(systemInfo); - } - const recentMessages = get().getMessagesWithMemory(); - const sendMessages = systemMessages.concat( - recentMessages.concat(userMessage), - ); + const sendMessages = recentMessages.concat(userMessage); const sessionIndex = get().currentSessionIndex; const messageIndex = get().currentSession().messages.length + 1; // save user's and bot's message get().updateCurrentSession((session) => { - session.messages = session.messages.concat([userMessage, botMessage]); + const savedUserMessage = { + ...userMessage, + content, + }; + session.messages = session.messages.concat([ + savedUserMessage, + botMessage, + ]); }); // make request @@ -350,55 +367,62 @@ export const useChatStore = create()( getMessagesWithMemory() { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + const clearContextIndex = session.clearContextIndex ?? 0; + const messages = session.messages.slice(); + const totalMessageCount = session.messages.length; - // wont send cleared context messages - const clearedContextMessages = session.messages.slice( - session.clearContextIndex ?? 0, - ); - const messages = clearedContextMessages.filter((msg) => !msg.isError); - const n = messages.length; - - const context = session.mask.context.slice(); + // in-context prompts + const contextPrompts = session.mask.context.slice(); // long term memory - if ( + const shouldSendLongTermMemory = modelConfig.sendMemory && session.memoryPrompt && - session.memoryPrompt.length > 0 - ) { - const memoryPrompt = get().getMemoryPrompt(); - context.push(memoryPrompt); - } - - // get short term and unmemorized long term memory - const shortTermMemoryMessageIndex = Math.max( + session.memoryPrompt.length > 0 && + session.lastSummarizeIndex <= clearContextIndex; + const longTermMemoryPrompts = shouldSendLongTermMemory + ? [get().getMemoryPrompt()] + : []; + const longTermMemoryStartIndex = session.lastSummarizeIndex; + + // short term memory + const shortTermMemoryStartIndex = Math.max( 0, - n - modelConfig.historyMessageCount, + totalMessageCount - modelConfig.historyMessageCount, ); - const longTermMemoryMessageIndex = session.lastSummarizeIndex; - // try to concat history messages + // lets concat send messages, including 4 parts: + // 1. long term memory: summarized memory messages + // 2. pre-defined in-context prompts + // 3. short term memory: latest n messages + // 4. newest input message const memoryStartIndex = Math.min( - shortTermMemoryMessageIndex, - longTermMemoryMessageIndex, + longTermMemoryStartIndex, + shortTermMemoryStartIndex, ); - const threshold = modelConfig.max_tokens; + // and if user has cleared history messages, we should exclude the memory too. + const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex); + const maxTokenThreshold = modelConfig.max_tokens; - // get recent messages as many as possible + // get recent messages as much as possible const reversedRecentMessages = []; for ( - let i = n - 1, count = 0; - i >= memoryStartIndex && count < threshold; + let i = totalMessageCount - 1, tokenCount = 0; + i >= contextStartIndex && tokenCount < maxTokenThreshold; i -= 1 ) { const msg = messages[i]; if (!msg || msg.isError) continue; - count += estimateTokenLength(msg.content); + tokenCount += estimateTokenLength(msg.content); reversedRecentMessages.push(msg); } - // concat - const recentMessages = context.concat(reversedRecentMessages.reverse()); + // concat all messages + const recentMessages = [ + ...longTermMemoryPrompts, + ...contextPrompts, + ...reversedRecentMessages.reverse(), + ]; return recentMessages; }, diff --git a/app/store/config.ts b/app/store/config.ts index 2b8493ca7c3..b15fa914802 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,7 +1,7 @@ import { create } from "zustand"; import { persist } from "zustand/middleware"; import { getClientConfig } from "../config/client"; -import { StoreKey } from "../constant"; +import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; export enum SubmitKey { Enter = "Enter", @@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = { sendMemory: true, historyMessageCount: 4, compressMessageLengthThreshold: 1000, + template: DEFAULT_INPUT_TEMPLATE, }, }; @@ -176,15 +177,16 @@ export const useAppConfig = create()( }), { name: StoreKey.Config, - version: 3, + version: 3.1, migrate(persistedState, version) { - if (version === 3) return persistedState as any; + if (version === 3.1) return persistedState as any; const state = persistedState as ChatConfig; state.modelConfig.sendMemory = true; state.modelConfig.historyMessageCount = 4; state.modelConfig.compressMessageLengthThreshold = 1000; state.modelConfig.frequency_penalty = 0; + state.modelConfig.template = DEFAULT_INPUT_TEMPLATE; state.dontShowMaskSplashScreen = false; return state;