From e7cb62ea0f5a5fc74c71a72fca306309d64f1840 Mon Sep 17 00:00:00 2001 From: Zhijie He Date: Thu, 28 Nov 2024 01:47:05 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20fallback=20behavior?= =?UTF-8?q?=20of=20default=20mode=20in=20AgentRuntime=20(#4813)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🐛 fix: fix fallback behavior of default mode in AgentRuntime * ♻️ refactor: optimize fallback behavior * 🔨 chore: add unit test --- src/server/modules/AgentRuntime/index.test.ts | 22 +++++++++++++++++++ src/server/modules/AgentRuntime/index.ts | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/server/modules/AgentRuntime/index.test.ts b/src/server/modules/AgentRuntime/index.test.ts index ce752fd53710..0cef30e17b5b 100644 --- a/src/server/modules/AgentRuntime/index.test.ts +++ b/src/server/modules/AgentRuntime/index.test.ts @@ -370,6 +370,28 @@ describe('initAgentRuntimeWithUserPayload method', () => { expect(runtime['_runtime']).toBeInstanceOf(LobeWenxinAI); }); + it('OpenAI provider: without apikey with OPENAI_PROXY_URL', async () => { + process.env.OPENAI_PROXY_URL = 'https://proxy.example.com/v1'; + + const jwtPayload: JWTPayload = {}; + const runtime = await initAgentRuntimeWithUserPayload(ModelProvider.OpenAI, jwtPayload); + expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI); + // 应返回 OPENAI_PROXY_URL + expect(runtime['_runtime'].baseURL).toBe('https://proxy.example.com/v1'); + }); + + it('Qwen AI provider: without apiKey and endpoint with OPENAI_PROXY_URL', async () => { + process.env.OPENAI_PROXY_URL = 'https://proxy.example.com/v1'; + + const jwtPayload: JWTPayload = {}; + const runtime = await initAgentRuntimeWithUserPayload(ModelProvider.Qwen, jwtPayload); + + // 假设 LobeQwenAI 是 Qwen 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeQwenAI); + // endpoint 不存在,应返回 DEFAULT_BASE_URL + expect(runtime['_runtime'].baseURL).toBe('https://dashscope.aliyuncs.com/compatible-mode/v1'); + }); + it('Unknown Provider', async () => { const jwtPayload = {}; const runtime = await initAgentRuntimeWithUserPayload('unknown', jwtPayload); diff --git a/src/server/modules/AgentRuntime/index.ts b/src/server/modules/AgentRuntime/index.ts index 5cdceb11cea8..0e969f6d6525 100644 --- a/src/server/modules/AgentRuntime/index.ts +++ b/src/server/modules/AgentRuntime/index.ts @@ -33,7 +33,7 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => { default: { let upperProvider = provider.toUpperCase(); - if (!llmConfig[`${upperProvider}_API_KEY`]) { + if (!( `${upperProvider}_API_KEY` in llmConfig)) { upperProvider = ModelProvider.OpenAI.toUpperCase(); // Use OpenAI options as default }