diff --git a/src/database/repositories/aiInfra/index.test.ts b/src/database/repositories/aiInfra/index.test.ts index 19ef7d319668c..d251b4381f048 100644 --- a/src/database/repositories/aiInfra/index.test.ts +++ b/src/database/repositories/aiInfra/index.test.ts @@ -1,70 +1,91 @@ -import { describe, expect, it, vi } from 'vitest'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders'; +import { clientDB, initializeDB } from '@/database/client/db'; +import { AiProviderModel } from '@/database/server/models/aiProvider'; +import { LobeChatDatabase } from '@/database/type'; +import { AiProviderModelListItem } from '@/types/aiModel'; +import { + AiProviderDetailItem, + AiProviderListItem, + AiProviderRuntimeConfig, + EnabledAiModel, + EnabledProvider, +} from '@/types/aiProvider'; import { AiInfraRepos } from './index'; -describe('AiInfraRepos', () => { - const mockDb = { - query: vi.fn(), - }; - - const mockUserId = 'test-user-id'; - const mockProviderConfigs = { - openai: { - enabled: true, - }, - }; - - const mockAiProviderModel = { - getAiProviderById: vi.fn(), - getAiProviderList: vi.fn(), - getAiProviderRuntimeConfig: vi.fn(), - }; - - const mockAiModelModel = { - getAllModels: vi.fn(), - getModelListByProviderId: vi.fn(), - }; +const userId = 'test-user-id'; +const mockProviderConfigs = { + openai: { enabled: true }, + anthropic: { enabled: false }, +}; + +let repo: AiInfraRepos; + +beforeEach(async () => { + await initializeDB(); + vi.clearAllMocks(); + repo = new AiInfraRepos(clientDB as any, userId, mockProviderConfigs); +}); + +describe('AiInfraRepos', () => { describe('getAiProviderList', () => { it('should merge builtin and user providers correctly', async () => { - const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs); - repo.aiProviderModel = mockAiProviderModel as any; - const mockUserProviders = [ - { - description: 'Custom OpenAI', - enabled: true, - id: 'openai', - name: 'Custom OpenAI', - sort: 1, - source: 'builtin' as const, - }, - ]; + { id: 'openai', enabled: true, name: 'Custom OpenAI' }, + { id: 'custom', enabled: true, name: 'Custom Provider' }, + ] as AiProviderListItem[]; - mockAiProviderModel.getAiProviderList.mockResolvedValue(mockUserProviders); + vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValueOnce(mockUserProviders); const result = await repo.getAiProviderList(); - expect(result[0]).toEqual( - expect.objectContaining({ - description: 'Custom OpenAI', - enabled: true, - id: 'openai', - name: 'Custom OpenAI', - sort: 1, - source: 'builtin', - }), + expect(result).toBeDefined(); + expect(result.length).toBeGreaterThan(0); + // Verify the merge logic + const openaiProvider = result.find((p) => p.id === 'openai'); + expect(openaiProvider).toMatchObject({ enabled: true, name: 'Custom OpenAI' }); + }); + + it('should sort providers according to DEFAULT_MODEL_PROVIDER_LIST order', async () => { + vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValue([]); + + const result = await repo.getAiProviderList(); + + expect(result).toEqual( + expect.arrayContaining( + DEFAULT_MODEL_PROVIDER_LIST.map((item) => + expect.objectContaining({ + id: item.id, + source: 'builtin', + }), + ), + ), ); }); }); describe('getUserEnabledProviderList', () => { it('should return only enabled providers', async () => { - const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs); - repo.aiProviderModel = mockAiProviderModel as any; + const mockProviders = [ + { id: 'openai', enabled: true, name: 'OpenAI', sort: 1 }, + { id: 'anthropic', enabled: false, name: 'Anthropic', sort: 2 }, + ] as AiProviderListItem[]; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + + const result = await repo.getUserEnabledProviderList(); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: 'openai', + name: 'OpenAI', + }); + }); + it('should return only enabled provider', async () => { const mockProviders = [ { enabled: true, @@ -84,7 +105,7 @@ describe('AiInfraRepos', () => { }, ]; - vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValue(mockProviders); const result = await repo.getUserEnabledProviderList(); @@ -100,11 +121,29 @@ describe('AiInfraRepos', () => { }); describe('getEnabledModels', () => { - it('should merge builtin and user models correctly', async () => { - const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs); - repo.aiProviderModel = mockAiProviderModel as any; - (repo as any).aiModelModel = mockAiModelModel; + it('should merge and filter enabled models', async () => { + const mockProviders = [{ id: 'openai', enabled: true }] as AiProviderListItem[]; + const mockAllModels = [ + { id: 'gpt-4', providerId: 'openai', enabled: true }, + ] as EnabledAiModel[]; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([ + { id: 'gpt-4', enabled: true, type: 'chat' }, + ]); + + const result = await repo.getEnabledModels(); + expect(result).toBeDefined(); + expect(result.length).toBeGreaterThan(0); + expect(result[0]).toMatchObject({ + id: 'gpt-4', + providerId: 'openai', + }); + }); + + it('should merge builtin and user models correctly', async () => { const mockProviders = [ { enabled: true, id: 'openai', name: 'OpenAI', sort: 1, source: 'builtin' as const }, ]; @@ -118,11 +157,12 @@ describe('AiInfraRepos', () => { providerId: 'openai', sort: 1, type: 'chat' as const, + contextWindowTokens: 10, }, ]; vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); - mockAiModelModel.getAllModels.mockResolvedValue(mockAllModels); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([ { abilities: {}, @@ -140,6 +180,7 @@ describe('AiInfraRepos', () => { abilities: { vision: true }, displayName: 'Custom GPT-4', enabled: true, + contextWindowTokens: 10, id: 'gpt-4', providerId: 'openai', sort: 1, @@ -149,10 +190,6 @@ describe('AiInfraRepos', () => { }); it('should handle case when user model not found', async () => { - const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs); - repo.aiProviderModel = mockAiProviderModel as any; - (repo as any).aiModelModel = mockAiModelModel; - const mockProviders = [ { enabled: true, id: 'openai', name: 'OpenAI', sort: 1, source: 'builtin' as const }, ]; @@ -160,7 +197,7 @@ describe('AiInfraRepos', () => { const mockAllModels: any[] = []; vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); - mockAiModelModel.getAllModels.mockResolvedValue(mockAllModels); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([ { abilities: { reasoning: true }, @@ -185,14 +222,31 @@ describe('AiInfraRepos', () => { }); describe('getAiProviderModelList', () => { + it('should merge builtin and user models', async () => { + const providerId = 'openai'; + const mockUserModels = [ + { id: 'custom-gpt4', enabled: true, type: 'chat' }, + ] as AiProviderModelListItem[]; + const mockBuiltinModels = [{ id: 'gpt-4', enabled: true }]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(mockUserModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(mockBuiltinModels); + + const result = await repo.getAiProviderModelList(providerId); + + expect(result).toHaveLength(2); + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: 'custom-gpt4' }), + expect.objectContaining({ id: 'gpt-4' }), + ]), + ); + }); it('should merge default and custom models', async () => { - const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs); - (repo as any).aiModelModel = mockAiModelModel; - const mockCustomModels = [ { displayName: 'Custom GPT-4', - enabled: true, + enabled: false, id: 'gpt-4', type: 'chat' as const, }, @@ -201,13 +255,13 @@ describe('AiInfraRepos', () => { const mockDefaultModels = [ { displayName: 'GPT-4', - enabled: false, + enabled: true, id: 'gpt-4', type: 'chat' as const, }, ]; - mockAiModelModel.getModelListByProviderId.mockResolvedValue(mockCustomModels); + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(mockCustomModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(mockDefaultModels); const result = await repo.getAiProviderModelList('openai'); @@ -215,25 +269,71 @@ describe('AiInfraRepos', () => { expect(result).toContainEqual( expect.objectContaining({ displayName: 'Custom GPT-4', - enabled: true, + enabled: false, id: 'gpt-4', }), ); }); + + it('should use builtin models', async () => { + const providerId = 'taichu'; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue([]); + + const result = await repo.getAiProviderModelList(providerId); + + expect(result).toHaveLength(2); + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: 'taichu_llm' }), + expect.objectContaining({ id: 'taichu2_mm' }), + ]), + ); + }); + + it('should return empty if not exist provider', async () => { + const providerId = 'abc'; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue([]); + + const result = await repo.getAiProviderModelList(providerId); + + expect(result).toHaveLength(0); + }); }); describe('getAiProviderRuntimeState', () => { - it('should return provider runtime state', async () => { - const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs); - repo.aiProviderModel = mockAiProviderModel as any; + it('should return complete runtime state', async () => { + const mockRuntimeConfig = { + openai: { apiKey: 'test-key' }, + } as unknown as Record; + const mockEnabledProviders = [{ id: 'openai', name: 'OpenAI' }] as EnabledProvider[]; + const mockEnabledModels = [{ id: 'gpt-4', providerId: 'openai' }] as EnabledAiModel[]; + vi.spyOn(repo.aiProviderModel, 'getAiProviderRuntimeConfig').mockResolvedValue( + mockRuntimeConfig, + ); + vi.spyOn(repo, 'getUserEnabledProviderList').mockResolvedValue(mockEnabledProviders); + vi.spyOn(repo, 'getEnabledModels').mockResolvedValue(mockEnabledModels); + + const result = await repo.getAiProviderRuntimeState(); + + expect(result).toMatchObject({ + enabledAiProviders: mockEnabledProviders, + enabledAiModels: mockEnabledModels, + runtimeConfig: expect.any(Object), + }); + }); + it('should return provider runtime state', async () => { const mockRuntimeConfig = { openai: { apiKey: 'test-key', }, - }; + } as unknown as Record; - mockAiProviderModel.getAiProviderRuntimeConfig.mockResolvedValue(mockRuntimeConfig); + vi.spyOn(repo.aiProviderModel, 'getAiProviderRuntimeConfig').mockResolvedValue( + mockRuntimeConfig, + ); vi.spyOn(repo, 'getUserEnabledProviderList').mockResolvedValue([ { id: 'openai', logo: 'logo1', name: 'OpenAI', source: 'builtin' }, @@ -271,10 +371,24 @@ describe('AiInfraRepos', () => { }); describe('getAiProviderDetail', () => { - it('should merge provider configs correctly', async () => { - const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs); - repo.aiProviderModel = mockAiProviderModel as any; + it('should merge provider config with user settings', async () => { + const providerId = 'openai'; + const mockProviderDetail = { + id: providerId, + customSetting: 'test', + } as unknown as AiProviderDetailItem; + vi.spyOn(repo.aiProviderModel, 'getAiProviderById').mockResolvedValue(mockProviderDetail); + + const result = await repo.getAiProviderDetail(providerId); + + expect(result).toMatchObject({ + id: providerId, + customSetting: 'test', + enabled: true, // from mockProviderConfigs + }); + }); + it('should merge provider configs correctly', async () => { const mockProviderDetail = { enabled: true, id: 'openai', @@ -284,7 +398,7 @@ describe('AiInfraRepos', () => { source: 'builtin' as const, }; - mockAiProviderModel.getAiProviderById.mockResolvedValue(mockProviderDetail); + vi.spyOn(repo.aiProviderModel, 'getAiProviderById').mockResolvedValue(mockProviderDetail); const result = await repo.getAiProviderDetail('openai'); diff --git a/src/database/repositories/aiInfra/index.ts b/src/database/repositories/aiInfra/index.ts index 8d5df23e30340..cd1b1d96a75fd 100644 --- a/src/database/repositories/aiInfra/index.ts +++ b/src/database/repositories/aiInfra/index.ts @@ -11,6 +11,7 @@ import { AiProviderListItem, AiProviderRuntimeState, EnabledAiModel, + EnabledProvider, } from '@/types/aiProvider'; import { ProviderConfig } from '@/types/user/settings'; import { merge, mergeArrayById } from '@/utils/merge'; @@ -22,7 +23,7 @@ export class AiInfraRepos { private db: LobeChatDatabase; aiProviderModel: AiProviderModel; private providerConfigs: Record; - private aiModelModel: AiModelModel; + aiModelModel: AiModelModel; constructor( db: LobeChatDatabase, @@ -70,7 +71,14 @@ export class AiInfraRepos { return list .filter((item) => item.enabled) .sort((a, b) => a.sort! - b.sort!) - .map((item) => ({ id: item.id, logo: item.logo, name: item.name, source: item.source })); + .map( + (item): EnabledProvider => ({ + id: item.id, + logo: item.logo, + name: item.name, + source: item.source, + }), + ); }; getEnabledModels = async () => { @@ -104,7 +112,7 @@ export class AiInfraRepos { ? user.contextWindowTokens : item.contextWindowTokens, displayName: user?.displayName || item.displayName, - enabled: user.enabled || item.enabled, + enabled: typeof user.enabled === 'boolean' ? user.enabled : item.enabled, id: item.id, providerId: provider.id, sort: user.sort || undefined,