diff --git a/changelog/v1.json b/changelog/v1.json index d186b77ff88c..58586b73d79e 100644 --- a/changelog/v1.json +++ b/changelog/v1.json @@ -1,4 +1,11 @@ [ + { + "children": { + "fixes": ["Fix provider enabled issue."] + }, + "date": "2025-01-08", + "version": "1.44.3" + }, { "children": { "fixes": ["Add provider id validate."] diff --git a/src/database/repositories/aiInfra/index.ts b/src/database/repositories/aiInfra/index.ts index 94f8964a1517..756af2f8d968 100644 --- a/src/database/repositories/aiInfra/index.ts +++ b/src/database/repositories/aiInfra/index.ts @@ -69,19 +69,28 @@ export class AiInfraRepos { const providers = await this.getAiProviderList(); const enabledProviders = providers.filter((item) => item.enabled); - const userEnabledModels = await this.aiModelModel.getEnabledModels(); + const allModels = await this.aiModelModel.getAllModels(); + const userEnabledModels = allModels.filter((item) => item.enabled); + const modelList = await pMap( enabledProviders, async (provider) => { const aiModels = await this.fetchBuiltinModels(provider.id); return (aiModels || []) - .filter((i) => i.enabled) - .map((item) => ({ - ...item, - abilities: item.abilities || {}, - providerId: provider.id, - })); + .map((item) => { + const user = allModels.find((m) => m.id === item.id && m.providerId === provider.id); + + const enabled = !!user ? user.enabled : item.enabled; + + return { + ...item, + abilities: item.abilities || {}, + enabled, + providerId: provider.id, + }; + }) + .filter((i) => i.enabled); }, { concurrency: 10 }, ); @@ -100,6 +109,9 @@ export class AiInfraRepos { return mergeArrayById(defaultModels, aiModels) as AiProviderModelListItem[]; }; + /** + * Fetch builtin models from config + */ private fetchBuiltinModels = async ( providerId: string, ): Promise => { diff --git a/src/database/server/models/__tests__/aiModel.test.ts b/src/database/server/models/__tests__/aiModel.test.ts index ddd399886faa..71d5a9c475c9 100644 --- a/src/database/server/models/__tests__/aiModel.test.ts +++ b/src/database/server/models/__tests__/aiModel.test.ts @@ -193,17 +193,15 @@ describe('AiModelModel', () => { }); }); - describe('getEnabledModels', () => { + describe('getAllModels', () => { it('should only return enabled models', async () => { await serverDB.insert(aiModels).values([ { id: 'model1', providerId: 'openai', enabled: true, source: 'custom', userId }, - { id: 'model2', providerId: 'openai', enabled: false, source: 'custom', userId }, + { id: 'model2', providerId: 'b', enabled: false, source: 'custom', userId }, ]); - const models = await aiProviderModel.getEnabledModels(); - expect(models).toHaveLength(1); - expect(models[0].id).toBe('model1'); - expect(models[0].enabled).toBe(true); + const models = await aiProviderModel.getAllModels(); + expect(models).toHaveLength(2); }); }); diff --git a/src/database/server/models/aiModel.ts b/src/database/server/models/aiModel.ts index 225156754d67..571e9bfdf6d2 100644 --- a/src/database/server/models/aiModel.ts +++ b/src/database/server/models/aiModel.ts @@ -8,6 +8,7 @@ import { AiProviderModelListItem, ToggleAiModelEnableParams, } from '@/types/aiModel'; +import { EnabledAiModel } from '@/types/aiProvider'; import { AiModelSelectItem, NewAiModelItem, aiModels } from '../../schemas'; @@ -83,8 +84,8 @@ export class AiModelModel { return result as AiProviderModelListItem[]; }; - getEnabledModels = async () => { - return this.db + getAllModels = async () => { + const data = await this.db .select({ abilities: aiModels.abilities, config: aiModels.config, @@ -98,7 +99,9 @@ export class AiModelModel { type: aiModels.type, }) .from(aiModels) - .where(and(eq(aiModels.userId, this.userId), eq(aiModels.enabled, true))); + .where(and(eq(aiModels.userId, this.userId))); + + return data as EnabledAiModel[]; }; findById = async (id: string) => { diff --git a/src/store/aiInfra/slices/aiProvider/action.ts b/src/store/aiInfra/slices/aiProvider/action.ts index e972b0285bf9..1db81dbb2ba8 100644 --- a/src/store/aiInfra/slices/aiProvider/action.ts +++ b/src/store/aiInfra/slices/aiProvider/action.ts @@ -215,7 +215,7 @@ export const createAiProviderSlice: StateCreator< enabledChatModelList, }, false, - 'useInitAiProviderKeyVaults', + 'useFetchAiProviderRuntimeState', ); }, }, diff --git a/src/types/aiProvider.ts b/src/types/aiProvider.ts index c411622d3e66..d97cc7669336 100644 --- a/src/types/aiProvider.ts +++ b/src/types/aiProvider.ts @@ -192,6 +192,7 @@ export interface EnabledAiModel { config?: AiModelConfig; contextWindowTokens?: number; displayName?: string; + enabled?: boolean; id: string; providerId: string; sort?: number;