Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
lwshen committed Apr 19, 2024
1 parent 3cef44d commit 21f90f2
Showing 1 changed file with 226 additions and 0 deletions.
226 changes: 226 additions & 0 deletions src/libs/agent-runtime/minimax/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// @vitest-environment edge-runtime
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { ChatStreamPayload, ModelProvider } from '@/libs/agent-runtime';
import * as fetchSSEModule from '@/utils/fetch';

import { LobeMinimaxAI } from './index';

const provider = ModelProvider.Minimax;
const bizErrorType = 'MinimaxBizError';
const invalidErrorType = 'InvalidMinimaxAPIKey';

// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

let instance: LobeMinimaxAI;

beforeEach(() => {
instance = new LobeMinimaxAI({ apiKey: 'test' });

// 使用 vi.spyOn 来模拟 fetchSSE 方法
vi.spyOn(fetchSSEModule, 'fetchSSE').mockResolvedValue(new Response());
});

afterEach(() => {
vi.clearAllMocks();
});

describe('LobeMinimaxAI', () => {
describe('init', () => {
it('should correctly initialize with an API key', async () => {
const instance = new LobeMinimaxAI({ apiKey: 'test_api_key' });
expect(instance).toBeInstanceOf(LobeMinimaxAI);
});
});

describe('chat', () => {
it('should return a StreamingTextResponse on successful API call', async () => {
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue('Hello, world!');
controller.close();
},
});
vi.spyOn(fetchSSEModule, 'fetchSSE').mockResolvedValue(new Response(mockStream));

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});

expect(result).toBeInstanceOf(Response);
});

it('should handle text messages correctly', async () => {
const mockResponseData = {
choices: [{ delta: { content: 'Hello, world!' } }],
};
vi.spyOn(fetchSSEModule, 'fetchSSE').mockResolvedValue(
new Response(
new ReadableStream({
start(controller) {
controller.enqueue(JSON.stringify(mockResponseData));
controller.close();
},
}),
),
);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});

expect(fetchSSEModule.fetchSSE).toHaveBeenCalledWith(expect.any(Function), {
onFinish: expect.any(Function),
onMessageHandle: expect.any(Function),
});
expect(result).toBeInstanceOf(Response);
});

describe('Error', () => {
it('should throw InvalidMinimaxAPIKey error on API_KEY_INVALID error', async () => {
const mockErrorResponse = {
base_resp: {
status_code: 1004,
status_msg: 'API key not valid',
},
};
vi.spyOn(fetchSSEModule, 'fetchSSE').mockResolvedValue(
new Response(
new ReadableStream({
start(controller) {
controller.enqueue(JSON.stringify(mockErrorResponse));
controller.close();
},
}),
),
);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
errorType: invalidErrorType,
error: {
code: 1004,
message: 'API key not valid',
},
provider,
});
}
});

it('should throw MinimaxBizError error on other error status codes', async () => {
const mockErrorResponse = {
base_resp: {
status_code: 1001,
status_msg: 'Some error occurred',
},
};
vi.spyOn(fetchSSEModule, 'fetchSSE').mockResolvedValue(
new Response(
new ReadableStream({
start(controller) {
controller.enqueue(JSON.stringify(mockErrorResponse));
controller.close();
},
}),
),
);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
errorType: bizErrorType,
error: {
code: 1001,
message: 'Some error occurred',
},
provider,
});
}
});

it('should throw MinimaxBizError error on generic errors', async () => {
const mockError = new Error('Something went wrong');
vi.spyOn(fetchSSEModule, 'fetchSSE').mockRejectedValue(mockError);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
errorType: bizErrorType,
error: {
cause: undefined,
message: 'Something went wrong',
name: 'Error',
stack: mockError.stack,
},
provider,
});
}
});
});
});

describe('private methods', () => {
describe('buildCompletionsParams', () => {
it('should build the correct parameters', () => {
const payload: ChatStreamPayload = {
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0.5,
top_p: 0.8,
max_tokens: 100,
};

const result = instance['buildCompletionsParams'](payload);

expect(result).toEqual({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
stream: true,
temperature: 0.5,
top_p: 0.8,
max_tokens: 100,
});
});

it('should exclude temperature and top_p when they are 0', () => {
const payload: ChatStreamPayload = {
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
top_p: 0,
max_tokens: 100,
};

const result = instance['buildCompletionsParams'](payload);

expect(result).toEqual({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
stream: true,
max_tokens: 100,
});
});
});
});
});

0 comments on commit 21f90f2

Please sign in to comment.