Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor: refactor the agent runtime payload #5250

Merged
merged 10 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion next.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ const nextConfig: NextConfig = {
'@icons-pack/react-simple-icons',
'@lobehub/ui',
'gpt-tokenizer',
'chroma-js',
],
webVitalsAttribution: ['CLS', 'LCP'],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ import { Checkbox, Form, FormInstance, Input } from 'antd';
import { memo, useEffect } from 'react';
import { useTranslation } from 'react-i18next';

import MaxTokenSlider from '@/components/MaxTokenSlider';
import { useIsMobile } from '@/hooks/useIsMobile';
import { ChatModelCard } from '@/types/llm';

import MaxTokenSlider from './MaxTokenSlider';

interface ModelConfigFormProps {
initialValues?: ChatModelCard;
onFormInstanceReady: (instance: FormInstance) => void;
Expand Down Expand Up @@ -66,7 +65,10 @@ const ModelConfigForm = memo<ModelConfigFormProps>(
>
<Input placeholder={t('llm.customModelCards.modelConfig.displayName.placeholder')} />
</Form.Item>
<Form.Item label={t('llm.customModelCards.modelConfig.tokens.title')} name={'contextWindowTokens'}>
<Form.Item
label={t('llm.customModelCards.modelConfig.tokens.title')}
name={'contextWindowTokens'}
>
<MaxTokenSlider />
</Form.Item>
<Form.Item
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
import useMergeState from 'use-merge-value';

import { useServerConfigStore } from '@/store/serverConfig';
import { serverConfigSelectors } from '@/store/serverConfig/selectors';
import { useIsMobile } from '@/hooks/useIsMobile';

const Kibi = 1024;

Expand All @@ -20,7 +19,7 @@ interface MaxTokenSliderProps {
}

const MaxTokenSlider = memo<MaxTokenSliderProps>(({ value, onChange, defaultValue }) => {
const { t } = useTranslation('setting');
const { t } = useTranslation('components');

const [token, setTokens] = useMergeState(0, {
defaultValue,
Expand All @@ -45,7 +44,7 @@ const MaxTokenSlider = memo<MaxTokenSliderProps>(({ value, onChange, defaultValu
setPowValue(exponent(value / Kibi));
};

const isMobile = useServerConfigStore(serverConfigSelectors.isMobile);
const isMobile = useIsMobile();

const marks = useMemo(() => {
return {
Expand Down Expand Up @@ -74,7 +73,7 @@ const MaxTokenSlider = memo<MaxTokenSliderProps>(({ value, onChange, defaultValu
tooltip={{
formatter: (x) => {
if (typeof x === 'undefined') return;
if (x === 0) return t('llm.customModelCards.modelConfig.tokens.unlimited');
if (x === 0) return t('MaxTokenSlider.unlimited');

let value = getRealValue(x);
if (value < 125) return value.toFixed(0) + 'K';
Expand Down
9 changes: 6 additions & 3 deletions src/components/ModelSelect/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { FC, memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Center, Flexbox } from 'react-layout-kit';

import { ModelAbilities } from '@/types/aiModel';
import { ChatModelCard } from '@/types/llm';
import { formatTokenNumber } from '@/utils/format';

Expand Down Expand Up @@ -57,8 +58,10 @@ const useStyles = createStyles(({ css, token }) => ({
`,
}));

interface ModelInfoTagsProps extends ChatModelCard {
interface ModelInfoTagsProps extends ModelAbilities {
contextWindowTokens?: number | null;
directionReverse?: boolean;
isCustom?: boolean;
placement?: 'top' | 'right';
}

Expand Down Expand Up @@ -102,7 +105,7 @@ export const ModelInfoTags = memo<ModelInfoTagsProps>(
</div>
</Tooltip>
)}
{model.contextWindowTokens !== undefined && (
{typeof model.contextWindowTokens === 'number' && (
<Tooltip
overlayStyle={{ maxWidth: 'unset', pointerEvents: 'none' }}
placement={placement}
Expand All @@ -117,7 +120,7 @@ export const ModelInfoTags = memo<ModelInfoTagsProps>(
{model.contextWindowTokens === 0 ? (
<Infinity size={17} strokeWidth={1.6} />
) : (
formatTokenNumber(model.contextWindowTokens)
formatTokenNumber(model.contextWindowTokens as number)
)}
</Center>
</Tooltip>
Expand Down
10 changes: 9 additions & 1 deletion src/components/NProgress/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ import { memo } from 'react';

const NProgress = memo(() => {
const theme = useTheme();
return <NextTopLoader color={theme.colorText} height={2} shadow={false} showSpinner={false} />;
return (
<NextTopLoader
color={theme.colorText}
height={2}
shadow={false}
showSpinner={false}
zIndex={1000}
/>
);
});

export default NProgress;
2 changes: 1 addition & 1 deletion src/const/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export interface JWTPayload {
/**
* Represents the endpoint of provider
*/
endpoint?: string;
baseURL?: string;

azureApiVersion?: string;

Expand Down
11 changes: 11 additions & 0 deletions src/database/server/models/__tests__/user.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ describe('UserModel', () => {
});
});

describe('getUserSettings', () => {
it('should get user settings', async () => {
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(userSettings).values({ id: userId, general: { language: 'en-US' } });

const data = await userModel.getUserSettings();

expect(data).toMatchObject({ id: userId, general: { language: 'en-US' } });
});
});

describe('deleteSetting', () => {
it('should delete user settings', async () => {
await serverDB.insert(users).values({ id: userId });
Expand Down
4 changes: 4 additions & 0 deletions src/database/server/models/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ export class UserModel {
};
};

getUserSettings = async () => {
return this.db.query.userSettings.findFirst({ where: eq(userSettings.id, this.userId) });
};

updateUser = async (value: Partial<UserItem>) => {
return this.db
.update(users)
Expand Down
20 changes: 10 additions & 10 deletions src/libs/agent-runtime/AgentRuntime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ describe('AgentRuntime', () => {
describe('Azure OpenAI provider', () => {
it('should initialize correctly', async () => {
const jwtPayload = {
apikey: 'user-azure-key',
endpoint: 'user-azure-endpoint',
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
apiVersion: '2024-06-01',
};

Expand All @@ -90,8 +90,8 @@ describe('AgentRuntime', () => {
});
it('should initialize with azureOpenAIParams correctly', async () => {
const jwtPayload = {
apikey: 'user-openai-key',
endpoint: 'user-endpoint',
apiKey: 'user-openai-key',
baseURL: 'user-endpoint',
apiVersion: 'custom-version',
};

Expand All @@ -106,8 +106,8 @@ describe('AgentRuntime', () => {

it('should initialize with AzureAI correctly', async () => {
const jwtPayload = {
apikey: 'user-azure-key',
endpoint: 'user-azure-endpoint',
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
};
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Azure, {
azure: jwtPayload,
Expand Down Expand Up @@ -171,7 +171,7 @@ describe('AgentRuntime', () => {

describe('Ollama provider', () => {
it('should initialize correctly', async () => {
const jwtPayload: JWTPayload = { endpoint: 'user-ollama-url' };
const jwtPayload: JWTPayload = { baseURL: 'https://user-ollama-url' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Ollama, {
ollama: jwtPayload,
});
Expand Down Expand Up @@ -255,7 +255,7 @@ describe('AgentRuntime', () => {

describe('AgentRuntime chat method', () => {
it('should run correctly', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, {
openai: jwtPayload,
});
Expand All @@ -271,7 +271,7 @@ describe('AgentRuntime', () => {
await runtime.chat(payload);
});
it('should handle options correctly', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, {
openai: jwtPayload,
});
Expand Down Expand Up @@ -300,7 +300,7 @@ describe('AgentRuntime', () => {
});

describe('callback', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, {
openai: jwtPayload,
});
Expand Down
6 changes: 3 additions & 3 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class AgentRuntime {
ai21: Partial<ClientOptions>;
ai360: Partial<ClientOptions>;
anthropic: Partial<ClientOptions>;
azure: { apiVersion?: string; apikey?: string; endpoint?: string };
azure: { apiKey?: string; apiVersion?: string; baseURL?: string };
baichuan: Partial<ClientOptions>;
bedrock: Partial<LobeBedrockAIParams>;
cloudflare: Partial<LobeCloudflareParams>;
Expand Down Expand Up @@ -180,8 +180,8 @@ class AgentRuntime {

case ModelProvider.Azure: {
runtimeModel = new LobeAzureOpenAI(
params.azure?.endpoint,
params.azure?.apikey,
params.azure?.baseURL,
params.azure?.apiKey,
params.azure?.apiVersion,
);
break;
Expand Down
5 changes: 4 additions & 1 deletion src/libs/agent-runtime/ollama/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ describe('LobeOllamaAI', () => {
try {
new LobeOllamaAI({ baseURL: 'invalid-url' });
} catch (e) {
expect(e).toEqual(AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs));
expect(e).toEqual({
error: new TypeError('Invalid URL'),
errorType: 'InvalidOllamaArgs',
});
}
});
});
Expand Down
4 changes: 2 additions & 2 deletions src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ export class LobeOllamaAI implements LobeRuntimeAI {
constructor({ baseURL }: ClientOptions = {}) {
try {
if (baseURL) new URL(baseURL);
} catch {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs);
} catch (e) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs, e);
}

this.client = new Ollama(!baseURL ? undefined : { host: baseURL });
Expand Down
10 changes: 10 additions & 0 deletions src/libs/agent-runtime/openai/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 0.5,
"output": 1.5,
},
"releasedAt": "2023-02-28",
},
{
"id": "gpt-3.5-turbo-16k",
Expand All @@ -35,6 +36,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 10,
"output": 30,
},
"releasedAt": "2024-01-23",
},
{
"contextWindowTokens": 128000,
Expand All @@ -46,6 +48,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 10,
"output": 30,
},
"releasedAt": "2024-01-23",
},
{
"contextWindowTokens": 4096,
Expand All @@ -56,6 +59,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 1.5,
"output": 2,
},
"releasedAt": "2023-08-24",
},
{
"id": "gpt-3.5-turbo-0301",
Expand All @@ -73,6 +77,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 1,
"output": 2,
},
"releasedAt": "2023-11-02",
},
{
"contextWindowTokens": 128000,
Expand All @@ -84,13 +89,15 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 10,
"output": 30,
},
"releasedAt": "2023-11-02",
},
{
"contextWindowTokens": 128000,
"deploymentName": "gpt-4-vision",
"description": "GPT-4 视觉预览版,专为图像分析和处理任务设计。",
"displayName": "GPT 4 Turbo with Vision Preview",
"id": "gpt-4-vision-preview",
"releasedAt": "2023-11-02",
"vision": true,
},
{
Expand All @@ -103,6 +110,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 30,
"output": 60,
},
"releasedAt": "2023-06-27",
},
{
"contextWindowTokens": 16385,
Expand All @@ -114,6 +122,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 0.5,
"output": 1.5,
},
"releasedAt": "2024-01-23",
},
{
"contextWindowTokens": 8192,
Expand All @@ -125,6 +134,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 30,
"output": 60,
},
"releasedAt": "2023-06-12",
},
]
`;
Loading
Loading