From dc76d6f8fcae3133f5dd6cc4ca40c22bdead4e90 Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Wed, 8 Jan 2025 10:21:58 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20model=20select=20no?= =?UTF-8?q?t=20auto=20update=20and=20sort=20issue=20(#5330)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- next.config.ts | 2 +- src/components/ModelSelect/index.tsx | 25 +++++++++++++------ src/database/repositories/aiInfra/index.ts | 6 +++-- src/database/server/models/aiModel.ts | 1 + src/features/ModelSelect/index.tsx | 9 ++++++- src/features/ModelSwitchPanel/index.tsx | 9 ++++++- src/store/aiInfra/slices/aiProvider/action.ts | 4 +-- .../modelList/selectors/modelProvider.ts | 1 + src/types/aiModel.ts | 4 +++ src/types/aiProvider.ts | 6 +++-- 10 files changed, 51 insertions(+), 16 deletions(-) diff --git a/next.config.ts b/next.config.ts index c758f7a15647..cc2b65eaa147 100644 --- a/next.config.ts +++ b/next.config.ts @@ -164,7 +164,7 @@ const nextConfig: NextConfig = { source: '/welcome', }, ], - serverExternalPackages: ['@electric-sql/pglite', 'shiki/wasm', 'sharp'], + serverExternalPackages: ['@electric-sql/pglite', 'sharp'], transpilePackages: ['pdfjs-dist', 'mermaid'], diff --git a/src/components/ModelSelect/index.tsx b/src/components/ModelSelect/index.tsx index e6d403e1383d..ccf26d754336 100644 --- a/src/components/ModelSelect/index.tsx +++ b/src/components/ModelSelect/index.tsx @@ -1,5 +1,5 @@ import { IconAvatarProps, ModelIcon, ProviderIcon } from '@lobehub/icons'; -import { Icon, Tooltip } from '@lobehub/ui'; +import { Avatar, Icon, Tooltip } from '@lobehub/ui'; import { Typography } from 'antd'; import { createStyles } from 'antd-style'; import { Infinity, LucideEye, LucidePaperclip, ToyBrick } from 'lucide-react'; @@ -10,6 +10,7 @@ import { useTranslation } from 'react-i18next'; import { Center, Flexbox } from 'react-layout-kit'; import { ModelAbilities } from '@/types/aiModel'; +import { AiProviderSourceType } from '@/types/aiProvider'; import { ChatModelCard } from '@/types/llm'; import { formatTokenNumber } from '@/utils/format'; @@ -153,16 +154,26 @@ export const ModelItemRender = memo(({ showInfoTag = true, }); interface ProviderItemRenderProps { + logo?: string; name: string; provider: string; + source?: AiProviderSourceType; } -export const ProviderItemRender = memo(({ provider, name }) => ( - - - {name} - -)); +export const ProviderItemRender = memo( + ({ provider, name, source, logo }) => { + return ( + + {source === 'custom' && !!logo ? ( + + ) : ( + + )} + {name} + + ); + }, +); interface LabelRendererProps { Icon: FC; diff --git a/src/database/repositories/aiInfra/index.ts b/src/database/repositories/aiInfra/index.ts index 90bd33da0cc1..94f8964a1517 100644 --- a/src/database/repositories/aiInfra/index.ts +++ b/src/database/repositories/aiInfra/index.ts @@ -62,7 +62,7 @@ export class AiInfraRepos { return list .filter((item) => item.enabled) .sort((a, b) => a.sort! - b.sort!) - .map((item) => ({ id: item.id, name: item.name, source: item.source })); + .map((item) => ({ id: item.id, logo: item.logo, name: item.name, source: item.source })); }; getEnabledModels = async () => { @@ -86,7 +86,9 @@ export class AiInfraRepos { { concurrency: 10 }, ); - return [...modelList.flat(), ...userEnabledModels] as EnabledAiModel[]; + return [...modelList.flat(), ...userEnabledModels].sort( + (a, b) => (a?.sort || -1) - (b?.sort || -1), + ) as EnabledAiModel[]; }; getAiProviderModelList = async (providerId: string) => { diff --git a/src/database/server/models/aiModel.ts b/src/database/server/models/aiModel.ts index a66703809b9d..225156754d67 100644 --- a/src/database/server/models/aiModel.ts +++ b/src/database/server/models/aiModel.ts @@ -93,6 +93,7 @@ export class AiModelModel { enabled: aiModels.enabled, id: aiModels.id, providerId: aiModels.providerId, + sort: aiModels.sort, source: aiModels.source, type: aiModels.type, }) diff --git a/src/features/ModelSelect/index.tsx b/src/features/ModelSelect/index.tsx index 7f7551b3c915..d885210e2cc2 100644 --- a/src/features/ModelSelect/index.tsx +++ b/src/features/ModelSelect/index.tsx @@ -46,7 +46,14 @@ const ModelSelect = memo(({ value, onChange, showAbility = tru } return enabledList.map((provider) => ({ - label: , + label: ( + + ), options: getChatModels(provider), })); }, [enabledList]); diff --git a/src/features/ModelSwitchPanel/index.tsx b/src/features/ModelSwitchPanel/index.tsx index 2f92814ab0dc..92837e1127cd 100644 --- a/src/features/ModelSwitchPanel/index.tsx +++ b/src/features/ModelSwitchPanel/index.tsx @@ -88,7 +88,14 @@ const ModelSwitchPanel = memo(({ children }) => { return enabledList.map((provider) => ({ children: getModelItems(provider), key: provider.id, - label: , + label: ( + + ), type: 'group', })); }, [enabledList]); diff --git a/src/store/aiInfra/slices/aiProvider/action.ts b/src/store/aiInfra/slices/aiProvider/action.ts index d147bbb4c196..4879fa39a889 100644 --- a/src/store/aiInfra/slices/aiProvider/action.ts +++ b/src/store/aiInfra/slices/aiProvider/action.ts @@ -84,7 +84,7 @@ export const createAiProviderSlice: StateCreator< await get().refreshAiProviderRuntimeState(); }, refreshAiProviderRuntimeState: async () => { - await mutate(AiProviderSwrKey.fetchAiProviderRuntimeState); + await mutate([AiProviderSwrKey.fetchAiProviderRuntimeState, true]); }, removeAiProvider: async (id) => { await aiProviderService.deleteAiProvider(id); @@ -187,8 +187,8 @@ export const createAiProviderSlice: StateCreator< // 3. 组装最终数据结构 const enabledChatModelList = data.enabledAiProviders.map((provider) => ({ + ...provider, children: getModelListByType(provider.id, 'chat'), - id: provider.id, name: provider.name || provider.id, })); diff --git a/src/store/user/slices/modelList/selectors/modelProvider.ts b/src/store/user/slices/modelList/selectors/modelProvider.ts index 0b067f6c688b..cf0c71f67243 100644 --- a/src/store/user/slices/modelList/selectors/modelProvider.ts +++ b/src/store/user/slices/modelList/selectors/modelProvider.ts @@ -105,6 +105,7 @@ const modelProviderListForModelSelect = (s: UserStore): EnabledProviderWithModel displayName: m.displayName, id: m.id, })), + source: 'builtin', })); const getModelCardById = (id: string) => (s: UserStore) => { diff --git a/src/types/aiModel.ts b/src/types/aiModel.ts index af18c287da53..42f65894881d 100644 --- a/src/types/aiModel.ts +++ b/src/types/aiModel.ts @@ -1,5 +1,7 @@ import { z } from 'zod'; +import { AiProviderSourceType } from '@/types/aiProvider'; + export type ModelPriceCurrency = 'CNY' | 'USD'; export const AiModelSourceEnum = { @@ -312,5 +314,7 @@ interface AiModelForSelect { export interface EnabledProviderWithModels { children: AiModelForSelect[]; id: string; + logo?: string; name: string; + source: AiProviderSourceType; } diff --git a/src/types/aiProvider.ts b/src/types/aiProvider.ts index 91bbeb555197..c411622d3e66 100644 --- a/src/types/aiProvider.ts +++ b/src/types/aiProvider.ts @@ -1,6 +1,6 @@ import { z } from 'zod'; -import { AiModelConfig, AiModelSourceType, AiModelType, ModelAbilities } from '@/types/aiModel'; +import { AiModelConfig, AiModelType, ModelAbilities } from '@/types/aiModel'; import { SmoothingParams } from '@/types/llm'; export const AiProviderSourceEnum = { @@ -182,8 +182,9 @@ export interface AiProviderSortMap { export interface EnabledProvider { id: string; + logo?: string; name?: string; - source: AiModelSourceType; + source: AiProviderSourceType; } export interface EnabledAiModel { @@ -193,6 +194,7 @@ export interface EnabledAiModel { displayName?: string; id: string; providerId: string; + sort?: number; type: AiModelType; }