Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
bentwnghk committed Jan 8, 2025
2 parents c54111f + 45aac54 commit ab1b17c
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 17 deletions.
7 changes: 7 additions & 0 deletions changelog/v1.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
[
{
"children": {
"fixes": ["Fix provider enabled issue."]
},
"date": "2025-01-08",
"version": "1.44.3"
},
{
"children": {
"fixes": ["Add provider id validate."]
Expand Down
26 changes: 19 additions & 7 deletions src/database/repositories/aiInfra/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,28 @@ export class AiInfraRepos {
const providers = await this.getAiProviderList();
const enabledProviders = providers.filter((item) => item.enabled);

const userEnabledModels = await this.aiModelModel.getEnabledModels();
const allModels = await this.aiModelModel.getAllModels();
const userEnabledModels = allModels.filter((item) => item.enabled);

const modelList = await pMap(
enabledProviders,
async (provider) => {
const aiModels = await this.fetchBuiltinModels(provider.id);

return (aiModels || [])
.filter((i) => i.enabled)
.map<EnabledAiModel>((item) => ({
...item,
abilities: item.abilities || {},
providerId: provider.id,
}));
.map<EnabledAiModel & { enabled?: boolean | null }>((item) => {
const user = allModels.find((m) => m.id === item.id && m.providerId === provider.id);

const enabled = !!user ? user.enabled : item.enabled;

return {
...item,
abilities: item.abilities || {},
enabled,
providerId: provider.id,
};
})
.filter((i) => i.enabled);
},
{ concurrency: 10 },
);
Expand All @@ -100,6 +109,9 @@ export class AiInfraRepos {
return mergeArrayById(defaultModels, aiModels) as AiProviderModelListItem[];
};

/**
* Fetch builtin models from config
*/
private fetchBuiltinModels = async (
providerId: string,
): Promise<AiProviderModelListItem[] | undefined> => {
Expand Down
10 changes: 4 additions & 6 deletions src/database/server/models/__tests__/aiModel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,15 @@ describe('AiModelModel', () => {
});
});

describe('getEnabledModels', () => {
describe('getAllModels', () => {
it('should only return enabled models', async () => {
await serverDB.insert(aiModels).values([
{ id: 'model1', providerId: 'openai', enabled: true, source: 'custom', userId },
{ id: 'model2', providerId: 'openai', enabled: false, source: 'custom', userId },
{ id: 'model2', providerId: 'b', enabled: false, source: 'custom', userId },
]);

const models = await aiProviderModel.getEnabledModels();
expect(models).toHaveLength(1);
expect(models[0].id).toBe('model1');
expect(models[0].enabled).toBe(true);
const models = await aiProviderModel.getAllModels();
expect(models).toHaveLength(2);
});
});

Expand Down
9 changes: 6 additions & 3 deletions src/database/server/models/aiModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
AiProviderModelListItem,
ToggleAiModelEnableParams,
} from '@/types/aiModel';
import { EnabledAiModel } from '@/types/aiProvider';

import { AiModelSelectItem, NewAiModelItem, aiModels } from '../../schemas';

Expand Down Expand Up @@ -83,8 +84,8 @@ export class AiModelModel {
return result as AiProviderModelListItem[];
};

getEnabledModels = async () => {
return this.db
getAllModels = async () => {
const data = await this.db
.select({
abilities: aiModels.abilities,
config: aiModels.config,
Expand All @@ -98,7 +99,9 @@ export class AiModelModel {
type: aiModels.type,
})
.from(aiModels)
.where(and(eq(aiModels.userId, this.userId), eq(aiModels.enabled, true)));
.where(and(eq(aiModels.userId, this.userId)));

return data as EnabledAiModel[];
};

findById = async (id: string) => {
Expand Down
2 changes: 1 addition & 1 deletion src/store/aiInfra/slices/aiProvider/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ export const createAiProviderSlice: StateCreator<
enabledChatModelList,
},
false,
'useInitAiProviderKeyVaults',
'useFetchAiProviderRuntimeState',
);
},
},
Expand Down
1 change: 1 addition & 0 deletions src/types/aiProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ export interface EnabledAiModel {
config?: AiModelConfig;
contextWindowTokens?: number;
displayName?: string;
enabled?: boolean;
id: string;
providerId: string;
sort?: number;
Expand Down

0 comments on commit ab1b17c

Please sign in to comment.