Skip to content

Commit

Permalink
🐛 fix: fix not enable models correctly (lobehub#6071)
Browse files Browse the repository at this point in the history
* fix enabled issue

* fix tests
  • Loading branch information
arvinxx authored Feb 13, 2025
1 parent 4aff396 commit b78328e
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 77 deletions.
262 changes: 188 additions & 74 deletions src/database/repositories/aiInfra/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,70 +1,91 @@
import { describe, expect, it, vi } from 'vitest';
import { beforeEach, describe, expect, it, vi } from 'vitest';

import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
import { clientDB, initializeDB } from '@/database/client/db';
import { AiProviderModel } from '@/database/server/models/aiProvider';
import { LobeChatDatabase } from '@/database/type';
import { AiProviderModelListItem } from '@/types/aiModel';
import {
AiProviderDetailItem,
AiProviderListItem,
AiProviderRuntimeConfig,
EnabledAiModel,
EnabledProvider,
} from '@/types/aiProvider';

import { AiInfraRepos } from './index';

describe('AiInfraRepos', () => {
const mockDb = {
query: vi.fn(),
};

const mockUserId = 'test-user-id';
const mockProviderConfigs = {
openai: {
enabled: true,
},
};

const mockAiProviderModel = {
getAiProviderById: vi.fn(),
getAiProviderList: vi.fn(),
getAiProviderRuntimeConfig: vi.fn(),
};

const mockAiModelModel = {
getAllModels: vi.fn(),
getModelListByProviderId: vi.fn(),
};
const userId = 'test-user-id';
const mockProviderConfigs = {
openai: { enabled: true },
anthropic: { enabled: false },
};

let repo: AiInfraRepos;

beforeEach(async () => {
await initializeDB();
vi.clearAllMocks();

repo = new AiInfraRepos(clientDB as any, userId, mockProviderConfigs);
});

describe('AiInfraRepos', () => {
describe('getAiProviderList', () => {
it('should merge builtin and user providers correctly', async () => {
const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs);
repo.aiProviderModel = mockAiProviderModel as any;

const mockUserProviders = [
{
description: 'Custom OpenAI',
enabled: true,
id: 'openai',
name: 'Custom OpenAI',
sort: 1,
source: 'builtin' as const,
},
];
{ id: 'openai', enabled: true, name: 'Custom OpenAI' },
{ id: 'custom', enabled: true, name: 'Custom Provider' },
] as AiProviderListItem[];

mockAiProviderModel.getAiProviderList.mockResolvedValue(mockUserProviders);
vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValueOnce(mockUserProviders);

const result = await repo.getAiProviderList();

expect(result[0]).toEqual(
expect.objectContaining({
description: 'Custom OpenAI',
enabled: true,
id: 'openai',
name: 'Custom OpenAI',
sort: 1,
source: 'builtin',
}),
expect(result).toBeDefined();
expect(result.length).toBeGreaterThan(0);
// Verify the merge logic
const openaiProvider = result.find((p) => p.id === 'openai');
expect(openaiProvider).toMatchObject({ enabled: true, name: 'Custom OpenAI' });
});

it('should sort providers according to DEFAULT_MODEL_PROVIDER_LIST order', async () => {
vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValue([]);

const result = await repo.getAiProviderList();

expect(result).toEqual(
expect.arrayContaining(
DEFAULT_MODEL_PROVIDER_LIST.map((item) =>
expect.objectContaining({
id: item.id,
source: 'builtin',
}),
),
),
);
});
});

describe('getUserEnabledProviderList', () => {
it('should return only enabled providers', async () => {
const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs);
repo.aiProviderModel = mockAiProviderModel as any;
const mockProviders = [
{ id: 'openai', enabled: true, name: 'OpenAI', sort: 1 },
{ id: 'anthropic', enabled: false, name: 'Anthropic', sort: 2 },
] as AiProviderListItem[];

vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders);

const result = await repo.getUserEnabledProviderList();

expect(result).toHaveLength(1);
expect(result[0]).toMatchObject({
id: 'openai',
name: 'OpenAI',
});
});

it('should return only enabled provider', async () => {
const mockProviders = [
{
enabled: true,
Expand All @@ -84,7 +105,7 @@ describe('AiInfraRepos', () => {
},
];

vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders);
vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValue(mockProviders);

const result = await repo.getUserEnabledProviderList();

Expand All @@ -100,11 +121,29 @@ describe('AiInfraRepos', () => {
});

describe('getEnabledModels', () => {
it('should merge builtin and user models correctly', async () => {
const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs);
repo.aiProviderModel = mockAiProviderModel as any;
(repo as any).aiModelModel = mockAiModelModel;
it('should merge and filter enabled models', async () => {
const mockProviders = [{ id: 'openai', enabled: true }] as AiProviderListItem[];
const mockAllModels = [
{ id: 'gpt-4', providerId: 'openai', enabled: true },
] as EnabledAiModel[];

vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders);
vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels);
vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([
{ id: 'gpt-4', enabled: true, type: 'chat' },
]);

const result = await repo.getEnabledModels();

expect(result).toBeDefined();
expect(result.length).toBeGreaterThan(0);
expect(result[0]).toMatchObject({
id: 'gpt-4',
providerId: 'openai',
});
});

it('should merge builtin and user models correctly', async () => {
const mockProviders = [
{ enabled: true, id: 'openai', name: 'OpenAI', sort: 1, source: 'builtin' as const },
];
Expand All @@ -118,11 +157,12 @@ describe('AiInfraRepos', () => {
providerId: 'openai',
sort: 1,
type: 'chat' as const,
contextWindowTokens: 10,
},
];

vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders);
mockAiModelModel.getAllModels.mockResolvedValue(mockAllModels);
vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels);
vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([
{
abilities: {},
Expand All @@ -140,6 +180,7 @@ describe('AiInfraRepos', () => {
abilities: { vision: true },
displayName: 'Custom GPT-4',
enabled: true,
contextWindowTokens: 10,
id: 'gpt-4',
providerId: 'openai',
sort: 1,
Expand All @@ -149,18 +190,14 @@ describe('AiInfraRepos', () => {
});

it('should handle case when user model not found', async () => {
const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs);
repo.aiProviderModel = mockAiProviderModel as any;
(repo as any).aiModelModel = mockAiModelModel;

const mockProviders = [
{ enabled: true, id: 'openai', name: 'OpenAI', sort: 1, source: 'builtin' as const },
];

const mockAllModels: any[] = [];

vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders);
mockAiModelModel.getAllModels.mockResolvedValue(mockAllModels);
vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels);
vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([
{
abilities: { reasoning: true },
Expand All @@ -185,14 +222,31 @@ describe('AiInfraRepos', () => {
});

describe('getAiProviderModelList', () => {
it('should merge builtin and user models', async () => {
const providerId = 'openai';
const mockUserModels = [
{ id: 'custom-gpt4', enabled: true, type: 'chat' },
] as AiProviderModelListItem[];
const mockBuiltinModels = [{ id: 'gpt-4', enabled: true }];

vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(mockUserModels);
vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(mockBuiltinModels);

const result = await repo.getAiProviderModelList(providerId);

expect(result).toHaveLength(2);
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({ id: 'custom-gpt4' }),
expect.objectContaining({ id: 'gpt-4' }),
]),
);
});
it('should merge default and custom models', async () => {
const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs);
(repo as any).aiModelModel = mockAiModelModel;

const mockCustomModels = [
{
displayName: 'Custom GPT-4',
enabled: true,
enabled: false,
id: 'gpt-4',
type: 'chat' as const,
},
Expand All @@ -201,39 +255,85 @@ describe('AiInfraRepos', () => {
const mockDefaultModels = [
{
displayName: 'GPT-4',
enabled: false,
enabled: true,
id: 'gpt-4',
type: 'chat' as const,
},
];

mockAiModelModel.getModelListByProviderId.mockResolvedValue(mockCustomModels);
vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(mockCustomModels);
vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(mockDefaultModels);

const result = await repo.getAiProviderModelList('openai');

expect(result).toContainEqual(
expect.objectContaining({
displayName: 'Custom GPT-4',
enabled: true,
enabled: false,
id: 'gpt-4',
}),
);
});

it('should use builtin models', async () => {
const providerId = 'taichu';

vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue([]);

const result = await repo.getAiProviderModelList(providerId);

expect(result).toHaveLength(2);
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({ id: 'taichu_llm' }),
expect.objectContaining({ id: 'taichu2_mm' }),
]),
);
});

it('should return empty if not exist provider', async () => {
const providerId = 'abc';

vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue([]);

const result = await repo.getAiProviderModelList(providerId);

expect(result).toHaveLength(0);
});
});

describe('getAiProviderRuntimeState', () => {
it('should return provider runtime state', async () => {
const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs);
repo.aiProviderModel = mockAiProviderModel as any;
it('should return complete runtime state', async () => {
const mockRuntimeConfig = {
openai: { apiKey: 'test-key' },
} as unknown as Record<string, AiProviderRuntimeConfig>;
const mockEnabledProviders = [{ id: 'openai', name: 'OpenAI' }] as EnabledProvider[];
const mockEnabledModels = [{ id: 'gpt-4', providerId: 'openai' }] as EnabledAiModel[];

vi.spyOn(repo.aiProviderModel, 'getAiProviderRuntimeConfig').mockResolvedValue(
mockRuntimeConfig,
);
vi.spyOn(repo, 'getUserEnabledProviderList').mockResolvedValue(mockEnabledProviders);
vi.spyOn(repo, 'getEnabledModels').mockResolvedValue(mockEnabledModels);

const result = await repo.getAiProviderRuntimeState();

expect(result).toMatchObject({
enabledAiProviders: mockEnabledProviders,
enabledAiModels: mockEnabledModels,
runtimeConfig: expect.any(Object),
});
});
it('should return provider runtime state', async () => {
const mockRuntimeConfig = {
openai: {
apiKey: 'test-key',
},
};
} as unknown as Record<string, AiProviderRuntimeConfig>;

mockAiProviderModel.getAiProviderRuntimeConfig.mockResolvedValue(mockRuntimeConfig);
vi.spyOn(repo.aiProviderModel, 'getAiProviderRuntimeConfig').mockResolvedValue(
mockRuntimeConfig,
);

vi.spyOn(repo, 'getUserEnabledProviderList').mockResolvedValue([
{ id: 'openai', logo: 'logo1', name: 'OpenAI', source: 'builtin' },
Expand Down Expand Up @@ -271,10 +371,24 @@ describe('AiInfraRepos', () => {
});

describe('getAiProviderDetail', () => {
it('should merge provider configs correctly', async () => {
const repo = new AiInfraRepos(mockDb as any, mockUserId, mockProviderConfigs);
repo.aiProviderModel = mockAiProviderModel as any;
it('should merge provider config with user settings', async () => {
const providerId = 'openai';
const mockProviderDetail = {
id: providerId,
customSetting: 'test',
} as unknown as AiProviderDetailItem;

vi.spyOn(repo.aiProviderModel, 'getAiProviderById').mockResolvedValue(mockProviderDetail);

const result = await repo.getAiProviderDetail(providerId);

expect(result).toMatchObject({
id: providerId,
customSetting: 'test',
enabled: true, // from mockProviderConfigs
});
});
it('should merge provider configs correctly', async () => {
const mockProviderDetail = {
enabled: true,
id: 'openai',
Expand All @@ -284,7 +398,7 @@ describe('AiInfraRepos', () => {
source: 'builtin' as const,
};

mockAiProviderModel.getAiProviderById.mockResolvedValue(mockProviderDetail);
vi.spyOn(repo.aiProviderModel, 'getAiProviderById').mockResolvedValue(mockProviderDetail);

const result = await repo.getAiProviderDetail('openai');

Expand Down
Loading

0 comments on commit b78328e

Please sign in to comment.