Skip to content

Commit

Permalink
Merge pull request #5607 from ConnectAI-E/hotfix/summarize-model
Browse files Browse the repository at this point in the history
fix compressModel, related #5426, fix #5606 #5603 #5575
  • Loading branch information
lloydzhou authored Oct 9, 2024
2 parents 5b4d423 + 93ca303 commit cbdc611
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
59 changes: 54 additions & 5 deletions app/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ import {
DEFAULT_SYSTEM_TEMPLATE,
KnowledgeCutOffDate,
StoreKey,
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
ServiceProvider,
} from "../constant";
import Locale, { getLang } from "../locales";
import { isDalle3, safeLocalStorage } from "../utils";
import { prettyObject } from "../utils/format";
import { createPersistStore } from "../utils/store";
import { estimateTokenLength } from "../utils/token";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { useAccessStore } from "./access";
import { collectModelsWithDefaultModel } from "../utils/model";
import { createEmptyMask, Mask } from "./mask";

const localStorage = safeLocalStorage();
Expand Down Expand Up @@ -103,6 +108,35 @@ function createEmptySession(): ChatSession {
};
}

function getSummarizeModel(
currentModel: string,
providerName: string,
): string[] {
// if it is using gpt-* models, force to use 4o-mini to summarize
if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
const configStore = useAppConfig.getState();
const accessStore = useAccessStore.getState();
const allModel = collectModelsWithDefaultModel(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
accessStore.defaultModel,
);
const summarizeModel = allModel.find(
(m) => m.name === SUMMARIZE_MODEL && m.available,
);
if (summarizeModel) {
return [
summarizeModel.name,
summarizeModel.provider?.providerName as string,
];
}
}
if (currentModel.startsWith("gemini")) {
return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
}
return [currentModel, providerName];
}

function countMessages(msgs: ChatMessage[]) {
return msgs.reduce(
(pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
Expand Down Expand Up @@ -579,8 +613,14 @@ export const useChatStore = createPersistStore(
return;
}

const providerName = modelConfig.compressProviderName;
const api: ClientApi = getClientApi(providerName);
// if not config compressModel, then using getSummarizeModel
const [model, providerName] = modelConfig.compressModel
? [modelConfig.compressModel, modelConfig.compressProviderName]
: getSummarizeModel(
session.mask.modelConfig.model,
session.mask.modelConfig.providerName,
);
const api: ClientApi = getClientApi(providerName as ServiceProvider);

// remove error messages if any
const messages = session.messages;
Expand Down Expand Up @@ -611,7 +651,7 @@ export const useChatStore = createPersistStore(
api.llm.chat({
messages: topicMessages,
config: {
model: modelConfig.compressModel,
model,
stream: false,
providerName,
},
Expand Down Expand Up @@ -675,7 +715,8 @@ export const useChatStore = createPersistStore(
config: {
...modelcfg,
stream: true,
model: modelConfig.compressModel,
model,
providerName,
},
onUpdate(message) {
session.memoryPrompt = message;
Expand Down Expand Up @@ -728,7 +769,7 @@ export const useChatStore = createPersistStore(
},
{
name: StoreKey.Chat,
version: 3.2,
version: 3.3,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
Expand Down Expand Up @@ -784,6 +825,14 @@ export const useChatStore = createPersistStore(
config.modelConfig.compressProviderName;
});
}
// revert default summarize model for every session
if (version < 3.3) {
newState.sessions.forEach((s) => {
const config = useAppConfig.getState();
s.mask.modelConfig.compressModel = "";
s.mask.modelConfig.compressProviderName = "";
});
}

return newState as any;
},
Expand Down
8 changes: 4 additions & 4 deletions app/store/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ export const DEFAULT_CONFIG = {
sendMemory: true,
historyMessageCount: 4,
compressMessageLengthThreshold: 1000,
compressModel: "gpt-4o-mini" as ModelType,
compressProviderName: "OpenAI" as ServiceProvider,
compressModel: "",
compressProviderName: "",
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
Expand Down Expand Up @@ -178,7 +178,7 @@ export const useAppConfig = createPersistStore(
}),
{
name: StoreKey.Config,
version: 4,
version: 4.1,

merge(persistedState, currentState) {
const state = persistedState as ChatConfig | undefined;
Expand Down Expand Up @@ -231,7 +231,7 @@ export const useAppConfig = createPersistStore(
: config?.template ?? DEFAULT_INPUT_TEMPLATE;
}

if (version < 4) {
if (version < 4.1) {
state.modelConfig.compressModel =
DEFAULT_CONFIG.modelConfig.compressModel;
state.modelConfig.compressProviderName =
Expand Down

0 comments on commit cbdc611

Please sign in to comment.