diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 325e1e3bb5f..f79e0e8f6f1 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store"; import Locale from "../locales"; import { InputRange } from "./input-range"; -import { List, ListItem, Select } from "./ui-lib"; +import { ListItem, Select } from "./ui-lib"; export function ModelConfigList(props: { modelConfig: ModelConfig; @@ -109,6 +109,21 @@ export function ModelConfigList(props: { > + + + props.updateConfig( + (config) => (config.template = e.currentTarget.value), + ) + } + > + + `当前版本:${x}`, IsLatest: "已是最新版本", diff --git a/app/locales/en.ts b/app/locales/en.ts index 265f5cff3f7..8e56147c827 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -116,6 +116,12 @@ const en: LocaleType = { Title: "Font Size", SubTitle: "Adjust font size of chat content", }, + + InputTemplate: { + Title: "Input Template", + SubTitle: "Newest message will be filled to this template", + }, + Update: { Version: (x: string) => `Version: ${x}`, IsLatest: "Latest version", diff --git a/app/masks/index.ts b/app/masks/index.ts index 07c6a3e8cda..b9cb23f20b5 100644 --- a/app/masks/index.ts +++ b/app/masks/index.ts @@ -9,7 +9,7 @@ export const BUILTIN_MASK_ID = 100000; export const BUILTIN_MASK_STORE = { buildinId: BUILTIN_MASK_ID, - masks: {} as Record, + masks: {} as Record, get(id?: number) { if (!id) return undefined; return this.masks[id] as Mask | undefined; @@ -21,6 +21,6 @@ export const BUILTIN_MASK_STORE = { }, }; -export const BUILTIN_MASKS: Mask[] = [...CN_MASKS, ...EN_MASKS].map((m) => - BUILTIN_MASK_STORE.add(m), +export const BUILTIN_MASKS: BuiltinMask[] = [...CN_MASKS, ...EN_MASKS].map( + (m) => BUILTIN_MASK_STORE.add(m), ); diff --git a/app/masks/typing.ts b/app/masks/typing.ts index 510d94a2c20..1ded6a90295 100644 --- a/app/masks/typing.ts +++ b/app/masks/typing.ts @@ -1,5 +1,7 @@ +import { ModelConfig } from "../store"; import { type Mask } from "../store/mask"; -export type BuiltinMask = Omit & { - builtin: true; +export type BuiltinMask = Omit & { + builtin: Boolean; + modelConfig: Partial; }; diff --git a/app/store/chat.ts b/app/store/chat.ts index 629eeab94de..d311c88ff9d 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -3,11 +3,11 @@ import { persist } from "zustand/middleware"; import { trimTopic } from "../utils"; -import Locale from "../locales"; +import Locale, { getLang } from "../locales"; import { showToast } from "../components/ui-lib"; -import { ModelType } from "./config"; +import { ModelConfig, ModelType, useAppConfig } from "./config"; import { createEmptyMask, Mask } from "./mask"; -import { StoreKey } from "../constant"; +import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; import { api, RequestMessage } from "../client/api"; import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; @@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) { return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0); } +function fillTemplateWith(input: string, modelConfig: ModelConfig) { + const vars = { + model: modelConfig.model, + time: new Date().toLocaleString(), + lang: getLang(), + input: input, + }; + + let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE; + + // must contains {{input}} + const inputVar = "{{input}}"; + if (!output.includes(inputVar)) { + output += "\n" + inputVar; + } + + Object.entries(vars).forEach(([name, value]) => { + output = output.replaceAll(`{{${name}}}`, value); + }); + + return output; +} + export const useChatStore = create()( persist( (set, get) => ({ @@ -158,7 +181,16 @@ export const useChatStore = create()( session.id = get().globalId; if (mask) { - session.mask = { ...mask }; + const config = useAppConfig.getState(); + const globalModelConfig = config.modelConfig; + + session.mask = { + ...mask, + modelConfig: { + ...globalModelConfig, + ...mask.modelConfig, + }, + }; session.topic = mask.name; } @@ -238,9 +270,12 @@ export const useChatStore = create()( const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + const userContent = fillTemplateWith(content, modelConfig); + console.log("[User Input] fill with template: ", userContent); + const userMessage: ChatMessage = createMessage({ role: "user", - content, + content: userContent, }); const botMessage: ChatMessage = createMessage({ @@ -250,31 +285,22 @@ export const useChatStore = create()( model: modelConfig.model, }); - const systemInfo = createMessage({ - role: "system", - content: `IMPORTANT: You are a virtual assistant powered by the ${ - modelConfig.model - } model, now time is ${new Date().toLocaleString()}}`, - id: botMessage.id! + 1, - }); - // get recent messages - const systemMessages = []; - // if user define a mask with context prompts, wont send system info - if (session.mask.context.length === 0) { - systemMessages.push(systemInfo); - } - const recentMessages = get().getMessagesWithMemory(); - const sendMessages = systemMessages.concat( - recentMessages.concat(userMessage), - ); + const sendMessages = recentMessages.concat(userMessage); const sessionIndex = get().currentSessionIndex; const messageIndex = get().currentSession().messages.length + 1; // save user's and bot's message get().updateCurrentSession((session) => { - session.messages = session.messages.concat([userMessage, botMessage]); + const savedUserMessage = { + ...userMessage, + content, + }; + session.messages = session.messages.concat([ + savedUserMessage, + botMessage, + ]); }); // make request @@ -350,55 +376,62 @@ export const useChatStore = create()( getMessagesWithMemory() { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + const clearContextIndex = session.clearContextIndex ?? 0; + const messages = session.messages.slice(); + const totalMessageCount = session.messages.length; - // wont send cleared context messages - const clearedContextMessages = session.messages.slice( - session.clearContextIndex ?? 0, - ); - const messages = clearedContextMessages.filter((msg) => !msg.isError); - const n = messages.length; - - const context = session.mask.context.slice(); + // in-context prompts + const contextPrompts = session.mask.context.slice(); // long term memory - if ( + const shouldSendLongTermMemory = modelConfig.sendMemory && session.memoryPrompt && - session.memoryPrompt.length > 0 - ) { - const memoryPrompt = get().getMemoryPrompt(); - context.push(memoryPrompt); - } - - // get short term and unmemorized long term memory - const shortTermMemoryMessageIndex = Math.max( + session.memoryPrompt.length > 0 && + session.lastSummarizeIndex <= clearContextIndex; + const longTermMemoryPrompts = shouldSendLongTermMemory + ? [get().getMemoryPrompt()] + : []; + const longTermMemoryStartIndex = session.lastSummarizeIndex; + + // short term memory + const shortTermMemoryStartIndex = Math.max( 0, - n - modelConfig.historyMessageCount, + totalMessageCount - modelConfig.historyMessageCount, ); - const longTermMemoryMessageIndex = session.lastSummarizeIndex; - // try to concat history messages + // lets concat send messages, including 4 parts: + // 1. long term memory: summarized memory messages + // 2. pre-defined in-context prompts + // 3. short term memory: latest n messages + // 4. newest input message const memoryStartIndex = Math.min( - shortTermMemoryMessageIndex, - longTermMemoryMessageIndex, + longTermMemoryStartIndex, + shortTermMemoryStartIndex, ); - const threshold = modelConfig.max_tokens; + // and if user has cleared history messages, we should exclude the memory too. + const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex); + const maxTokenThreshold = modelConfig.max_tokens; - // get recent messages as many as possible + // get recent messages as much as possible const reversedRecentMessages = []; for ( - let i = n - 1, count = 0; - i >= memoryStartIndex && count < threshold; + let i = totalMessageCount - 1, tokenCount = 0; + i >= contextStartIndex && tokenCount < maxTokenThreshold; i -= 1 ) { const msg = messages[i]; if (!msg || msg.isError) continue; - count += estimateTokenLength(msg.content); + tokenCount += estimateTokenLength(msg.content); reversedRecentMessages.push(msg); } - // concat - const recentMessages = context.concat(reversedRecentMessages.reverse()); + // concat all messages + const recentMessages = [ + ...longTermMemoryPrompts, + ...contextPrompts, + ...reversedRecentMessages.reverse(), + ]; return recentMessages; }, diff --git a/app/store/config.ts b/app/store/config.ts index 2b8493ca7c3..b15fa914802 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,7 +1,7 @@ import { create } from "zustand"; import { persist } from "zustand/middleware"; import { getClientConfig } from "../config/client"; -import { StoreKey } from "../constant"; +import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; export enum SubmitKey { Enter = "Enter", @@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = { sendMemory: true, historyMessageCount: 4, compressMessageLengthThreshold: 1000, + template: DEFAULT_INPUT_TEMPLATE, }, }; @@ -176,15 +177,16 @@ export const useAppConfig = create()( }), { name: StoreKey.Config, - version: 3, + version: 3.1, migrate(persistedState, version) { - if (version === 3) return persistedState as any; + if (version === 3.1) return persistedState as any; const state = persistedState as ChatConfig; state.modelConfig.sendMemory = true; state.modelConfig.historyMessageCount = 4; state.modelConfig.compressMessageLengthThreshold = 1000; state.modelConfig.frequency_penalty = 0; + state.modelConfig.template = DEFAULT_INPUT_TEMPLATE; state.dontShowMaskSplashScreen = false; return state; diff --git a/app/store/mask.ts b/app/store/mask.ts index ed45241f8cd..6d6377c372d 100644 --- a/app/store/mask.ts +++ b/app/store/mask.ts @@ -3,7 +3,7 @@ import { persist } from "zustand/middleware"; import { BUILTIN_MASKS } from "../masks"; import { getLang, Lang } from "../locales"; import { DEFAULT_TOPIC, ChatMessage } from "./chat"; -import { ModelConfig, ModelType, useAppConfig } from "./config"; +import { ModelConfig, useAppConfig } from "./config"; import { StoreKey } from "../constant"; export type Mask = { @@ -89,7 +89,18 @@ export const useMaskStore = create()( const userMasks = Object.values(get().masks).sort( (a, b) => b.id - a.id, ); - return userMasks.concat(BUILTIN_MASKS); + const config = useAppConfig.getState(); + const buildinMasks = BUILTIN_MASKS.map( + (m) => + ({ + ...m, + modelConfig: { + ...config.modelConfig, + ...m.modelConfig, + }, + } as Mask), + ); + return userMasks.concat(buildinMasks); }, search(text) { return Object.values(get().masks);