From b844045d231658b9e40fa0582936c6746e7a7ef4 Mon Sep 17 00:00:00 2001
From: ryanhex53 <ouyang.em@gmail.com>
Date: Tue, 5 Nov 2024 07:44:12 +0000
Subject: [PATCH 1/2] Custom model names can include the `@` symbol by itself.

To specify the model's provider, append it after the model name using `@` as before.

This format supports cases like `google vertex ai` with a model name like `claude-3-5-sonnet@20240620`.

For instance, `claude-3-5-sonnet@20240620@vertex-ai` will be split by `split(/@(?!.*@)/)` into:

`[ 'claude-3-5-sonnet@20240620', 'vertex-ai' ]`, where the former is the model name and the latter is the custom provider.
---
 app/api/common.ts               | 2 +-
 app/components/chat.tsx         | 2 +-
 app/components/model-config.tsx | 6 ++++--
 app/store/access.ts             | 2 +-
 app/utils/model.ts              | 8 ++++----
 5 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/app/api/common.ts b/app/api/common.ts
index b4c792d6ff0..322dedeedfe 100644
--- a/app/api/common.ts
+++ b/app/api/common.ts
@@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
         .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
         .forEach((m) => {
           const [fullName, displayName] = m.split("=");
-          const [_, providerName] = fullName.split("@");
+          const [_, providerName] = fullName.split(/@(?!.*@)/);
           if (providerName === "azure" && !displayName) {
             const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
               "deployments/",
diff --git a/app/components/chat.tsx b/app/components/chat.tsx
index 3d5b6a4f2c4..2ff08253a1c 100644
--- a/app/components/chat.tsx
+++ b/app/components/chat.tsx
@@ -645,7 +645,7 @@ export function ChatActions(props: {
           onClose={() => setShowModelSelector(false)}
           onSelection={(s) => {
             if (s.length === 0) return;
-            const [model, providerName] = s[0].split("@");
+            const [model, providerName] = s[0].split(/@(?!.*@)/);
             chatStore.updateCurrentSession((session) => {
               session.mask.modelConfig.model = model as ModelType;
               session.mask.modelConfig.providerName =
diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx
index f2297e10b49..0eac916eb3e 100644
--- a/app/components/model-config.tsx
+++ b/app/components/model-config.tsx
@@ -28,7 +28,8 @@ export function ModelConfigList(props: {
           value={value}
           align="left"
           onChange={(e) => {
-            const [model, providerName] = e.currentTarget.value.split("@");
+            const [model, providerName] =
+              e.currentTarget.value.split(/@(?!.*@)/);
             props.updateConfig((config) => {
               config.model = ModalConfigValidator.model(model);
               config.providerName = providerName as ServiceProvider;
@@ -247,7 +248,8 @@ export function ModelConfigList(props: {
           aria-label={Locale.Settings.CompressModel.Title}
           value={compressModelValue}
           onChange={(e) => {
-            const [model, providerName] = e.currentTarget.value.split("@");
+            const [model, providerName] =
+              e.currentTarget.value.split(/@(?!.*@)/);
             props.updateConfig((config) => {
               config.compressModel = ModalConfigValidator.model(model);
               config.compressProviderName = providerName as ServiceProvider;
diff --git a/app/store/access.ts b/app/store/access.ts
index 3b0e6357bc1..4e2cb160362 100644
--- a/app/store/access.ts
+++ b/app/store/access.ts
@@ -226,7 +226,7 @@ export const useAccessStore = createPersistStore(
         .then((res) => {
           const defaultModel = res.defaultModel ?? "";
           if (defaultModel !== "") {
-            const [model, providerName] = defaultModel.split("@");
+            const [model, providerName] = defaultModel.split(/@(?!.*@)/);
             DEFAULT_CONFIG.modelConfig.model = model;
             DEFAULT_CONFIG.modelConfig.providerName = providerName;
           }
diff --git a/app/utils/model.ts b/app/utils/model.ts
index 0b62b53be09..0b95713e1c3 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -79,10 +79,10 @@ export function collectModelTable(
         );
       } else {
         // 1. find model by name, and set available value
-        const [customModelName, customProviderName] = name.split("@");
+        const [customModelName, customProviderName] = name.split(/@(?!.*@)/);
         let count = 0;
         for (const fullName in modelTable) {
-          const [modelName, providerName] = fullName.split("@");
+          const [modelName, providerName] = fullName.split(/@(?!.*@)/);
           if (
             customModelName == modelName &&
             (customProviderName === undefined ||
@@ -102,7 +102,7 @@ export function collectModelTable(
         }
         // 2. if model not exists, create new model with available value
         if (count === 0) {
-          let [customModelName, customProviderName] = name.split("@");
+          let [customModelName, customProviderName] = name.split(/@(?!.*@)/);
           const provider = customProvider(
             customProviderName || customModelName,
           );
@@ -139,7 +139,7 @@ export function collectModelTableWithDefaultModel(
       for (const key of Object.keys(modelTable)) {
         if (
           modelTable[key].available &&
-          key.split("@").shift() == defaultModel
+          key.split(/@(?!.*@)/).shift() == defaultModel
         ) {
           modelTable[key].isDefault = true;
           break;

From 8e2484fcdf476a1248ae91541d6d491e5881b49b Mon Sep 17 00:00:00 2001
From: ryanhex53 <ouyang.em@gmail.com>
Date: Tue, 5 Nov 2024 13:52:54 +0000
Subject: [PATCH 2/2] Refactor: Replace all provider split occurrences with
 getModelProvider utility method

---
 app/api/common.ts               |  4 ++--
 app/components/chat.tsx         |  3 ++-
 app/components/model-config.tsx | 11 +++++++----
 app/store/access.ts             |  5 +++--
 app/utils/model.ts              | 19 +++++++++++++++----
 test/model-provider.test.ts     | 31 +++++++++++++++++++++++++++++++
 6 files changed, 60 insertions(+), 13 deletions(-)
 create mode 100644 test/model-provider.test.ts

diff --git a/app/api/common.ts b/app/api/common.ts
index 322dedeedfe..495a12ccdbb 100644
--- a/app/api/common.ts
+++ b/app/api/common.ts
@@ -1,8 +1,8 @@
 import { NextRequest, NextResponse } from "next/server";
 import { getServerSideConfig } from "../config/server";
 import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
-import { isModelAvailableInServer } from "../utils/model";
 import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
+import { getModelProvider, isModelAvailableInServer } from "../utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
         .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
         .forEach((m) => {
           const [fullName, displayName] = m.split("=");
-          const [_, providerName] = fullName.split(/@(?!.*@)/);
+          const [_, providerName] = getModelProvider(fullName);
           if (providerName === "azure" && !displayName) {
             const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
               "deployments/",
diff --git a/app/components/chat.tsx b/app/components/chat.tsx
index 2ff08253a1c..cee54d8914f 100644
--- a/app/components/chat.tsx
+++ b/app/components/chat.tsx
@@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
 import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
 
 import { isEmpty } from "lodash-es";
+import { getModelProvider } from "../utils/model";
 
 const localStorage = safeLocalStorage();
 
@@ -645,7 +646,7 @@ export function ChatActions(props: {
           onClose={() => setShowModelSelector(false)}
           onSelection={(s) => {
             if (s.length === 0) return;
-            const [model, providerName] = s[0].split(/@(?!.*@)/);
+            const [model, providerName] = getModelProvider(s[0]);
             chatStore.updateCurrentSession((session) => {
               session.mask.modelConfig.model = model as ModelType;
               session.mask.modelConfig.providerName =
diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx
index 0eac916eb3e..e845bfeac7a 100644
--- a/app/components/model-config.tsx
+++ b/app/components/model-config.tsx
@@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
 import { useAllModels } from "../utils/hooks";
 import { groupBy } from "lodash-es";
 import styles from "./model-config.module.scss";
+import { getModelProvider } from "../utils/model";
 
 export function ModelConfigList(props: {
   modelConfig: ModelConfig;
@@ -28,8 +29,9 @@ export function ModelConfigList(props: {
           value={value}
           align="left"
           onChange={(e) => {
-            const [model, providerName] =
-              e.currentTarget.value.split(/@(?!.*@)/);
+            const [model, providerName] = getModelProvider(
+              e.currentTarget.value,
+            );
             props.updateConfig((config) => {
               config.model = ModalConfigValidator.model(model);
               config.providerName = providerName as ServiceProvider;
@@ -248,8 +250,9 @@ export function ModelConfigList(props: {
           aria-label={Locale.Settings.CompressModel.Title}
           value={compressModelValue}
           onChange={(e) => {
-            const [model, providerName] =
-              e.currentTarget.value.split(/@(?!.*@)/);
+            const [model, providerName] = getModelProvider(
+              e.currentTarget.value,
+            );
             props.updateConfig((config) => {
               config.compressModel = ModalConfigValidator.model(model);
               config.compressProviderName = providerName as ServiceProvider;
diff --git a/app/store/access.ts b/app/store/access.ts
index 4e2cb160362..4796b2fe84e 100644
--- a/app/store/access.ts
+++ b/app/store/access.ts
@@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
 import { createPersistStore } from "../utils/store";
 import { ensure } from "../utils/clone";
 import { DEFAULT_CONFIG } from "./config";
+import { getModelProvider } from "../utils/model";
 
 let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
 
@@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
         .then((res) => {
           const defaultModel = res.defaultModel ?? "";
           if (defaultModel !== "") {
-            const [model, providerName] = defaultModel.split(/@(?!.*@)/);
+            const [model, providerName] = getModelProvider(defaultModel);
             DEFAULT_CONFIG.modelConfig.model = model;
-            DEFAULT_CONFIG.modelConfig.providerName = providerName;
+            DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
           }
 
           return res;
diff --git a/app/utils/model.ts b/app/utils/model.ts
index 0b95713e1c3..a1b7df1b61e 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
     }
   });
 
+/**
+ * get model name and provider from a formatted string,
+ * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
+ * @param modelWithProvider model name with provider separated by last `@` char,
+ * @returns [model, provider] tuple, if no `@` char found, provider is undefined
+ */
+export function getModelProvider(modelWithProvider: string): [string, string?] {
+  const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
+  return [model, provider];
+}
+
 export function collectModelTable(
   models: readonly LLMModel[],
   customModels: string,
@@ -79,10 +90,10 @@ export function collectModelTable(
         );
       } else {
         // 1. find model by name, and set available value
-        const [customModelName, customProviderName] = name.split(/@(?!.*@)/);
+        const [customModelName, customProviderName] = getModelProvider(name);
         let count = 0;
         for (const fullName in modelTable) {
-          const [modelName, providerName] = fullName.split(/@(?!.*@)/);
+          const [modelName, providerName] = getModelProvider(fullName);
           if (
             customModelName == modelName &&
             (customProviderName === undefined ||
@@ -102,7 +113,7 @@ export function collectModelTable(
         }
         // 2. if model not exists, create new model with available value
         if (count === 0) {
-          let [customModelName, customProviderName] = name.split(/@(?!.*@)/);
+          let [customModelName, customProviderName] = getModelProvider(name);
           const provider = customProvider(
             customProviderName || customModelName,
           );
@@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
       for (const key of Object.keys(modelTable)) {
         if (
           modelTable[key].available &&
-          key.split(/@(?!.*@)/).shift() == defaultModel
+          getModelProvider(key)[0] == defaultModel
         ) {
           modelTable[key].isDefault = true;
           break;
diff --git a/test/model-provider.test.ts b/test/model-provider.test.ts
new file mode 100644
index 00000000000..41f14be026c
--- /dev/null
+++ b/test/model-provider.test.ts
@@ -0,0 +1,31 @@
+import { getModelProvider } from "../app/utils/model";
+
+describe("getModelProvider", () => {
+  test("should return model and provider when input contains '@'", () => {
+    const input = "model@provider";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("model");
+    expect(provider).toBe("provider");
+  });
+
+  test("should return model and undefined provider when input does not contain '@'", () => {
+    const input = "model";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("model");
+    expect(provider).toBeUndefined();
+  });
+
+  test("should handle multiple '@' characters correctly", () => {
+    const input = "model@provider@extra";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("model@provider");
+    expect(provider).toBe("extra");
+  });
+
+  test("should return empty strings when input is empty", () => {
+    const input = "";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("");
+    expect(provider).toBeUndefined();
+  });
+});