diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx
index 325e1e3bb5f..f79e0e8f6f1 100644
--- a/app/components/model-config.tsx
+++ b/app/components/model-config.tsx
@@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store";
import Locale from "../locales";
import { InputRange } from "./input-range";
-import { List, ListItem, Select } from "./ui-lib";
+import { ListItem, Select } from "./ui-lib";
export function ModelConfigList(props: {
modelConfig: ModelConfig;
@@ -109,6 +109,21 @@ export function ModelConfigList(props: {
>
+
+
+ props.updateConfig(
+ (config) => (config.template = e.currentTarget.value),
+ )
+ }
+ >
+
+
`当前版本:${x}`,
IsLatest: "已是最新版本",
diff --git a/app/locales/en.ts b/app/locales/en.ts
index 265f5cff3f7..8e56147c827 100644
--- a/app/locales/en.ts
+++ b/app/locales/en.ts
@@ -116,6 +116,12 @@ const en: LocaleType = {
Title: "Font Size",
SubTitle: "Adjust font size of chat content",
},
+
+ InputTemplate: {
+ Title: "Input Template",
+ SubTitle: "Newest message will be filled to this template",
+ },
+
Update: {
Version: (x: string) => `Version: ${x}`,
IsLatest: "Latest version",
diff --git a/app/masks/index.ts b/app/masks/index.ts
index 07c6a3e8cda..b9cb23f20b5 100644
--- a/app/masks/index.ts
+++ b/app/masks/index.ts
@@ -9,7 +9,7 @@ export const BUILTIN_MASK_ID = 100000;
export const BUILTIN_MASK_STORE = {
buildinId: BUILTIN_MASK_ID,
- masks: {} as Record,
+ masks: {} as Record,
get(id?: number) {
if (!id) return undefined;
return this.masks[id] as Mask | undefined;
@@ -21,6 +21,6 @@ export const BUILTIN_MASK_STORE = {
},
};
-export const BUILTIN_MASKS: Mask[] = [...CN_MASKS, ...EN_MASKS].map((m) =>
- BUILTIN_MASK_STORE.add(m),
+export const BUILTIN_MASKS: BuiltinMask[] = [...CN_MASKS, ...EN_MASKS].map(
+ (m) => BUILTIN_MASK_STORE.add(m),
);
diff --git a/app/masks/typing.ts b/app/masks/typing.ts
index 510d94a2c20..1ded6a90295 100644
--- a/app/masks/typing.ts
+++ b/app/masks/typing.ts
@@ -1,5 +1,7 @@
+import { ModelConfig } from "../store";
import { type Mask } from "../store/mask";
-export type BuiltinMask = Omit & {
- builtin: true;
+export type BuiltinMask = Omit & {
+ builtin: Boolean;
+ modelConfig: Partial;
};
diff --git a/app/store/chat.ts b/app/store/chat.ts
index 629eeab94de..d311c88ff9d 100644
--- a/app/store/chat.ts
+++ b/app/store/chat.ts
@@ -3,11 +3,11 @@ import { persist } from "zustand/middleware";
import { trimTopic } from "../utils";
-import Locale from "../locales";
+import Locale, { getLang } from "../locales";
import { showToast } from "../components/ui-lib";
-import { ModelType } from "./config";
+import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
-import { StoreKey } from "../constant";
+import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
import { api, RequestMessage } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
@@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) {
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
}
+function fillTemplateWith(input: string, modelConfig: ModelConfig) {
+ const vars = {
+ model: modelConfig.model,
+ time: new Date().toLocaleString(),
+ lang: getLang(),
+ input: input,
+ };
+
+ let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
+
+ // must contains {{input}}
+ const inputVar = "{{input}}";
+ if (!output.includes(inputVar)) {
+ output += "\n" + inputVar;
+ }
+
+ Object.entries(vars).forEach(([name, value]) => {
+ output = output.replaceAll(`{{${name}}}`, value);
+ });
+
+ return output;
+}
+
export const useChatStore = create()(
persist(
(set, get) => ({
@@ -158,7 +181,16 @@ export const useChatStore = create()(
session.id = get().globalId;
if (mask) {
- session.mask = { ...mask };
+ const config = useAppConfig.getState();
+ const globalModelConfig = config.modelConfig;
+
+ session.mask = {
+ ...mask,
+ modelConfig: {
+ ...globalModelConfig,
+ ...mask.modelConfig,
+ },
+ };
session.topic = mask.name;
}
@@ -238,9 +270,12 @@ export const useChatStore = create()(
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
+ const userContent = fillTemplateWith(content, modelConfig);
+ console.log("[User Input] fill with template: ", userContent);
+
const userMessage: ChatMessage = createMessage({
role: "user",
- content,
+ content: userContent,
});
const botMessage: ChatMessage = createMessage({
@@ -250,31 +285,22 @@ export const useChatStore = create()(
model: modelConfig.model,
});
- const systemInfo = createMessage({
- role: "system",
- content: `IMPORTANT: You are a virtual assistant powered by the ${
- modelConfig.model
- } model, now time is ${new Date().toLocaleString()}}`,
- id: botMessage.id! + 1,
- });
-
// get recent messages
- const systemMessages = [];
- // if user define a mask with context prompts, wont send system info
- if (session.mask.context.length === 0) {
- systemMessages.push(systemInfo);
- }
-
const recentMessages = get().getMessagesWithMemory();
- const sendMessages = systemMessages.concat(
- recentMessages.concat(userMessage),
- );
+ const sendMessages = recentMessages.concat(userMessage);
const sessionIndex = get().currentSessionIndex;
const messageIndex = get().currentSession().messages.length + 1;
// save user's and bot's message
get().updateCurrentSession((session) => {
- session.messages = session.messages.concat([userMessage, botMessage]);
+ const savedUserMessage = {
+ ...userMessage,
+ content,
+ };
+ session.messages = session.messages.concat([
+ savedUserMessage,
+ botMessage,
+ ]);
});
// make request
@@ -350,55 +376,62 @@ export const useChatStore = create()(
getMessagesWithMemory() {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
+ const clearContextIndex = session.clearContextIndex ?? 0;
+ const messages = session.messages.slice();
+ const totalMessageCount = session.messages.length;
- // wont send cleared context messages
- const clearedContextMessages = session.messages.slice(
- session.clearContextIndex ?? 0,
- );
- const messages = clearedContextMessages.filter((msg) => !msg.isError);
- const n = messages.length;
-
- const context = session.mask.context.slice();
+ // in-context prompts
+ const contextPrompts = session.mask.context.slice();
// long term memory
- if (
+ const shouldSendLongTermMemory =
modelConfig.sendMemory &&
session.memoryPrompt &&
- session.memoryPrompt.length > 0
- ) {
- const memoryPrompt = get().getMemoryPrompt();
- context.push(memoryPrompt);
- }
-
- // get short term and unmemorized long term memory
- const shortTermMemoryMessageIndex = Math.max(
+ session.memoryPrompt.length > 0 &&
+ session.lastSummarizeIndex <= clearContextIndex;
+ const longTermMemoryPrompts = shouldSendLongTermMemory
+ ? [get().getMemoryPrompt()]
+ : [];
+ const longTermMemoryStartIndex = session.lastSummarizeIndex;
+
+ // short term memory
+ const shortTermMemoryStartIndex = Math.max(
0,
- n - modelConfig.historyMessageCount,
+ totalMessageCount - modelConfig.historyMessageCount,
);
- const longTermMemoryMessageIndex = session.lastSummarizeIndex;
- // try to concat history messages
+ // lets concat send messages, including 4 parts:
+ // 1. long term memory: summarized memory messages
+ // 2. pre-defined in-context prompts
+ // 3. short term memory: latest n messages
+ // 4. newest input message
const memoryStartIndex = Math.min(
- shortTermMemoryMessageIndex,
- longTermMemoryMessageIndex,
+ longTermMemoryStartIndex,
+ shortTermMemoryStartIndex,
);
- const threshold = modelConfig.max_tokens;
+ // and if user has cleared history messages, we should exclude the memory too.
+ const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
+ const maxTokenThreshold = modelConfig.max_tokens;
- // get recent messages as many as possible
+ // get recent messages as much as possible
const reversedRecentMessages = [];
for (
- let i = n - 1, count = 0;
- i >= memoryStartIndex && count < threshold;
+ let i = totalMessageCount - 1, tokenCount = 0;
+ i >= contextStartIndex && tokenCount < maxTokenThreshold;
i -= 1
) {
const msg = messages[i];
if (!msg || msg.isError) continue;
- count += estimateTokenLength(msg.content);
+ tokenCount += estimateTokenLength(msg.content);
reversedRecentMessages.push(msg);
}
- // concat
- const recentMessages = context.concat(reversedRecentMessages.reverse());
+ // concat all messages
+ const recentMessages = [
+ ...longTermMemoryPrompts,
+ ...contextPrompts,
+ ...reversedRecentMessages.reverse(),
+ ];
return recentMessages;
},
diff --git a/app/store/config.ts b/app/store/config.ts
index 2b8493ca7c3..b15fa914802 100644
--- a/app/store/config.ts
+++ b/app/store/config.ts
@@ -1,7 +1,7 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { getClientConfig } from "../config/client";
-import { StoreKey } from "../constant";
+import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
export enum SubmitKey {
Enter = "Enter",
@@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = {
sendMemory: true,
historyMessageCount: 4,
compressMessageLengthThreshold: 1000,
+ template: DEFAULT_INPUT_TEMPLATE,
},
};
@@ -176,15 +177,16 @@ export const useAppConfig = create()(
}),
{
name: StoreKey.Config,
- version: 3,
+ version: 3.1,
migrate(persistedState, version) {
- if (version === 3) return persistedState as any;
+ if (version === 3.1) return persistedState as any;
const state = persistedState as ChatConfig;
state.modelConfig.sendMemory = true;
state.modelConfig.historyMessageCount = 4;
state.modelConfig.compressMessageLengthThreshold = 1000;
state.modelConfig.frequency_penalty = 0;
+ state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
state.dontShowMaskSplashScreen = false;
return state;
diff --git a/app/store/mask.ts b/app/store/mask.ts
index ed45241f8cd..6d6377c372d 100644
--- a/app/store/mask.ts
+++ b/app/store/mask.ts
@@ -3,7 +3,7 @@ import { persist } from "zustand/middleware";
import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, ChatMessage } from "./chat";
-import { ModelConfig, ModelType, useAppConfig } from "./config";
+import { ModelConfig, useAppConfig } from "./config";
import { StoreKey } from "../constant";
export type Mask = {
@@ -89,7 +89,18 @@ export const useMaskStore = create()(
const userMasks = Object.values(get().masks).sort(
(a, b) => b.id - a.id,
);
- return userMasks.concat(BUILTIN_MASKS);
+ const config = useAppConfig.getState();
+ const buildinMasks = BUILTIN_MASKS.map(
+ (m) =>
+ ({
+ ...m,
+ modelConfig: {
+ ...config.modelConfig,
+ ...m.modelConfig,
+ },
+ } as Mask),
+ );
+ return userMasks.concat(buildinMasks);
},
search(text) {
return Object.values(get().masks);