diff --git a/next.config.ts b/next.config.ts index 28aba9f2eff4..35095f78efe4 100644 --- a/next.config.ts +++ b/next.config.ts @@ -25,7 +25,6 @@ const nextConfig: NextConfig = { '@icons-pack/react-simple-icons', '@lobehub/ui', 'gpt-tokenizer', - 'chroma-js', ], webVitalsAttribution: ['CLS', 'LCP'], }, diff --git a/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/Form.tsx b/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/Form.tsx index 6f3b7a1f7772..aa6e597def7f 100644 --- a/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/Form.tsx +++ b/src/app/(main)/settings/llm/components/ProviderModelList/ModelConfigModal/Form.tsx @@ -2,11 +2,10 @@ import { Checkbox, Form, FormInstance, Input } from 'antd'; import { memo, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; +import MaxTokenSlider from '@/components/MaxTokenSlider'; import { useIsMobile } from '@/hooks/useIsMobile'; import { ChatModelCard } from '@/types/llm'; -import MaxTokenSlider from './MaxTokenSlider'; - interface ModelConfigFormProps { initialValues?: ChatModelCard; onFormInstanceReady: (instance: FormInstance) => void; @@ -66,7 +65,10 @@ const ModelConfigForm = memo( > - + (({ value, onChange, defaultValue }) => { - const { t } = useTranslation('setting'); + const { t } = useTranslation('components'); const [token, setTokens] = useMergeState(0, { defaultValue, @@ -45,7 +44,7 @@ const MaxTokenSlider = memo(({ value, onChange, defaultValu setPowValue(exponent(value / Kibi)); }; - const isMobile = useServerConfigStore(serverConfigSelectors.isMobile); + const isMobile = useIsMobile(); const marks = useMemo(() => { return { @@ -74,7 +73,7 @@ const MaxTokenSlider = memo(({ value, onChange, defaultValu tooltip={{ formatter: (x) => { if (typeof x === 'undefined') return; - if (x === 0) return t('llm.customModelCards.modelConfig.tokens.unlimited'); + if (x === 0) return t('MaxTokenSlider.unlimited'); let value = getRealValue(x); if (value < 125) return value.toFixed(0) + 'K'; diff --git a/src/components/ModelSelect/index.tsx b/src/components/ModelSelect/index.tsx index 265eb845847e..0216225615fc 100644 --- a/src/components/ModelSelect/index.tsx +++ b/src/components/ModelSelect/index.tsx @@ -9,6 +9,7 @@ import { FC, memo } from 'react'; import { useTranslation } from 'react-i18next'; import { Center, Flexbox } from 'react-layout-kit'; +import { ModelAbilities } from '@/types/aiModel'; import { ChatModelCard } from '@/types/llm'; import { formatTokenNumber } from '@/utils/format'; @@ -57,8 +58,10 @@ const useStyles = createStyles(({ css, token }) => ({ `, })); -interface ModelInfoTagsProps extends ChatModelCard { +interface ModelInfoTagsProps extends ModelAbilities { + contextWindowTokens?: number | null; directionReverse?: boolean; + isCustom?: boolean; placement?: 'top' | 'right'; } @@ -102,7 +105,7 @@ export const ModelInfoTags = memo( )} - {model.contextWindowTokens !== undefined && ( + {typeof model.contextWindowTokens === 'number' && ( ( {model.contextWindowTokens === 0 ? ( ) : ( - formatTokenNumber(model.contextWindowTokens) + formatTokenNumber(model.contextWindowTokens as number) )} diff --git a/src/components/NProgress/index.tsx b/src/components/NProgress/index.tsx index 3820eefca717..860c18bf7866 100644 --- a/src/components/NProgress/index.tsx +++ b/src/components/NProgress/index.tsx @@ -6,7 +6,15 @@ import { memo } from 'react'; const NProgress = memo(() => { const theme = useTheme(); - return ; + return ( + + ); }); export default NProgress; diff --git a/src/const/auth.ts b/src/const/auth.ts index 0858275a4b71..5bec0201d5a2 100644 --- a/src/const/auth.ts +++ b/src/const/auth.ts @@ -28,7 +28,7 @@ export interface JWTPayload { /** * Represents the endpoint of provider */ - endpoint?: string; + baseURL?: string; azureApiVersion?: string; diff --git a/src/database/server/models/__tests__/user.test.ts b/src/database/server/models/__tests__/user.test.ts index fe7e4d33b8d7..4f2677b64ff0 100644 --- a/src/database/server/models/__tests__/user.test.ts +++ b/src/database/server/models/__tests__/user.test.ts @@ -130,6 +130,17 @@ describe('UserModel', () => { }); }); + describe('getUserSettings', () => { + it('should get user settings', async () => { + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(userSettings).values({ id: userId, general: { language: 'en-US' } }); + + const data = await userModel.getUserSettings(); + + expect(data).toMatchObject({ id: userId, general: { language: 'en-US' } }); + }); + }); + describe('deleteSetting', () => { it('should delete user settings', async () => { await serverDB.insert(users).values({ id: userId }); diff --git a/src/database/server/models/user.ts b/src/database/server/models/user.ts index e286bb622ac8..74d8edc59cfa 100644 --- a/src/database/server/models/user.ts +++ b/src/database/server/models/user.ts @@ -75,6 +75,10 @@ export class UserModel { }; }; + getUserSettings = async () => { + return this.db.query.userSettings.findFirst({ where: eq(userSettings.id, this.userId) }); + }; + updateUser = async (value: Partial) => { return this.db .update(users) diff --git a/src/libs/agent-runtime/AgentRuntime.test.ts b/src/libs/agent-runtime/AgentRuntime.test.ts index aa97ed0ed77f..ad41d93639c6 100644 --- a/src/libs/agent-runtime/AgentRuntime.test.ts +++ b/src/libs/agent-runtime/AgentRuntime.test.ts @@ -75,8 +75,8 @@ describe('AgentRuntime', () => { describe('Azure OpenAI provider', () => { it('should initialize correctly', async () => { const jwtPayload = { - apikey: 'user-azure-key', - endpoint: 'user-azure-endpoint', + apiKey: 'user-azure-key', + baseURL: 'user-azure-endpoint', apiVersion: '2024-06-01', }; @@ -90,8 +90,8 @@ describe('AgentRuntime', () => { }); it('should initialize with azureOpenAIParams correctly', async () => { const jwtPayload = { - apikey: 'user-openai-key', - endpoint: 'user-endpoint', + apiKey: 'user-openai-key', + baseURL: 'user-endpoint', apiVersion: 'custom-version', }; @@ -106,8 +106,8 @@ describe('AgentRuntime', () => { it('should initialize with AzureAI correctly', async () => { const jwtPayload = { - apikey: 'user-azure-key', - endpoint: 'user-azure-endpoint', + apiKey: 'user-azure-key', + baseURL: 'user-azure-endpoint', }; const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Azure, { azure: jwtPayload, @@ -171,7 +171,7 @@ describe('AgentRuntime', () => { describe('Ollama provider', () => { it('should initialize correctly', async () => { - const jwtPayload: JWTPayload = { endpoint: 'user-ollama-url' }; + const jwtPayload: JWTPayload = { baseURL: 'https://user-ollama-url' }; const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Ollama, { ollama: jwtPayload, }); @@ -255,7 +255,7 @@ describe('AgentRuntime', () => { describe('AgentRuntime chat method', () => { it('should run correctly', async () => { - const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' }; + const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' }; const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, { openai: jwtPayload, }); @@ -271,7 +271,7 @@ describe('AgentRuntime', () => { await runtime.chat(payload); }); it('should handle options correctly', async () => { - const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' }; + const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' }; const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, { openai: jwtPayload, }); @@ -300,7 +300,7 @@ describe('AgentRuntime', () => { }); describe('callback', async () => { - const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' }; + const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' }; const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, { openai: jwtPayload, }); diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index c70c5d4c153e..d82271459ab7 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -133,7 +133,7 @@ class AgentRuntime { ai21: Partial; ai360: Partial; anthropic: Partial; - azure: { apiVersion?: string; apikey?: string; endpoint?: string }; + azure: { apiKey?: string; apiVersion?: string; baseURL?: string }; baichuan: Partial; bedrock: Partial; cloudflare: Partial; @@ -180,8 +180,8 @@ class AgentRuntime { case ModelProvider.Azure: { runtimeModel = new LobeAzureOpenAI( - params.azure?.endpoint, - params.azure?.apikey, + params.azure?.baseURL, + params.azure?.apiKey, params.azure?.apiVersion, ); break; diff --git a/src/libs/agent-runtime/ollama/index.test.ts b/src/libs/agent-runtime/ollama/index.test.ts index d48cebfcc3de..79ab4360249a 100644 --- a/src/libs/agent-runtime/ollama/index.test.ts +++ b/src/libs/agent-runtime/ollama/index.test.ts @@ -29,7 +29,10 @@ describe('LobeOllamaAI', () => { try { new LobeOllamaAI({ baseURL: 'invalid-url' }); } catch (e) { - expect(e).toEqual(AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs)); + expect(e).toEqual({ + error: new TypeError('Invalid URL'), + errorType: 'InvalidOllamaArgs', + }); } }); }); diff --git a/src/libs/agent-runtime/ollama/index.ts b/src/libs/agent-runtime/ollama/index.ts index 6f3fbababd09..47b6023caf64 100644 --- a/src/libs/agent-runtime/ollama/index.ts +++ b/src/libs/agent-runtime/ollama/index.ts @@ -22,8 +22,8 @@ export class LobeOllamaAI implements LobeRuntimeAI { constructor({ baseURL }: ClientOptions = {}) { try { if (baseURL) new URL(baseURL); - } catch { - throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs); + } catch (e) { + throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs, e); } this.client = new Ollama(!baseURL ? undefined : { host: baseURL }); diff --git a/src/libs/agent-runtime/openai/__snapshots__/index.test.ts.snap b/src/libs/agent-runtime/openai/__snapshots__/index.test.ts.snap index 95eb7fa4fd26..f0aea17d2160 100644 --- a/src/libs/agent-runtime/openai/__snapshots__/index.test.ts.snap +++ b/src/libs/agent-runtime/openai/__snapshots__/index.test.ts.snap @@ -12,6 +12,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 0.5, "output": 1.5, }, + "releasedAt": "2023-02-28", }, { "id": "gpt-3.5-turbo-16k", @@ -35,6 +36,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 10, "output": 30, }, + "releasedAt": "2024-01-23", }, { "contextWindowTokens": 128000, @@ -46,6 +48,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 10, "output": 30, }, + "releasedAt": "2024-01-23", }, { "contextWindowTokens": 4096, @@ -56,6 +59,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 1.5, "output": 2, }, + "releasedAt": "2023-08-24", }, { "id": "gpt-3.5-turbo-0301", @@ -73,6 +77,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 1, "output": 2, }, + "releasedAt": "2023-11-02", }, { "contextWindowTokens": 128000, @@ -84,6 +89,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 10, "output": 30, }, + "releasedAt": "2023-11-02", }, { "contextWindowTokens": 128000, @@ -91,6 +97,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "description": "GPT-4 视觉预览版,专为图像分析和处理任务设计。", "displayName": "GPT 4 Turbo with Vision Preview", "id": "gpt-4-vision-preview", + "releasedAt": "2023-11-02", "vision": true, }, { @@ -103,6 +110,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 30, "output": 60, }, + "releasedAt": "2023-06-27", }, { "contextWindowTokens": 16385, @@ -114,6 +122,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 0.5, "output": 1.5, }, + "releasedAt": "2024-01-23", }, { "contextWindowTokens": 8192, @@ -125,6 +134,7 @@ exports[`LobeOpenAI > models > should get models 1`] = ` "input": 30, "output": 60, }, + "releasedAt": "2023-06-12", }, ] `; diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index 814f890df875..48eef0beef3a 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -1,3 +1,5 @@ +import dayjs from 'dayjs'; +import utc from 'dayjs/plugin/utc'; import OpenAI, { ClientOptions } from 'openai'; import { Stream } from 'openai/streaming'; @@ -18,6 +20,7 @@ import type { TextToSpeechOptions, TextToSpeechPayload, } from '../../types'; +import { ChatStreamCallbacks } from '../../types'; import { AgentRuntimeError } from '../createError'; import { debugResponse, debugStream } from '../debugStream'; import { desensitizeUrl } from '../desensitizeUrl'; @@ -25,7 +28,6 @@ import { handleOpenAIError } from '../handleOpenAIError'; import { convertOpenAIMessages } from '../openaiHelpers'; import { StreamingResponse } from '../response'; import { OpenAIStream, OpenAIStreamOptions } from '../streams'; -import { ChatStreamCallbacks } from '../../types'; // the model contains the following keywords is not a chat model, so we should filter them out export const CHAT_MODELS_BLOCK_LIST = [ @@ -248,7 +250,8 @@ export const LobeOpenAICompatibleFactory = = any> if (responseMode === 'json') return Response.json(response); - const transformHandler = chatCompletion?.handleTransformResponseToStream || transformResponseToStream; + const transformHandler = + chatCompletion?.handleTransformResponseToStream || transformResponseToStream; const stream = transformHandler(response as unknown as OpenAI.ChatCompletion); const streamHandler = chatCompletion?.handleStream || OpenAIStream; @@ -278,7 +281,15 @@ export const LobeOpenAICompatibleFactory = = any> const knownModel = LOBE_DEFAULT_MODEL_LIST.find((model) => model.id === item.id); - if (knownModel) return knownModel; + if (knownModel) { + dayjs.extend(utc); + + return { + ...knownModel, + releasedAt: + knownModel.releasedAt ?? dayjs.utc(item.created * 1000).format('YYYY-MM-DD'), + }; + } return { id: item.id }; }) diff --git a/src/locales/default/components.ts b/src/locales/default/components.ts index d2eb6d41f0dd..9afc63d92c22 100644 --- a/src/locales/default/components.ts +++ b/src/locales/default/components.ts @@ -70,6 +70,9 @@ export default { GoBack: { back: '返回', }, + MaxTokenSlider: { + unlimited: '无限制', + }, ModelSelect: { featureTag: { custom: '自定义模型,默认设定同时支持函数调用与视觉识别,请根据实际情况验证上述能力的可用性', diff --git a/src/locales/default/setting.ts b/src/locales/default/setting.ts index 8a7b61a61a46..fb8fbeb2b1c3 100644 --- a/src/locales/default/setting.ts +++ b/src/locales/default/setting.ts @@ -86,7 +86,6 @@ export default { modalTitle: '自定义模型配置', tokens: { title: '最大 token 数', - unlimited: '无限制', }, vision: { extra: diff --git a/src/server/modules/AgentRuntime/index.test.ts b/src/server/modules/AgentRuntime/index.test.ts index 33f843ec2f0f..dc4601c59a79 100644 --- a/src/server/modules/AgentRuntime/index.test.ts +++ b/src/server/modules/AgentRuntime/index.test.ts @@ -70,23 +70,23 @@ vi.mock('@/config/llm', () => ({ describe('initAgentRuntimeWithUserPayload method', () => { describe('should initialize with options correctly', () => { it('OpenAI provider: with apikey and endpoint', async () => { - const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' }; + const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' }; const runtime = await initAgentRuntimeWithUserPayload(ModelProvider.OpenAI, jwtPayload); expect(runtime).toBeInstanceOf(AgentRuntime); expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI); - expect(runtime['_runtime'].baseURL).toBe(jwtPayload.endpoint); + expect(runtime['_runtime'].baseURL).toBe(jwtPayload.baseURL); }); it('Azure AI provider: with apikey, endpoint and apiversion', async () => { const jwtPayload: JWTPayload = { apiKey: 'user-azure-key', - endpoint: 'user-azure-endpoint', + baseURL: 'user-azure-endpoint', azureApiVersion: '2024-06-01', }; const runtime = await initAgentRuntimeWithUserPayload(ModelProvider.Azure, jwtPayload); expect(runtime).toBeInstanceOf(AgentRuntime); expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI); - expect(runtime['_runtime'].baseURL).toBe(jwtPayload.endpoint); + expect(runtime['_runtime'].baseURL).toBe(jwtPayload.baseURL); }); it('ZhiPu AI provider: with apikey', async () => { @@ -130,11 +130,11 @@ describe('initAgentRuntimeWithUserPayload method', () => { }); it('Ollama provider: with endpoint', async () => { - const jwtPayload: JWTPayload = { endpoint: 'http://user-ollama-url' }; + const jwtPayload: JWTPayload = { baseURL: 'http://user-ollama-url' }; const runtime = await initAgentRuntimeWithUserPayload(ModelProvider.Ollama, jwtPayload); expect(runtime).toBeInstanceOf(AgentRuntime); expect(runtime['_runtime']).toBeInstanceOf(LobeOllamaAI); - expect(runtime['_runtime']['baseURL']).toEqual(jwtPayload.endpoint); + expect(runtime['_runtime']['baseURL']).toEqual(jwtPayload.baseURL); }); it('Perplexity AI provider: with apikey', async () => { @@ -220,12 +220,12 @@ describe('initAgentRuntimeWithUserPayload method', () => { it('Unknown Provider: with apikey and endpoint, should initialize to OpenAi', async () => { const jwtPayload: JWTPayload = { apiKey: 'user-unknown-key', - endpoint: 'user-unknown-endpoint', + baseURL: 'user-unknown-endpoint', }; const runtime = await initAgentRuntimeWithUserPayload('unknown', jwtPayload); expect(runtime).toBeInstanceOf(AgentRuntime); expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI); - expect(runtime['_runtime'].baseURL).toBe(jwtPayload.endpoint); + expect(runtime['_runtime'].baseURL).toBe(jwtPayload.baseURL); }); }); diff --git a/src/server/modules/AgentRuntime/index.ts b/src/server/modules/AgentRuntime/index.ts index 73f559109ddd..ead352d43ec4 100644 --- a/src/server/modules/AgentRuntime/index.ts +++ b/src/server/modules/AgentRuntime/index.ts @@ -38,23 +38,23 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => { } const apiKey = apiKeyManager.pick(payload?.apiKey || llmConfig[`${upperProvider}_API_KEY`]); - const baseURL = payload?.endpoint || process.env[`${upperProvider}_PROXY_URL`]; + const baseURL = payload?.baseURL || process.env[`${upperProvider}_PROXY_URL`]; return baseURL ? { apiKey, baseURL } : { apiKey }; } case ModelProvider.Ollama: { - const baseURL = payload?.endpoint || process.env.OLLAMA_PROXY_URL; + const baseURL = payload?.baseURL || process.env.OLLAMA_PROXY_URL; return { baseURL }; } case ModelProvider.Azure: { const { AZURE_API_KEY, AZURE_API_VERSION, AZURE_ENDPOINT } = llmConfig; - const apikey = apiKeyManager.pick(payload?.apiKey || AZURE_API_KEY); - const endpoint = payload?.endpoint || AZURE_ENDPOINT; + const apiKey = apiKeyManager.pick(payload?.apiKey || AZURE_API_KEY); + const baseURL = payload?.baseURL || AZURE_ENDPOINT; const apiVersion = payload?.azureApiVersion || AZURE_API_VERSION; - return { apiVersion, apikey, endpoint }; + return { apiKey, apiVersion, baseURL }; } case ModelProvider.Bedrock: { diff --git a/src/services/__tests__/_auth.test.ts b/src/services/__tests__/_auth.test.ts index 96921b4d7a86..552c4415e53a 100644 --- a/src/services/__tests__/_auth.test.ts +++ b/src/services/__tests__/_auth.test.ts @@ -131,7 +131,7 @@ describe('getProviderAuthPayload', () => { expect(payload).toEqual({ apiKey: mockAzureConfig.apiKey, azureApiVersion: mockAzureConfig.apiVersion, - endpoint: mockAzureConfig.endpoint, + baseURL: mockAzureConfig.endpoint, }); }); @@ -144,7 +144,7 @@ describe('getProviderAuthPayload', () => { const payload = getProviderAuthPayload(ModelProvider.Ollama); expect(payload).toEqual({ - endpoint: mockOllamaProxyUrl, + baseURL: mockOllamaProxyUrl, }); }); @@ -152,8 +152,7 @@ describe('getProviderAuthPayload', () => { // 假设的 OpenAI 配置 const mockOpenAIConfig = { apiKey: 'openai-api-key', - baseURL: 'openai-baseURL', - endpoint: 'openai-endpoint', + baseURL: 'openai-endpoint', useAzure: true, azureApiVersion: 'openai-azure-api-version', }; @@ -164,7 +163,7 @@ describe('getProviderAuthPayload', () => { const payload = getProviderAuthPayload(ModelProvider.OpenAI); expect(payload).toEqual({ apiKey: mockOpenAIConfig.apiKey, - endpoint: mockOpenAIConfig.baseURL, + baseURL: mockOpenAIConfig.baseURL, }); }); @@ -181,7 +180,7 @@ describe('getProviderAuthPayload', () => { const payload = getProviderAuthPayload(ModelProvider.Stepfun); expect(payload).toEqual({ apiKey: mockOpenAIConfig.apiKey, - endpoint: mockOpenAIConfig.baseURL, + baseURL: mockOpenAIConfig.baseURL, }); }); diff --git a/src/services/__tests__/chat.test.ts b/src/services/__tests__/chat.test.ts index a66cb93577fb..52b0032a4266 100644 --- a/src/services/__tests__/chat.test.ts +++ b/src/services/__tests__/chat.test.ts @@ -939,6 +939,7 @@ describe('AgentRuntimeOnClient', () => { }, }, } as UserSettingsState) as unknown as UserStore; + const runtime = await initializeWithClientStore(ModelProvider.Azure, {}); expect(runtime).toBeInstanceOf(AgentRuntime); expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI); diff --git a/src/services/_auth.ts b/src/services/_auth.ts index e7daccddc404..9fb4b49dddc1 100644 --- a/src/services/_auth.ts +++ b/src/services/_auth.ts @@ -45,14 +45,14 @@ export const getProviderAuthPayload = (provider: string) => { return { apiKey: azure.apiKey, azureApiVersion: azure.apiVersion, - endpoint: azure.endpoint, + baseURL: azure.endpoint, }; } case ModelProvider.Ollama: { const config = keyVaultsConfigSelectors.ollamaConfig(useUserStore.getState()); - return { endpoint: config?.baseURL }; + return { baseURL: config?.baseURL }; } case ModelProvider.Cloudflare: { @@ -69,7 +69,7 @@ export const getProviderAuthPayload = (provider: string) => { useUserStore.getState(), ); - return { apiKey: config?.apiKey, endpoint: config?.baseURL }; + return { apiKey: config?.apiKey, baseURL: config?.baseURL }; } } }; diff --git a/src/services/chat.ts b/src/services/chat.ts index 0804504e3e15..7999bdae3b9b 100644 --- a/src/services/chat.ts +++ b/src/services/chat.ts @@ -94,21 +94,20 @@ export function initializeWithClientStore(provider: string, payload: any) { default: case ModelProvider.OpenAI: { providerOptions = { - baseURL: providerAuthPayload?.endpoint, + baseURL: providerAuthPayload?.baseURL, }; break; } case ModelProvider.Azure: { providerOptions = { + apiKey: providerAuthPayload?.apiKey, apiVersion: providerAuthPayload?.azureApiVersion, - // That's a wired properity, but just remapped it - apikey: providerAuthPayload?.apiKey, }; break; } case ModelProvider.Google: { providerOptions = { - baseURL: providerAuthPayload?.endpoint, + baseURL: providerAuthPayload?.baseURL, }; break; } @@ -125,27 +124,27 @@ export function initializeWithClientStore(provider: string, payload: any) { } case ModelProvider.Ollama: { providerOptions = { - baseURL: providerAuthPayload?.endpoint, + baseURL: providerAuthPayload?.baseURL, }; break; } case ModelProvider.Perplexity: { providerOptions = { apikey: providerAuthPayload?.apiKey, - baseURL: providerAuthPayload?.endpoint, + baseURL: providerAuthPayload?.baseURL, }; break; } case ModelProvider.Anthropic: { providerOptions = { - baseURL: providerAuthPayload?.endpoint, + baseURL: providerAuthPayload?.baseURL, }; break; } case ModelProvider.Groq: { providerOptions = { apikey: providerAuthPayload?.apiKey, - baseURL: providerAuthPayload?.endpoint, + baseURL: providerAuthPayload?.baseURL, }; break; } diff --git a/src/types/aiModel.ts b/src/types/aiModel.ts new file mode 100644 index 000000000000..9d954e8c3818 --- /dev/null +++ b/src/types/aiModel.ts @@ -0,0 +1,275 @@ +import { z } from 'zod'; + +export type ModelPriceCurrency = 'CNY' | 'USD'; + +export const AiModelSourceEnum = { + Builtin: 'builtin', + Custom: 'custom', + Remote: 'remote', +} as const; +export type AiModelSourceType = (typeof AiModelSourceEnum)[keyof typeof AiModelSourceEnum]; + +export type AiModelType = + | 'chat' + | 'embedding' + | 'tts' + | 'stt' + | 'image' + | 'text2video' + | 'text2music'; + +export interface ModelAbilities { + /** + * whether model supports file upload + */ + files?: boolean; + /** + * whether model supports function call + */ + functionCall?: boolean; + /** + * whether model supports vision + */ + vision?: boolean; +} + +const AiModelAbilitiesSchema = z.object({ + // files: z.boolean().optional(), + functionCall: z.boolean().optional(), + vision: z.boolean().optional(), +}); + +// 语言模型的设置参数 +export interface LLMParams { + /** + * 控制生成文本中的惩罚系数,用于减少重复性 + * @default 0 + */ + frequency_penalty?: number; + /** + * 生成文本的最大长度 + */ + max_tokens?: number; + /** + * 控制生成文本中的惩罚系数,用于减少主题的变化 + * @default 0 + */ + presence_penalty?: number; + /** + * 生成文本的随机度量,用于控制文本的创造性和多样性 + * @default 1 + */ + temperature?: number; + /** + * 控制生成文本中最高概率的单个 token + * @default 1 + */ + top_p?: number; +} + +export interface BasicModelPricing { + /** + * the currency of the pricing + * @default USD + */ + currency?: ModelPriceCurrency; + /** + * the input pricing, e.g. $1 / 1M tokens + */ + input?: number; +} + +export interface ChatModelPricing extends BasicModelPricing { + audioInput?: number; + audioOutput?: number; + cachedAudioInput?: number; + cachedInput?: number; + /** + * the output pricing, e.g. $2 / 1M tokens + */ + output?: number; + writeCacheInput?: number; +} + +interface AIBaseModelCard { + /** + * the context window (or input + output tokens limit) + */ + contextWindowTokens?: number; + description?: string; + /** + * the name show for end user + */ + displayName?: string; + enabled?: boolean; + id: string; + /** + * whether model is legacy (deprecated but not removed yet) + */ + legacy?: boolean; + /** + * who create this model + */ + organization?: string; + + releasedAt?: string; +} + +export interface AIChatModelCard extends AIBaseModelCard { + abilities?: { + /** + * whether model supports file upload + */ + files?: boolean; + /** + * whether model supports function call + */ + functionCall?: boolean; + /** + * whether model supports vision + */ + vision?: boolean; + }; + /** + * used in azure and doubao + */ + deploymentName?: string; + maxOutput?: number; + pricing?: ChatModelPricing; + type: 'chat'; +} + +export interface AIEmbeddingModelCard extends AIBaseModelCard { + maxDimension: number; + pricing?: { + /** + * the currency of the pricing + * @default USD + */ + currency?: ModelPriceCurrency; + /** + * the input pricing, e.g. $1 / 1M tokens + */ + input?: number; + }; + type: 'embedding'; +} + +export interface AIText2ImageModelCard extends AIBaseModelCard { + pricing?: { + /** + * the currency of the pricing + * @default USD + */ + currency?: ModelPriceCurrency; + } & Record; // [resolution: string]: number; + resolutions: string[]; + type: 'image'; +} + +export interface AITTSModelCard extends AIBaseModelCard { + pricing?: { + /** + * the currency of the pricing + * @default USD + */ + currency?: ModelPriceCurrency; + /** + * the input pricing, e.g. $1 / 1M tokens + */ + input?: number; + }; + type: 'tts'; +} + +export interface AISTTModelCard extends AIBaseModelCard { + pricing?: { + /** + * the currency of the pricing + * @default USD + */ + currency?: ModelPriceCurrency; + /** + * the input pricing, e.g. $1 / 1M tokens + */ + input?: number; + }; + type: 'stt'; +} + +export interface AIRealtimeModelCard extends AIBaseModelCard { + abilities?: { + /** + * whether model supports file upload + */ + files?: boolean; + /** + * whether model supports function call + */ + functionCall?: boolean; + /** + * whether model supports vision + */ + vision?: boolean; + }; + /** + * used in azure and doubao + */ + deploymentName?: string; + maxOutput?: number; + pricing?: ChatModelPricing; + type: 'realtime'; +} + +// create +export const CreateAiModelSchema = z.object({ + abilities: AiModelAbilitiesSchema.optional(), + contextWindowTokens: z.number().optional(), + displayName: z.string().optional(), + id: z.string(), + providerId: z.string(), + releasedAt: z.string().optional(), + + // checkModel: z.string().optional(), + // homeUrl: z.string().optional(), + // modelsUrl: z.string().optional(), +}); + +export type CreateAiModelParams = z.infer; + +// List Query + +export interface AiProviderModelListItem { + abilities?: ModelAbilities; + contextWindowTokens?: number; + displayName?: string; + enabled: boolean; + id: string; + pricing?: ChatModelPricing; + releasedAt?: string; + source?: AiModelSourceType; + type: AiModelType; +} + +// Update +export const UpdateAiModelSchema = z.object({ + abilities: AiModelAbilitiesSchema.optional(), + contextWindowTokens: z.number().optional(), + displayName: z.string().optional(), +}); + +export type UpdateAiModelParams = z.infer; + +export interface AiModelSortMap { + id: string; + sort: number; +} + +export const ToggleAiModelEnableSchema = z.object({ + enabled: z.boolean(), + id: z.string(), + providerId: z.string(), + source: z.enum(['builtin', 'custom', 'remote']).optional(), +}); + +export type ToggleAiModelEnableParams = z.infer; diff --git a/src/types/aiProvider.ts b/src/types/aiProvider.ts new file mode 100644 index 000000000000..5a63b71dae54 --- /dev/null +++ b/src/types/aiProvider.ts @@ -0,0 +1,148 @@ +import { z } from 'zod'; + +import { SmoothingParams } from '@/types/llm'; + +// create +export const CreateAiProviderSchema = z.object({ + config: z.object({}).passthrough().optional(), + description: z.string().optional(), + id: z.string(), + keyVaults: z.any().optional(), + logo: z.string().optional(), + name: z.string(), + sdkType: z.enum(['openai', 'anthropic']).optional(), + // checkModel: z.string().optional(), + // homeUrl: z.string().optional(), + // modelsUrl: z.string().optional(), +}); + +export type CreateAiProviderParams = z.infer; + +// List Query + +export interface AiProviderListItem { + description?: string; + enabled: boolean; + id: string; + logo?: string; + name?: string; + sort?: number; + source: 'builtin' | 'custom'; +} + +// Detail Query + +interface AiProviderConfig { + /** + * whether provider show browser request option by default + * + * @default false + */ + defaultShowBrowserRequest?: boolean; + /** + * some provider server like stepfun and aliyun don't support browser request, + * So we should disable it + * + * @default false + */ + disableBrowserRequest?: boolean; + proxyUrl?: + | { + desc?: string; + placeholder: string; + title?: string; + } + | false; + + /** + * whether show api key in the provider config + * so provider like ollama don't need api key field + */ + showApiKey?: boolean; + + /** + * whether show checker in the provider config + */ + showChecker?: boolean; + showDeployName?: boolean; + showModelFetcher?: boolean; + /** + * whether to smoothing the output + */ + smoothing?: SmoothingParams; +} + +export interface AiProviderItem { + /** + * the default model that used for connection check + */ + checkModel?: string; + config: AiProviderConfig; + description?: string; + enabled: boolean; + enabledChatModels: string[]; + /** + * provider's website url + */ + homeUrl?: string; + id: string; + logo?: string; + /** + * the url show the all models in the provider + */ + modelsUrl?: string; + /** + * the name show for end user + */ + name: string; + /** + * default openai + */ + sdkType?: 'openai' | 'anthropic'; + source: 'builtin' | 'custom'; +} + +export interface AiProviderDetailItem { + /** + * the default model that used for connection check + */ + checkModel?: string; + config: AiProviderConfig; + description?: string; + enabled: boolean; + fetchOnClient?: boolean; + /** + * provider's website url + */ + homeUrl?: string; + id: string; + keyVaults?: Record; + logo?: string; + /** + * the url show the all models in the provider + */ + modelsUrl?: string; + /** + * the name show for end user + */ + name: string; + /** + * default openai + */ + sdkType?: 'openai' | 'anthropic'; + source: 'builtin' | 'custom'; +} + +// Update +export const UpdateAiProviderConfigSchema = z.object({ + checkModel: z.string().optional(), + fetchOnClient: z.boolean().optional(), + keyVaults: z.object({}).passthrough().optional(), +}); + +export type UpdateAiProviderConfigParams = z.infer; + +export interface AiProviderSortMap { + id: string; + sort: number; +} diff --git a/src/types/llm.ts b/src/types/llm.ts index 36bd5d77d53d..15471746453d 100644 --- a/src/types/llm.ts +++ b/src/types/llm.ts @@ -1,5 +1,7 @@ import { ReactNode } from 'react'; +import { ChatModelPricing } from '@/types/aiModel'; + export type ModelPriceCurrency = 'CNY' | 'USD'; export interface ChatModelCard { @@ -38,23 +40,7 @@ export interface ChatModelCard { */ legacy?: boolean; maxOutput?: number; - pricing?: { - cachedInput?: number; - /** - * the currency of the pricing - * @default USD - */ - currency?: ModelPriceCurrency; - /** - * the input pricing, e.g. $1 / 1M tokens - */ - input?: number; - /** - * the output pricing, e.g. $2 / 1M tokens - */ - output?: number; - writeCacheInput?: number; - }; + pricing?: ChatModelPricing; releasedAt?: string; /** diff --git a/src/utils/merge.test.ts b/src/utils/merge.test.ts new file mode 100644 index 000000000000..c3ba5907fdc8 --- /dev/null +++ b/src/utils/merge.test.ts @@ -0,0 +1,48 @@ +import { expect } from 'vitest'; + +import { AIChatModelCard } from '@/types/aiModel'; + +import { mergeArrayById } from './merge'; + +describe('mergeArrayById', () => { + it('should merge data', () => { + const data = mergeArrayById( + [ + { + contextWindowTokens: 128_000, + description: + 'o1-mini是一款针对编程、数学和科学应用场景而设计的快速、经济高效的推理模型。该模型具有128K上下文和2023年10月的知识截止日期。', + displayName: 'OpenAI o1-mini', + enabled: true, + id: 'o1-mini', + maxOutput: 65_536, + pricing: { + input: 3, + output: 12, + }, + releasedAt: '2024-09-12', + type: 'chat', + }, + ], + [{ id: 'o1-mini', displayName: 'OpenAI o1-mini ABC', type: 'chat' }], + ); + + expect(data).toEqual([ + { + contextWindowTokens: 128_000, + description: + 'o1-mini是一款针对编程、数学和科学应用场景而设计的快速、经济高效的推理模型。该模型具有128K上下文和2023年10月的知识截止日期。', + displayName: 'OpenAI o1-mini ABC', + enabled: true, + id: 'o1-mini', + maxOutput: 65_536, + pricing: { + input: 3, + output: 12, + }, + releasedAt: '2024-09-12', + type: 'chat', + }, + ]); + }); +}); diff --git a/src/utils/merge.ts b/src/utils/merge.ts index 714705988135..3ab798e81992 100644 --- a/src/utils/merge.ts +++ b/src/utils/merge.ts @@ -9,3 +9,42 @@ export const merge: typeof _merge = (target: T, source: T) => mergeWith({}, target, source, (obj, src) => { if (Array.isArray(obj)) return src; }); + +type MergeableItem = { + [key: string]: any; + id: string; +}; + +/** + * Merge two arrays based on id, preserving metadata from default items + * @param defaultItems Items with default configuration and metadata + * @param userItems User-defined items with higher priority + */ +export const mergeArrayById = (defaultItems: T[], userItems: T[]): T[] => { + // Create a map of default items for faster lookup + const defaultItemsMap = new Map(defaultItems.map((item) => [item.id, item])); + + // Process user items with default metadata + const mergedItems = userItems.map((userItem) => { + const defaultItem = defaultItemsMap.get(userItem.id); + if (!defaultItem) return userItem; + + // Merge strategy: use default value when user value is null or undefined + const mergedItem: T = { ...defaultItem }; + Object.entries(userItem).forEach(([key, value]) => { + // Only use user value if it's not null and not undefined + if (value !== null && value !== undefined) { + // @ts-expect-error + mergedItem[key] = value; + } + }); + + return mergedItem; + }); + + // Add items that only exist in default configuration + const userItemIds = new Set(userItems.map((item) => item.id)); + const onlyInDefaultItems = defaultItems.filter((item) => !userItemIds.has(item.id)); + + return [...mergedItems, ...onlyInDefaultItems]; +};