From ea7e732350508ecb13425c8d528bdb30e9b48019 Mon Sep 17 00:00:00 2001 From: Zhijie He Date: Sun, 29 Dec 2024 13:06:52 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20custom=20stream=20han?= =?UTF-8?q?dle=20support=20for=20LobeOpenAICompatibleFactory=20(#5039)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ refactor: add function call support for Spark * ♻️ refactor: add non-stream mode support * ⚡️ perf: using stream mode for tools call * ✨ feat: add `handleStream` & `handleStreamResponse` for LobeOpenAICompatibleFactory, custom stream handle * ✨ feat: add `handleTtransformResponseToStream` for custom non-stream transform handle * ♻️ refactor: refactor qwen to LobeOpenAICompatibleFactory, enable `enable_search` for Qwen LLM * 🔨 chore: add unit test for LobeOpenAICompatibleFactory * 🔨 chore: add unit test for SparkAIStream * 🔨 chore: add unit test for Qwen & Spark * 🐛 fix: fix Qwen param range error * 🔨 chore: add `QwenLegacyModels` array, limit `presence_penalty` * 🐛 fix: fix typo --- src/config/modelProviders/spark.ts | 9 +- src/libs/agent-runtime/qwen/index.test.ts | 201 ++---------------- src/libs/agent-runtime/qwen/index.ts | 173 ++++----------- src/libs/agent-runtime/spark/index.test.ts | 52 +++-- src/libs/agent-runtime/spark/index.ts | 4 + .../openaiCompatibleFactory/index.test.ts | 131 ++++++++++++ .../utils/openaiCompatibleFactory/index.ts | 17 +- src/libs/agent-runtime/utils/streams/index.ts | 1 + .../agent-runtime/utils/streams/spark.test.ts | 199 +++++++++++++++++ src/libs/agent-runtime/utils/streams/spark.ts | 134 ++++++++++++ 10 files changed, 570 insertions(+), 351 deletions(-) create mode 100644 src/libs/agent-runtime/utils/streams/spark.test.ts create mode 100644 src/libs/agent-runtime/utils/streams/spark.ts diff --git a/src/config/modelProviders/spark.ts b/src/config/modelProviders/spark.ts index a03c8e853aa8..dbe407464d27 100644 --- a/src/config/modelProviders/spark.ts +++ b/src/config/modelProviders/spark.ts @@ -10,7 +10,6 @@ const Spark: ModelProviderCard = { 'Spark Lite 是一款轻量级大语言模型,具备极低的延迟与高效的处理能力,完全免费开放,支持实时在线搜索功能。其快速响应的特性使其在低算力设备上的推理应用和模型微调中表现出色,为用户带来出色的成本效益和智能体验,尤其在知识问答、内容生成及搜索场景下表现不俗。', displayName: 'Spark Lite', enabled: true, - functionCall: false, id: 'lite', maxOutput: 4096, }, @@ -20,7 +19,6 @@ const Spark: ModelProviderCard = { 'Spark Pro 是一款为专业领域优化的高性能大语言模型,专注数学、编程、医疗、教育等多个领域,并支持联网搜索及内置天气、日期等插件。其优化后模型在复杂知识问答、语言理解及高层次文本创作中展现出色表现和高效性能,是适合专业应用场景的理想选择。', displayName: 'Spark Pro', enabled: true, - functionCall: false, id: 'generalv3', maxOutput: 8192, }, @@ -30,7 +28,6 @@ const Spark: ModelProviderCard = { 'Spark Pro 128K 配置了特大上下文处理能力,能够处理多达128K的上下文信息,特别适合需通篇分析和长期逻辑关联处理的长文内容,可在复杂文本沟通中提供流畅一致的逻辑与多样的引用支持。', displayName: 'Spark Pro 128K', enabled: true, - functionCall: false, id: 'pro-128k', maxOutput: 4096, }, @@ -40,7 +37,7 @@ const Spark: ModelProviderCard = { 'Spark Max 为功能最为全面的版本,支持联网搜索及众多内置插件。其全面优化的核心能力以及系统角色设定和函数调用功能,使其在各种复杂应用场景中的表现极为优异和出色。', displayName: 'Spark Max', enabled: true, - functionCall: false, + functionCall: true, id: 'generalv3.5', maxOutput: 8192, }, @@ -50,7 +47,7 @@ const Spark: ModelProviderCard = { 'Spark Max 32K 配置了大上下文处理能力,更强的上下文理解和逻辑推理能力,支持32K tokens的文本输入,适用于长文档阅读、私有知识问答等场景', displayName: 'Spark Max 32K', enabled: true, - functionCall: false, + functionCall: true, id: 'max-32k', maxOutput: 8192, }, @@ -60,7 +57,7 @@ const Spark: ModelProviderCard = { 'Spark Ultra 是星火大模型系列中最为强大的版本,在升级联网搜索链路同时,提升对文本内容的理解和总结能力。它是用于提升办公生产力和准确响应需求的全方位解决方案,是引领行业的智能产品。', displayName: 'Spark 4.0 Ultra', enabled: true, - functionCall: false, + functionCall: true, id: '4.0Ultra', maxOutput: 8192, }, diff --git a/src/libs/agent-runtime/qwen/index.test.ts b/src/libs/agent-runtime/qwen/index.test.ts index 819a30a5175d..f813686e01fb 100644 --- a/src/libs/agent-runtime/qwen/index.test.ts +++ b/src/libs/agent-runtime/qwen/index.test.ts @@ -2,8 +2,9 @@ import OpenAI from 'openai'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import Qwen from '@/config/modelProviders/qwen'; -import { AgentRuntimeErrorType, ModelProvider } from '@/libs/agent-runtime'; +import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime'; +import { ModelProvider } from '@/libs/agent-runtime'; +import { AgentRuntimeErrorType } from '@/libs/agent-runtime'; import * as debugStreamModule from '../utils/debugStream'; import { LobeQwenAI } from './index'; @@ -16,7 +17,7 @@ const invalidErrorType = AgentRuntimeErrorType.InvalidProviderAPIKey; // Mock the console.error to avoid polluting test output vi.spyOn(console, 'error').mockImplementation(() => {}); -let instance: LobeQwenAI; +let instance: LobeOpenAICompatibleRuntime; beforeEach(() => { instance = new LobeQwenAI({ apiKey: 'test' }); @@ -40,183 +41,7 @@ describe('LobeQwenAI', () => { }); }); - describe('models', () => { - it('should correctly list available models', async () => { - const instance = new LobeQwenAI({ apiKey: 'test_api_key' }); - vi.spyOn(instance, 'models').mockResolvedValue(Qwen.chatModels); - - const models = await instance.models(); - expect(models).toEqual(Qwen.chatModels); - }); - }); - describe('chat', () => { - describe('Params', () => { - it('should call llms with proper options', async () => { - const mockStream = new ReadableStream(); - const mockResponse = Promise.resolve(mockStream); - - (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); - - const result = await instance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', - temperature: 0.6, - top_p: 0.7, - }); - - // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( - { - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', - temperature: 0.6, - stream: true, - top_p: 0.7, - result_format: 'message', - }, - { headers: { Accept: '*/*' } }, - ); - expect(result).toBeInstanceOf(Response); - }); - - it('should call vlms with proper options', async () => { - const mockStream = new ReadableStream(); - const mockResponse = Promise.resolve(mockStream); - - (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); - - const result = await instance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-vl-plus', - temperature: 0.6, - top_p: 0.7, - }); - - // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( - { - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-vl-plus', - stream: true, - }, - { headers: { Accept: '*/*' } }, - ); - expect(result).toBeInstanceOf(Response); - }); - - it('should transform non-streaming response to stream correctly', async () => { - const mockResponse = { - id: 'chatcmpl-fc539f49-51a8-94be-8061', - object: 'chat.completion', - created: 1719901794, - model: 'qwen-turbo', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Hello' }, - finish_reason: 'stop', - logprobs: null, - }, - ], - } as OpenAI.ChatCompletion; - vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( - mockResponse as any, - ); - - const result = await instance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', - temperature: 0.6, - stream: false, - }); - - const decoder = new TextDecoder(); - const reader = result.body!.getReader(); - const stream: string[] = []; - - while (true) { - const { value, done } = await reader.read(); - if (done) break; - stream.push(decoder.decode(value)); - } - - expect(stream).toEqual([ - 'id: chatcmpl-fc539f49-51a8-94be-8061\n', - 'event: text\n', - 'data: "Hello"\n\n', - 'id: chatcmpl-fc539f49-51a8-94be-8061\n', - 'event: stop\n', - 'data: "stop"\n\n', - ]); - - expect((await reader.read()).done).toBe(true); - }); - - it('should set temperature to undefined if temperature is 0 or >= 2', async () => { - const temperatures = [0, 2, 3]; - const expectedTemperature = undefined; - - for (const temp of temperatures) { - vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( - new ReadableStream() as any, - ); - await instance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', - temperature: temp, - }); - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: expect.any(Array), - model: 'qwen-turbo', - temperature: expectedTemperature, - }), - expect.any(Object), - ); - } - }); - - it('should set temperature to original temperature', async () => { - vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( - new ReadableStream() as any, - ); - await instance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', - temperature: 1.5, - }); - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: expect.any(Array), - model: 'qwen-turbo', - temperature: 1.5, - }), - expect.any(Object), - ); - }); - - it('should set temperature to Float', async () => { - const createMock = vi.fn().mockResolvedValue(new ReadableStream() as any); - vi.spyOn(instance['client'].chat.completions, 'create').mockImplementation(createMock); - await instance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', - temperature: 1, - }); - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: expect.any(Array), - model: 'qwen-turbo', - temperature: expect.any(Number), - }), - expect.any(Object), - ); - const callArgs = createMock.mock.calls[0][0]; - expect(Number.isInteger(callArgs.temperature)).toBe(false); // Temperature is always not an integer - }); - }); - describe('Error', () => { it('should return QwenBizError with an openai error response when OpenAI.APIError is thrown', async () => { // Arrange @@ -238,7 +63,7 @@ describe('LobeQwenAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', + model: 'qwen-turbo-latest', temperature: 0.999, }); } catch (e) { @@ -278,7 +103,7 @@ describe('LobeQwenAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', + model: 'qwen-turbo-latest', temperature: 0.999, }); } catch (e) { @@ -304,7 +129,8 @@ describe('LobeQwenAI', () => { instance = new LobeQwenAI({ apiKey: 'test', - baseURL: defaultBaseURL, + + baseURL: 'https://api.abc.com/v1', }); vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); @@ -313,13 +139,12 @@ describe('LobeQwenAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', + model: 'qwen-turbo-latest', temperature: 0.999, }); } catch (e) { expect(e).toEqual({ - /* Desensitizing is unnecessary for a public-accessible gateway endpoint. */ - endpoint: defaultBaseURL, + endpoint: 'https://api.***.com/v1', error: { cause: { message: 'api is undefined' }, stack: 'abc', @@ -339,7 +164,7 @@ describe('LobeQwenAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', + model: 'qwen-turbo-latest', temperature: 0.999, }); } catch (e) { @@ -362,7 +187,7 @@ describe('LobeQwenAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', + model: 'qwen-turbo-latest', temperature: 0.999, }); } catch (e) { @@ -410,7 +235,7 @@ describe('LobeQwenAI', () => { // 假设的测试函数调用,你可能需要根据实际情况调整 await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'qwen-turbo', + model: 'qwen-turbo-latest', stream: true, temperature: 0.999, }); diff --git a/src/libs/agent-runtime/qwen/index.ts b/src/libs/agent-runtime/qwen/index.ts index 0bfee54edd10..b0cc566f5b0b 100644 --- a/src/libs/agent-runtime/qwen/index.ts +++ b/src/libs/agent-runtime/qwen/index.ts @@ -1,129 +1,50 @@ -import { omit } from 'lodash-es'; -import OpenAI, { ClientOptions } from 'openai'; +import { ModelProvider } from '../types'; +import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; -import Qwen from '@/config/modelProviders/qwen'; - -import { LobeRuntimeAI } from '../BaseAI'; -import { AgentRuntimeErrorType } from '../error'; -import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types'; -import { AgentRuntimeError } from '../utils/createError'; -import { debugStream } from '../utils/debugStream'; -import { handleOpenAIError } from '../utils/handleOpenAIError'; -import { transformResponseToStream } from '../utils/openaiCompatibleFactory'; -import { StreamingResponse } from '../utils/response'; import { QwenAIStream } from '../utils/streams'; -const DEFAULT_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'; - -/** - * Use DashScope OpenAI compatible mode for now. - * DashScope OpenAI [compatible mode](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api) currently supports base64 image input for vision models e.g. qwen-vl-plus. - * You can use images input either: - * 1. Use qwen-vl-* out of box with base64 image_url input; - * or - * 2. Set S3-* enviroment variables properly to store all uploaded files. - */ -export class LobeQwenAI implements LobeRuntimeAI { - client: OpenAI; - baseURL: string; - - constructor({ - apiKey, - baseURL = DEFAULT_BASE_URL, - ...res - }: ClientOptions & Record = {}) { - if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey); - this.client = new OpenAI({ apiKey, baseURL, ...res }); - this.baseURL = this.client.baseURL; - } - - async models() { - return Qwen.chatModels; - } - - async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { - try { - const params = this.buildCompletionParamsByModel(payload); - - const response = await this.client.chat.completions.create( - params as OpenAI.ChatCompletionCreateParamsStreaming & { result_format: string }, - { - headers: { Accept: '*/*' }, - signal: options?.signal, - }, - ); - - if (params.stream) { - const [prod, debug] = response.tee(); - - if (process.env.DEBUG_QWEN_CHAT_COMPLETION === '1') { - debugStream(debug.toReadableStream()).catch(console.error); - } - - return StreamingResponse(QwenAIStream(prod, options?.callback), { - headers: options?.headers, - }); - } - - const stream = transformResponseToStream(response as unknown as OpenAI.ChatCompletion); - - return StreamingResponse(QwenAIStream(stream, options?.callback), { - headers: options?.headers, - }); - } catch (error) { - if ('status' in (error as any)) { - switch ((error as Response).status) { - case 401: { - throw AgentRuntimeError.chat({ - endpoint: this.baseURL, - error: error as any, - errorType: AgentRuntimeErrorType.InvalidProviderAPIKey, - provider: ModelProvider.Qwen, - }); - } - - default: { - break; - } - } - } - const { errorResult, RuntimeError } = handleOpenAIError(error); - const errorType = RuntimeError || AgentRuntimeErrorType.ProviderBizError; - - throw AgentRuntimeError.chat({ - endpoint: this.baseURL, - error: errorResult, - errorType, - provider: ModelProvider.Qwen, - }); - } - } - - private buildCompletionParamsByModel(payload: ChatStreamPayload) { - const { model, temperature, top_p, stream, messages, tools } = payload; - const isVisionModel = model.startsWith('qwen-vl'); - - const params = { - ...payload, - messages, - result_format: 'message', - stream: !!tools?.length ? false : (stream ?? true), - temperature: - temperature === 0 || temperature >= 2 ? undefined : temperature === 1 ? 0.999 : temperature, // 'temperature' must be Float - top_p: top_p && top_p >= 1 ? 0.999 : top_p, - }; - - /* Qwen-vl models temporarily do not support parameters below. */ - /* Notice: `top_p` imposes significant impact on the result,the default 1 or 0.999 is not a proper choice. */ - return isVisionModel - ? omit( - params, - 'presence_penalty', - 'frequency_penalty', - 'temperature', - 'result_format', - 'top_p', - ) - : omit(params, 'frequency_penalty'); - } -} +/* + QwenLegacyModels: A set of legacy Qwen models that do not support presence_penalty. + Currently, presence_penalty is only supported on Qwen commercial models and open-source models starting from Qwen 1.5 and later. +*/ +export const QwenLegacyModels = new Set([ + 'qwen-72b-chat', + 'qwen-14b-chat', + 'qwen-7b-chat', + 'qwen-1.8b-chat', + 'qwen-1.8b-longcontext-chat', +]); + +export const LobeQwenAI = LobeOpenAICompatibleFactory({ + baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + chatCompletion: { + handlePayload: (payload) => { + const { model, presence_penalty, temperature, top_p, ...rest } = payload; + + return { + ...rest, + frequency_penalty: undefined, + model, + presence_penalty: + QwenLegacyModels.has(model) + ? undefined + : (presence_penalty !== undefined && presence_penalty >= -2 && presence_penalty <= 2) + ? presence_penalty + : undefined, + stream: !payload.tools, + temperature: (temperature !== undefined && temperature >= 0 && temperature < 2) ? temperature : undefined, + ...(model.startsWith('qwen-vl') ? { + top_p: (top_p !== undefined && top_p > 0 && top_p <= 1) ? top_p : undefined, + } : { + enable_search: true, + top_p: (top_p !== undefined && top_p > 0 && top_p < 1) ? top_p : undefined, + }), + } as any; + }, + handleStream: QwenAIStream, + }, + debug: { + chatCompletion: () => process.env.DEBUG_QWEN_CHAT_COMPLETION === '1', + }, + provider: ModelProvider.Qwen, +}); diff --git a/src/libs/agent-runtime/spark/index.test.ts b/src/libs/agent-runtime/spark/index.test.ts index 7b6b1a2b1a06..b87b65565ef7 100644 --- a/src/libs/agent-runtime/spark/index.test.ts +++ b/src/libs/agent-runtime/spark/index.test.ts @@ -2,20 +2,17 @@ import OpenAI from 'openai'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import { - ChatStreamCallbacks, - LobeOpenAICompatibleRuntime, - ModelProvider, -} from '@/libs/agent-runtime'; +import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime'; +import { ModelProvider } from '@/libs/agent-runtime'; +import { AgentRuntimeErrorType } from '@/libs/agent-runtime'; import * as debugStreamModule from '../utils/debugStream'; import { LobeSparkAI } from './index'; const provider = ModelProvider.Spark; const defaultBaseURL = 'https://spark-api-open.xf-yun.com/v1'; - -const bizErrorType = 'ProviderBizError'; -const invalidErrorType = 'InvalidProviderAPIKey'; +const bizErrorType = AgentRuntimeErrorType.ProviderBizError; +const invalidErrorType = AgentRuntimeErrorType.InvalidProviderAPIKey; // Mock the console.error to avoid polluting test output vi.spyOn(console, 'error').mockImplementation(() => {}); @@ -46,7 +43,7 @@ describe('LobeSparkAI', () => { describe('chat', () => { describe('Error', () => { - it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + it('should return QwenBizError with an openai error response when OpenAI.APIError is thrown', async () => { // Arrange const apiError = new OpenAI.APIError( 400, @@ -66,8 +63,8 @@ describe('LobeSparkAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'general', - temperature: 0, + model: 'max-32k', + temperature: 0.999, }); } catch (e) { expect(e).toEqual({ @@ -82,7 +79,7 @@ describe('LobeSparkAI', () => { } }); - it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + it('should throw AgentRuntimeError with InvalidQwenAPIKey if no apiKey is provided', async () => { try { new LobeSparkAI({}); } catch (e) { @@ -90,7 +87,7 @@ describe('LobeSparkAI', () => { } }); - it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + it('should return QwenBizError with the cause when OpenAI.APIError is thrown with cause', async () => { // Arrange const errorInfo = { stack: 'abc', @@ -106,8 +103,8 @@ describe('LobeSparkAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'general', - temperature: 0, + model: 'max-32k', + temperature: 0.999, }); } catch (e) { expect(e).toEqual({ @@ -122,7 +119,7 @@ describe('LobeSparkAI', () => { } }); - it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + it('should return QwenBizError with an cause response with desensitize Url', async () => { // Arrange const errorInfo = { stack: 'abc', @@ -142,8 +139,8 @@ describe('LobeSparkAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'general', - temperature: 0, + model: 'max-32k', + temperature: 0.999, }); } catch (e) { expect(e).toEqual({ @@ -158,23 +155,22 @@ describe('LobeSparkAI', () => { } }); - it('should throw an InvalidSparkAPIKey error type on 401 status code', async () => { + it('should throw an InvalidQwenAPIKey error type on 401 status code', async () => { // Mock the API call to simulate a 401 error - const error = new Error('Unauthorized') as any; + const error = new Error('InvalidApiKey') as any; error.status = 401; vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'general', - temperature: 0, + model: 'max-32k', + temperature: 0.999, }); } catch (e) { - // Expect the chat method to throw an error with InvalidSparkAPIKey expect(e).toEqual({ endpoint: defaultBaseURL, - error: new Error('Unauthorized'), + error: new Error('InvalidApiKey'), errorType: invalidErrorType, provider, }); @@ -191,8 +187,8 @@ describe('LobeSparkAI', () => { try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'general', - temperature: 0, + model: 'max-32k', + temperature: 0.999, }); } catch (e) { expect(e).toEqual({ @@ -239,9 +235,9 @@ describe('LobeSparkAI', () => { // 假设的测试函数调用,你可能需要根据实际情况调整 await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], - model: 'general', + model: 'max-32k', stream: true, - temperature: 0, + temperature: 0.999, }); // 验证 debugStream 被调用 diff --git a/src/libs/agent-runtime/spark/index.ts b/src/libs/agent-runtime/spark/index.ts index 8cc8dfe1e28e..95d3f3e81d45 100644 --- a/src/libs/agent-runtime/spark/index.ts +++ b/src/libs/agent-runtime/spark/index.ts @@ -1,9 +1,13 @@ import { ModelProvider } from '../types'; import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; +import { transformSparkResponseToStream, SparkAIStream } from '../utils/streams'; + export const LobeSparkAI = LobeOpenAICompatibleFactory({ baseURL: 'https://spark-api-open.xf-yun.com/v1', chatCompletion: { + handleStream: SparkAIStream, + handleTransformResponseToStream: transformSparkResponseToStream, noUserId: true, }, debug: { diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts index aa436cb668e0..d95009f484fe 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts @@ -1,10 +1,13 @@ // @vitest-environment node import OpenAI from 'openai'; +import type { Stream } from 'openai/streaming'; + import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { AgentRuntimeErrorType, ChatStreamCallbacks, + ChatStreamPayload, LobeOpenAICompatibleRuntime, ModelProvider, } from '@/libs/agent-runtime'; @@ -797,6 +800,134 @@ describe('LobeOpenAICompatibleFactory', () => { }); }); + it('should use custom stream handler when provided', async () => { + // Create a custom stream handler that handles both ReadableStream and OpenAI Stream + const customStreamHandler = vi.fn((stream: ReadableStream | Stream) => { + const readableStream = stream instanceof ReadableStream ? stream : stream.toReadableStream(); + return new ReadableStream({ + start(controller) { + const reader = readableStream.getReader(); + const process = async () => { + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + controller.enqueue(value); + } + } finally { + controller.close(); + } + }; + process(); + }, + }); + }); + + const LobeMockProvider = LobeOpenAICompatibleFactory({ + baseURL: 'https://api.test.com/v1', + chatCompletion: { + handleStream: customStreamHandler, + }, + provider: ModelProvider.OpenAI, + }); + + const instance = new LobeMockProvider({ apiKey: 'test' }); + + // Create a mock stream + const mockStream = new ReadableStream({ + start(controller) { + controller.enqueue({ + id: 'test-id', + choices: [{ delta: { content: 'Hello' }, index: 0 }], + created: Date.now(), + model: 'test-model', + object: 'chat.completion.chunk', + }); + controller.close(); + }, + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue({ + tee: () => [mockStream, mockStream], + } as any); + + const payload: ChatStreamPayload = { + messages: [{ content: 'Test', role: 'user' }], + model: 'test-model', + temperature: 0.7, + }; + + await instance.chat(payload); + + expect(customStreamHandler).toHaveBeenCalled(); + }); + + it('should use custom transform handler for non-streaming response', async () => { + const customTransformHandler = vi.fn((data: OpenAI.ChatCompletion): ReadableStream => { + return new ReadableStream({ + start(controller) { + // Transform the completion to chunk format + controller.enqueue({ + id: data.id, + choices: data.choices.map((choice) => ({ + delta: { content: choice.message.content }, + index: choice.index, + })), + created: data.created, + model: data.model, + object: 'chat.completion.chunk', + }); + controller.close(); + }, + }); + }); + + const LobeMockProvider = LobeOpenAICompatibleFactory({ + baseURL: 'https://api.test.com/v1', + chatCompletion: { + handleTransformResponseToStream: customTransformHandler, + }, + provider: ModelProvider.OpenAI, + }); + + const instance = new LobeMockProvider({ apiKey: 'test' }); + + const mockResponse: OpenAI.ChatCompletion = { + id: 'test-id', + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: 'Test response', + refusal: null + }, + logprobs: null, + finish_reason: 'stop', + }, + ], + created: Date.now(), + model: 'test-model', + object: 'chat.completion', + usage: { completion_tokens: 2, prompt_tokens: 1, total_tokens: 3 }, + }; + + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + mockResponse as any, + ); + + const payload: ChatStreamPayload = { + messages: [{ content: 'Test', role: 'user' }], + model: 'test-model', + temperature: 0.7, + stream: false, + }; + + await instance.chat(payload); + + expect(customTransformHandler).toHaveBeenCalledWith(mockResponse); + }); + describe('DEBUG', () => { it('should call debugStream and return StreamingTextResponse when DEBUG_OPENROUTER_CHAT_COMPLETION is 1', async () => { // Arrange diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index df80bf4c4a3b..814f890df875 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -25,6 +25,7 @@ import { handleOpenAIError } from '../handleOpenAIError'; import { convertOpenAIMessages } from '../openaiHelpers'; import { StreamingResponse } from '../response'; import { OpenAIStream, OpenAIStreamOptions } from '../streams'; +import { ChatStreamCallbacks } from '../../types'; // the model contains the following keywords is not a chat model, so we should filter them out export const CHAT_MODELS_BLOCK_LIST = [ @@ -62,10 +63,17 @@ interface OpenAICompatibleFactoryOptions = any> { payload: ChatStreamPayload, options: ConstructorOptions, ) => OpenAI.ChatCompletionCreateParamsStreaming; + handleStream?: ( + stream: Stream | ReadableStream, + callbacks?: ChatStreamCallbacks, + ) => ReadableStream; handleStreamBizErrorType?: (error: { message: string; name: string; }) => ILobeAgentRuntimeErrorType | undefined; + handleTransformResponseToStream?: ( + data: OpenAI.ChatCompletion, + ) => ReadableStream; noUserId?: boolean; }; constructorOptions?: ConstructorOptions; @@ -228,7 +236,8 @@ export const LobeOpenAICompatibleFactory = = any> debugStream(useForDebugStream).catch(console.error); } - return StreamingResponse(OpenAIStream(prod, streamOptions), { + const streamHandler = chatCompletion?.handleStream || OpenAIStream; + return StreamingResponse(streamHandler(prod, streamOptions), { headers: options?.headers, }); } @@ -239,9 +248,11 @@ export const LobeOpenAICompatibleFactory = = any> if (responseMode === 'json') return Response.json(response); - const stream = transformResponseToStream(response as unknown as OpenAI.ChatCompletion); + const transformHandler = chatCompletion?.handleTransformResponseToStream || transformResponseToStream; + const stream = transformHandler(response as unknown as OpenAI.ChatCompletion); - return StreamingResponse(OpenAIStream(stream, streamOptions), { + const streamHandler = chatCompletion?.handleStream || OpenAIStream; + return StreamingResponse(streamHandler(stream, streamOptions), { headers: options?.headers, }); } catch (error) { diff --git a/src/libs/agent-runtime/utils/streams/index.ts b/src/libs/agent-runtime/utils/streams/index.ts index e5518ce05221..a3ac8983d97e 100644 --- a/src/libs/agent-runtime/utils/streams/index.ts +++ b/src/libs/agent-runtime/utils/streams/index.ts @@ -7,3 +7,4 @@ export * from './ollama'; export * from './openai'; export * from './protocol'; export * from './qwen'; +export * from './spark'; diff --git a/src/libs/agent-runtime/utils/streams/spark.test.ts b/src/libs/agent-runtime/utils/streams/spark.test.ts new file mode 100644 index 000000000000..86626ce69aee --- /dev/null +++ b/src/libs/agent-runtime/utils/streams/spark.test.ts @@ -0,0 +1,199 @@ +import { beforeAll, describe, expect, it, vi } from 'vitest'; +import { SparkAIStream, transformSparkResponseToStream } from './spark'; +import type OpenAI from 'openai'; + +describe('SparkAIStream', () => { + beforeAll(() => {}); + + it('should transform non-streaming response to stream', async () => { + const mockResponse = { + id: "cha000ceba6@dx193d200b580b8f3532", + object: "chat.completion", + created: 1734395014, + model: "max-32k", + choices: [ + { + message: { + role: "assistant", + content: "", + refusal: null, + tool_calls: { + type: "function", + function: { + arguments: '{"city":"Shanghai"}', + name: "realtime-weather____fetchCurrentWeather" + }, + id: "call_1" + } + }, + index: 0, + logprobs: null, + finish_reason: "tool_calls" + } + ], + usage: { + prompt_tokens: 8, + completion_tokens: 0, + total_tokens: 8 + } + } as unknown as OpenAI.ChatCompletion; + + const stream = transformSparkResponseToStream(mockResponse); + const decoder = new TextDecoder(); + const chunks = []; + + // @ts-ignore + for await (const chunk of stream) { + chunks.push(chunk); + } + + expect(chunks).toHaveLength(2); + expect(chunks[0].choices[0].delta.tool_calls).toEqual([{ + function: { + arguments: '{"city":"Shanghai"}', + name: "realtime-weather____fetchCurrentWeather" + }, + id: "call_1", + index: 0, + type: "function" + }]); + expect(chunks[1].choices[0].finish_reason).toBeDefined(); + }); + + it('should transform streaming response with tool calls', async () => { + const mockStream = new ReadableStream({ + start(controller) { + controller.enqueue({ + id: "cha000b0bf9@dx193d1ffa61cb894532", + object: "chat.completion.chunk", + created: 1734395014, + model: "max-32k", + choices: [ + { + delta: { + role: "assistant", + content: "", + tool_calls: { + type: "function", + function: { + arguments: '{"city":"Shanghai"}', + name: "realtime-weather____fetchCurrentWeather" + }, + id: "call_1" + } + }, + index: 0 + } + ] + } as unknown as OpenAI.ChatCompletionChunk); + controller.close(); + } + }); + + const onToolCallMock = vi.fn(); + + const protocolStream = SparkAIStream(mockStream, { + onToolCall: onToolCallMock + }); + + const decoder = new TextDecoder(); + const chunks = []; + + // @ts-ignore + for await (const chunk of protocolStream) { + chunks.push(decoder.decode(chunk, { stream: true })); + } + + expect(chunks).toEqual([ + 'id: cha000b0bf9@dx193d1ffa61cb894532\n', + 'event: tool_calls\n', + `data: [{\"function\":{\"arguments\":\"{\\\"city\\\":\\\"Shanghai\\\"}\",\"name\":\"realtime-weather____fetchCurrentWeather\"},\"id\":\"call_1\",\"index\":0,\"type\":\"function\"}]\n\n` + ]); + + expect(onToolCallMock).toHaveBeenCalledTimes(1); + }); + + it('should handle text content in stream', async () => { + const mockStream = new ReadableStream({ + start(controller) { + controller.enqueue({ + id: "test-id", + object: "chat.completion.chunk", + created: 1734395014, + model: "max-32k", + choices: [ + { + delta: { + content: "Hello", + role: "assistant" + }, + index: 0 + } + ] + } as OpenAI.ChatCompletionChunk); + controller.enqueue({ + id: "test-id", + object: "chat.completion.chunk", + created: 1734395014, + model: "max-32k", + choices: [ + { + delta: { + content: " World", + role: "assistant" + }, + index: 0 + } + ] + } as OpenAI.ChatCompletionChunk); + controller.close(); + } + }); + + const onTextMock = vi.fn(); + + const protocolStream = SparkAIStream(mockStream, { + onText: onTextMock + }); + + const decoder = new TextDecoder(); + const chunks = []; + + // @ts-ignore + for await (const chunk of protocolStream) { + chunks.push(decoder.decode(chunk, { stream: true })); + } + + expect(chunks).toEqual([ + 'id: test-id\n', + 'event: text\n', + 'data: "Hello"\n\n', + 'id: test-id\n', + 'event: text\n', + 'data: " World"\n\n' + ]); + + expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"'); + expect(onTextMock).toHaveBeenNthCalledWith(2, '" World"'); + }); + + it('should handle empty stream', async () => { + const mockStream = new ReadableStream({ + start(controller) { + controller.close(); + } + }); + + const protocolStream = SparkAIStream(mockStream); + + const decoder = new TextDecoder(); + const chunks = []; + + // @ts-ignore + for await (const chunk of protocolStream) { + chunks.push(decoder.decode(chunk, { stream: true })); + } + + expect(chunks).toEqual([]); + }); +}); diff --git a/src/libs/agent-runtime/utils/streams/spark.ts b/src/libs/agent-runtime/utils/streams/spark.ts new file mode 100644 index 000000000000..ee74f424df31 --- /dev/null +++ b/src/libs/agent-runtime/utils/streams/spark.ts @@ -0,0 +1,134 @@ +import OpenAI from 'openai'; +import type { Stream } from 'openai/streaming'; + +import { ChatStreamCallbacks } from '../../types'; +import { + StreamProtocolChunk, + StreamProtocolToolCallChunk, + convertIterableToStream, + createCallbacksTransformer, + createSSEProtocolTransformer, + generateToolCallId, +} from './protocol'; + +export function transformSparkResponseToStream(data: OpenAI.ChatCompletion) { + return new ReadableStream({ + start(controller) { + const chunk: OpenAI.ChatCompletionChunk = { + choices: data.choices.map((choice: OpenAI.ChatCompletion.Choice) => { + const toolCallsArray = choice.message.tool_calls + ? Array.isArray(choice.message.tool_calls) + ? choice.message.tool_calls + : [choice.message.tool_calls] + : []; // 如果不是数组,包装成数组 + + return { + delta: { + content: choice.message.content, + role: choice.message.role, + tool_calls: toolCallsArray.map( + (tool, index): OpenAI.ChatCompletionChunk.Choice.Delta.ToolCall => ({ + function: tool.function, + id: tool.id, + index, + type: tool.type, + }), + ), + }, + finish_reason: null, + index: choice.index, + logprobs: choice.logprobs, + }; + }), + created: data.created, + id: data.id, + model: data.model, + object: 'chat.completion.chunk', + }; + + controller.enqueue(chunk); + + controller.enqueue({ + choices: data.choices.map((choice: OpenAI.ChatCompletion.Choice) => ({ + delta: { + content: null, + role: choice.message.role, + }, + finish_reason: choice.finish_reason, + index: choice.index, + logprobs: choice.logprobs, + })), + created: data.created, + id: data.id, + model: data.model, + object: 'chat.completion.chunk', + system_fingerprint: data.system_fingerprint, + } as OpenAI.ChatCompletionChunk); + controller.close(); + }, + }); +} + +export const transformSparkStream = (chunk: OpenAI.ChatCompletionChunk): StreamProtocolChunk => { + const item = chunk.choices[0]; + + if (!item) { + return { data: chunk, id: chunk.id, type: 'data' }; + } + + if (item.delta?.tool_calls) { + const toolCallsArray = Array.isArray(item.delta.tool_calls) + ? item.delta.tool_calls + : [item.delta.tool_calls]; // 如果不是数组,包装成数组 + + if (toolCallsArray.length > 0) { + return { + data: toolCallsArray.map((toolCall, index) => ({ + function: toolCall.function, + id: toolCall.id || generateToolCallId(index, toolCall.function?.name), + index: typeof toolCall.index !== 'undefined' ? toolCall.index : index, + type: toolCall.type || 'function', + })), + id: chunk.id, + type: 'tool_calls', + } as StreamProtocolToolCallChunk; + } + } + + if (item.finish_reason) { + // one-api 的流式接口,会出现既有 finish_reason ,也有 content 的情况 + // {"id":"demo","model":"deepl-en","choices":[{"index":0,"delta":{"role":"assistant","content":"Introduce yourself."},"finish_reason":"stop"}]} + + if (typeof item.delta?.content === 'string' && !!item.delta.content) { + return { data: item.delta.content, id: chunk.id, type: 'text' }; + } + + return { data: item.finish_reason, id: chunk.id, type: 'stop' }; + } + + if (typeof item.delta?.content === 'string') { + return { data: item.delta.content, id: chunk.id, type: 'text' }; + } + + if (item.delta?.content === null) { + return { data: item.delta, id: chunk.id, type: 'data' }; + } + + return { + data: { delta: item.delta, id: chunk.id, index: item.index }, + id: chunk.id, + type: 'data', + }; +}; + +export const SparkAIStream = ( + stream: Stream | ReadableStream, + callbacks?: ChatStreamCallbacks, +) => { + const readableStream = + stream instanceof ReadableStream ? stream : convertIterableToStream(stream); + + return readableStream + .pipeThrough(createSSEProtocolTransformer(transformSparkStream)) + .pipeThrough(createCallbacksTransformer(callbacks)); +};