From b2775b571bef0fc5653b755683abd093cab35aba Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Thu, 9 Jan 2025 13:20:46 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20some=20ai=20provide?= =?UTF-8?q?r=20known=20issues=20(#5361)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix provider url * improve fetch model list issue * fix builtin model sort and displayName * fix user enabled models * fix model name * fix model displayName name --- .../[slug]/features/ProviderConfig.tsx | 5 +- src/database/repositories/aiInfra/index.ts | 13 ++-- .../server/models/__tests__/aiModel.test.ts | 2 +- src/database/server/models/aiModel.ts | 59 +++++-------------- 4 files changed, 26 insertions(+), 53 deletions(-) diff --git a/src/app/(main)/discover/(detail)/provider/[slug]/features/ProviderConfig.tsx b/src/app/(main)/discover/(detail)/provider/[slug]/features/ProviderConfig.tsx index 0e14c3b49a4d..3e660023b229 100644 --- a/src/app/(main)/discover/(detail)/provider/[slug]/features/ProviderConfig.tsx +++ b/src/app/(main)/discover/(detail)/provider/[slug]/features/ProviderConfig.tsx @@ -10,6 +10,7 @@ import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { FlexboxProps } from 'react-layout-kit'; +import { isServerMode } from '@/const/version'; import { DiscoverProviderItem } from '@/types/discover'; const useStyles = createStyles(({ css }) => ({ @@ -25,13 +26,13 @@ interface ProviderConfigProps extends FlexboxProps { identifier: string; } -const ProviderConfig = memo(({ data }) => { +const ProviderConfig = memo(({ data, identifier }) => { const { styles } = useStyles(); const { t } = useTranslation('discover'); const router = useRouter(); const openSettings = () => { - router.push('/settings/llm'); + router.push(!isServerMode ? '/settings/llm' : `/settings/provider/${identifier}`); }; const icon = ; diff --git a/src/database/repositories/aiInfra/index.ts b/src/database/repositories/aiInfra/index.ts index 756af2f8d968..7063a5af5303 100644 --- a/src/database/repositories/aiInfra/index.ts +++ b/src/database/repositories/aiInfra/index.ts @@ -81,13 +81,16 @@ export class AiInfraRepos { .map((item) => { const user = allModels.find((m) => m.id === item.id && m.providerId === provider.id); - const enabled = !!user ? user.enabled : item.enabled; - return { - ...item, - abilities: item.abilities || {}, - enabled, + abilities: !!user ? user.abilities : item.abilities || {}, + config: !!user ? user.config : item.config, + contextWindowTokens: !!user ? user.contextWindowTokens : item.contextWindowTokens, + displayName: user?.displayName || item.displayName, + enabled: !!user ? user.enabled : item.enabled, + id: item.id, providerId: provider.id, + sort: !!user ? user.sort : undefined, + type: item.type, }; }) .filter((i) => i.enabled); diff --git a/src/database/server/models/__tests__/aiModel.test.ts b/src/database/server/models/__tests__/aiModel.test.ts index 71d5a9c475c9..7a8dea883e36 100644 --- a/src/database/server/models/__tests__/aiModel.test.ts +++ b/src/database/server/models/__tests__/aiModel.test.ts @@ -248,7 +248,7 @@ describe('AiModelModel', () => { const allModels = await aiProviderModel.query(); expect(allModels).toHaveLength(2); - expect(allModels.find((m) => m.id === 'existing-model')?.displayName).toBe('Updated Name'); + expect(allModels.find((m) => m.id === 'existing-model')?.displayName).toBe('Old Name'); expect(allModels.find((m) => m.id === 'new-model')?.displayName).toBe('New Model'); }); }); diff --git a/src/database/server/models/aiModel.ts b/src/database/server/models/aiModel.ts index 571e9bfdf6d2..cf62186ea0a3 100644 --- a/src/database/server/models/aiModel.ts +++ b/src/database/server/models/aiModel.ts @@ -1,5 +1,4 @@ import { and, asc, desc, eq, inArray } from 'drizzle-orm/expressions'; -import pMap from 'p-map'; import { LobeChatDatabase } from '@/database/type'; import { @@ -131,51 +130,21 @@ export class AiModelModel { }; batchUpdateAiModels = async (providerId: string, models: AiProviderModelListItem[]) => { - return this.db.transaction(async (trx) => { - const records = models.map(({ id, ...model }) => ({ - ...model, - id, - providerId, - updatedAt: new Date(), - userId: this.userId, - })); + const records = models.map(({ id, ...model }) => ({ + ...model, + id, + providerId, + updatedAt: new Date(), + userId: this.userId, + })); - // 第一步:尝试插入所有记录,忽略冲突 - const insertedRecords = await trx - .insert(aiModels) - .values(records) - .onConflictDoNothing({ - target: [aiModels.id, aiModels.userId, aiModels.providerId], - }) - .returning(); - // 第二步:找出需要更新的记录(即插入时发生冲突的记录) - // 找出未能插入的记录(需要更新的记录) - const insertedIds = new Set(insertedRecords.map((r) => r.id)); - const recordsToUpdate = records.filter((r) => !insertedIds.has(r.id)); - - // 第三步:更新已存在的记录 - if (recordsToUpdate.length > 0) { - await pMap( - recordsToUpdate, - async (record) => { - await trx - .update(aiModels) - .set({ - ...record, - updatedAt: new Date(), - }) - .where( - and( - eq(aiModels.id, record.id), - eq(aiModels.userId, this.userId), - eq(aiModels.providerId, providerId), - ), - ); - }, - { concurrency: 10 }, // 限制并发数为 10 - ); - } - }); + return this.db + .insert(aiModels) + .values(records) + .onConflictDoNothing({ + target: [aiModels.id, aiModels.userId, aiModels.providerId], + }) + .returning(); }; batchToggleAiModels = async (providerId: string, models: string[], enabled: boolean) => {