Skip to content

Commit

Permalink
✨ feat: add custom stream handle support for LobeOpenAICompatibleFact…
Browse files Browse the repository at this point in the history
…ory (#5039)

* ♻️ refactor: add function call support for Spark

* ♻️ refactor: add non-stream mode support

* ⚡️ perf: using stream mode for tools call

* ✨ feat: add `handleStream` & `handleStreamResponse` for LobeOpenAICompatibleFactory, custom stream handle

* ✨ feat: add `handleTtransformResponseToStream` for custom non-stream transform handle

* ♻️ refactor: refactor qwen to LobeOpenAICompatibleFactory, enable `enable_search` for Qwen LLM

* 🔨 chore: add unit test for LobeOpenAICompatibleFactory

* 🔨 chore: add unit test for SparkAIStream

* 🔨 chore: add unit test for Qwen & Spark

* 🐛 fix: fix Qwen param range error

* 🔨 chore: add `QwenLegacyModels` array, limit `presence_penalty`

* 🐛 fix: fix typo
  • Loading branch information
hezhijie0327 authored Dec 29, 2024
1 parent cf0e8d8 commit ea7e732
Show file tree
Hide file tree
Showing 10 changed files with 570 additions and 351 deletions.
9 changes: 3 additions & 6 deletions src/config/modelProviders/spark.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ const Spark: ModelProviderCard = {
'Spark Lite 是一款轻量级大语言模型,具备极低的延迟与高效的处理能力,完全免费开放,支持实时在线搜索功能。其快速响应的特性使其在低算力设备上的推理应用和模型微调中表现出色,为用户带来出色的成本效益和智能体验,尤其在知识问答、内容生成及搜索场景下表现不俗。',
displayName: 'Spark Lite',
enabled: true,
functionCall: false,
id: 'lite',
maxOutput: 4096,
},
Expand All @@ -20,7 +19,6 @@ const Spark: ModelProviderCard = {
'Spark Pro 是一款为专业领域优化的高性能大语言模型,专注数学、编程、医疗、教育等多个领域,并支持联网搜索及内置天气、日期等插件。其优化后模型在复杂知识问答、语言理解及高层次文本创作中展现出色表现和高效性能,是适合专业应用场景的理想选择。',
displayName: 'Spark Pro',
enabled: true,
functionCall: false,
id: 'generalv3',
maxOutput: 8192,
},
Expand All @@ -30,7 +28,6 @@ const Spark: ModelProviderCard = {
'Spark Pro 128K 配置了特大上下文处理能力,能够处理多达128K的上下文信息,特别适合需通篇分析和长期逻辑关联处理的长文内容,可在复杂文本沟通中提供流畅一致的逻辑与多样的引用支持。',
displayName: 'Spark Pro 128K',
enabled: true,
functionCall: false,
id: 'pro-128k',
maxOutput: 4096,
},
Expand All @@ -40,7 +37,7 @@ const Spark: ModelProviderCard = {
'Spark Max 为功能最为全面的版本,支持联网搜索及众多内置插件。其全面优化的核心能力以及系统角色设定和函数调用功能,使其在各种复杂应用场景中的表现极为优异和出色。',
displayName: 'Spark Max',
enabled: true,
functionCall: false,
functionCall: true,
id: 'generalv3.5',
maxOutput: 8192,
},
Expand All @@ -50,7 +47,7 @@ const Spark: ModelProviderCard = {
'Spark Max 32K 配置了大上下文处理能力,更强的上下文理解和逻辑推理能力,支持32K tokens的文本输入,适用于长文档阅读、私有知识问答等场景',
displayName: 'Spark Max 32K',
enabled: true,
functionCall: false,
functionCall: true,
id: 'max-32k',
maxOutput: 8192,
},
Expand All @@ -60,7 +57,7 @@ const Spark: ModelProviderCard = {
'Spark Ultra 是星火大模型系列中最为强大的版本,在升级联网搜索链路同时,提升对文本内容的理解和总结能力。它是用于提升办公生产力和准确响应需求的全方位解决方案,是引领行业的智能产品。',
displayName: 'Spark 4.0 Ultra',
enabled: true,
functionCall: false,
functionCall: true,
id: '4.0Ultra',
maxOutput: 8192,
},
Expand Down
201 changes: 13 additions & 188 deletions src/libs/agent-runtime/qwen/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import Qwen from '@/config/modelProviders/qwen';
import { AgentRuntimeErrorType, ModelProvider } from '@/libs/agent-runtime';
import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
import { ModelProvider } from '@/libs/agent-runtime';
import { AgentRuntimeErrorType } from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeQwenAI } from './index';
Expand All @@ -16,7 +17,7 @@ const invalidErrorType = AgentRuntimeErrorType.InvalidProviderAPIKey;
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

let instance: LobeQwenAI;
let instance: LobeOpenAICompatibleRuntime;

beforeEach(() => {
instance = new LobeQwenAI({ apiKey: 'test' });
Expand All @@ -40,183 +41,7 @@ describe('LobeQwenAI', () => {
});
});

describe('models', () => {
it('should correctly list available models', async () => {
const instance = new LobeQwenAI({ apiKey: 'test_api_key' });
vi.spyOn(instance, 'models').mockResolvedValue(Qwen.chatModels);

const models = await instance.models();
expect(models).toEqual(Qwen.chatModels);
});
});

describe('chat', () => {
describe('Params', () => {
it('should call llms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
top_p: 0.7,
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: true,
top_p: 0.7,
result_format: 'message',
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

it('should call vlms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
temperature: 0.6,
top_p: 0.7,
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
stream: true,
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

it('should transform non-streaming response to stream correctly', async () => {
const mockResponse = {
id: 'chatcmpl-fc539f49-51a8-94be-8061',
object: 'chat.completion',
created: 1719901794,
model: 'qwen-turbo',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Hello' },
finish_reason: 'stop',
logprobs: null,
},
],
} as OpenAI.ChatCompletion;
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
mockResponse as any,
);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: false,
});

const decoder = new TextDecoder();
const reader = result.body!.getReader();
const stream: string[] = [];

while (true) {
const { value, done } = await reader.read();
if (done) break;
stream.push(decoder.decode(value));
}

expect(stream).toEqual([
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
'event: text\n',
'data: "Hello"\n\n',
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
'event: stop\n',
'data: "stop"\n\n',
]);

expect((await reader.read()).done).toBe(true);
});

it('should set temperature to undefined if temperature is 0 or >= 2', async () => {
const temperatures = [0, 2, 3];
const expectedTemperature = undefined;

for (const temp of temperatures) {
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: temp,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'qwen-turbo',
temperature: expectedTemperature,
}),
expect.any(Object),
);
}
});

it('should set temperature to original temperature', async () => {
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 1.5,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'qwen-turbo',
temperature: 1.5,
}),
expect.any(Object),
);
});

it('should set temperature to Float', async () => {
const createMock = vi.fn().mockResolvedValue(new ReadableStream() as any);
vi.spyOn(instance['client'].chat.completions, 'create').mockImplementation(createMock);
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 1,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'qwen-turbo',
temperature: expect.any(Number),
}),
expect.any(Object),
);
const callArgs = createMock.mock.calls[0][0];
expect(Number.isInteger(callArgs.temperature)).toBe(false); // Temperature is always not an integer
});
});

describe('Error', () => {
it('should return QwenBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
Expand All @@ -238,7 +63,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
Expand Down Expand Up @@ -278,7 +103,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
Expand All @@ -304,7 +129,8 @@ describe('LobeQwenAI', () => {

instance = new LobeQwenAI({
apiKey: 'test',
baseURL: defaultBaseURL,

baseURL: 'https://api.abc.com/v1',
});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
Expand All @@ -313,13 +139,12 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
expect(e).toEqual({
/* Desensitizing is unnecessary for a public-accessible gateway endpoint. */
endpoint: defaultBaseURL,
endpoint: 'https://api.***.com/v1',
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
Expand All @@ -339,7 +164,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
Expand All @@ -362,7 +187,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
Expand Down Expand Up @@ -410,7 +235,7 @@ describe('LobeQwenAI', () => {
// 假设的测试函数调用,你可能需要根据实际情况调整
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
stream: true,
temperature: 0.999,
});
Expand Down
Loading

0 comments on commit ea7e732

Please sign in to comment.