Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[patch]: support stream for wenxin and zhipu chat #5573

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 105 additions & 6 deletions libs/langchain-community/src/chat_models/baiduwenxin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@ import {
BaseChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { AIMessage, BaseMessage, ChatMessage } from "@langchain/core/messages";
import { ChatGeneration, ChatResult } from "@langchain/core/outputs";
import {
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
} from "@langchain/core/messages";
import {
ChatGeneration,
ChatGenerationChunk,
ChatResult,
} from "@langchain/core/outputs";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { createStream } from "../utils/stream.js";

/**
* Type representing the role of a message in the Wenxin chat model.
Expand Down Expand Up @@ -54,6 +64,12 @@ interface ChatCompletionResponse {
usage: TokenUsage;
}

interface ChatCompletionStreamResponse extends ChatCompletionResponse {
is_end: boolean;
is_truncated: boolean;
sentence_id: number;
}

/**
* Interface defining the input to the ChatBaiduWenxin class.
*/
Expand Down Expand Up @@ -347,6 +363,13 @@ export class ChatBaiduWenxin
};
}

private _ensureMessages(messages: BaseMessage[]): WenxinMessage[] {
return messages.map((message) => ({
role: messageToWenxinRole(message),
content: message.text,
}));
}

/** @ignore */
async _generate(
messages: BaseMessage[],
Expand All @@ -366,10 +389,7 @@ export class ChatBaiduWenxin
messages = messages.filter((message) => message !== systemMessage);
params.system = systemMessage.text;
}
const messagesMapped: WenxinMessage[] = messages.map((message) => ({
role: messageToWenxinRole(message),
content: message.text,
}));
const messagesMapped = this._ensureMessages(messages);

const data = params.stream
? await new Promise<ChatCompletionResponse>((resolve, reject) => {
Expand Down Expand Up @@ -596,6 +616,85 @@ export class ChatBaiduWenxin
return this.caller.call(makeCompletionRequest);
}

private async getFullApiUrl() {
if (!this.accessToken) {
this.accessToken = await this.getAccessToken();
}
return `${this.apiUrl}?access_token=${this.accessToken}`;
}

private async *createWenxinStream(
request: ChatCompletionRequest,
signal?: AbortSignal
) {
const url = await this.getFullApiUrl();
const response = await fetch(url, {
method: "POST",
headers: {
Accept: "text/event-stream",
"Content-Type": "application/json",
},
body: JSON.stringify(request),
signal,
});

if (!response.body) {
throw new Error(
"Could not begin Wenxin stream. Please check the given URL and try again."
);
}

yield* createStream<ChatCompletionStreamResponse>(response.body);
}

async *_streamResponseChunks(
messages: BaseMessage[],
options?: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const parameters = {
...this.invocationParams(),
stream: true,
};

// Wenxin requires the system message to be put in the params, not messages array
const systemMessage = messages.find(
(message) => message._getType() === "system"
);
if (systemMessage) {
// eslint-disable-next-line no-param-reassign
messages = messages.filter((message) => message !== systemMessage);
parameters.system = systemMessage.text;
}
const messagesMapped = this._ensureMessages(messages);

const stream = await this.caller.call(async () =>
this.createWenxinStream(
{
...parameters,
messages: messagesMapped,
},
options?.signal
)
);

for await (const chunk of stream) {
const { result, is_end, id } = chunk;
yield new ChatGenerationChunk({
text: result,
message: new AIMessageChunk({ content: result }),
generationInfo: is_end
? {
is_end,
request_id: id,
usage: chunk.usage,
}
: undefined,
});
await runManager?.handleLLMNewToken(result);
}
}

_llmType() {
return "baiduwenxin";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ interface TestConfig {
shouldThrow?: boolean;
}

test.skip("Test chat.stream work fine", async () => {
const chat = new ChatBaiduWenxin({
modelName: "ERNIE-Bot",
});
const stream = await chat.stream(
`Translate "I love programming" into Chinese.`
);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(0);
});

const runTest = async ({
modelName,
config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ interface TestConfig {
shouldThrow?: boolean;
}

test.skip("Test chat.stream work fine", async () => {
const chat = new ChatZhipuAI({
modelName: "glm-3-turbo",
});
const stream = await chat.stream(
`Translate "I love programming" into Chinese.`
);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
console.log(chunks);
expect(chunks.length).toBeGreaterThan(0);
});

const runTest = async ({
modelName,
config,
Expand Down
72 changes: 71 additions & 1 deletion libs/langchain-community/src/chat_models/zhipuai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import {
AIMessage,
type BaseMessage,
ChatMessage,
AIMessageChunk,
} from "@langchain/core/messages";
import { type ChatResult } from "@langchain/core/outputs";
import { ChatGenerationChunk, type ChatResult } from "@langchain/core/outputs";
import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { getEnvironmentVariable } from "@langchain/core/utils/env";

import { encodeApiKey } from "../utils/zhipuai.js";
import { createStream } from "../utils/stream.js";

export type ZhipuMessageRole = "system" | "assistant" | "user";

Expand Down Expand Up @@ -452,6 +454,74 @@ export class ChatZhipuAI extends BaseChatModel implements ChatZhipuAIParams {
return this.caller.call(makeCompletionRequest);
}

private async *createZhipuStream(
request: ChatCompletionRequest,
signal?: AbortSignal
) {
const response = await fetch(this.apiUrl, {
method: "POST",
headers: {
Accept: "text/event-stream",
Authorization: `Bearer ${encodeApiKey(this.zhipuAIApiKey)}`,
"Content-Type": "application/json",
},
body: JSON.stringify(request),
signal,
});

if (!response.body) {
throw new Error(
"Could not begin Zhipu stream. Please check the given URL and try again."
);
}

yield* createStream<ChatCompletionResponse>(response.body);
}

async *_streamResponseChunks(
messages: BaseMessage[],
options?: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const parameters = {
...this.invocationParams(),
stream: true,
};

const messagesMapped: ZhipuMessage[] = messages.map((message) => ({
role: messageToRole(message),
content: message.content as string,
}));

const stream = await this.caller.call(async () =>
this.createZhipuStream(
{
...parameters,
messages: messagesMapped,
},
options?.signal
)
);

for await (const chunk of stream) {
const { choices, id } = chunk;
const text = choices[0]?.delta?.content ?? "";
const finished = !!choices[0]?.finish_reason;
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
generationInfo: finished
? {
finished,
request_id: id,
usage: chunk.usage,
}
: undefined,
});
await runManager?.handleLLMNewToken(text);
}
}

_llmType(): string {
return "zhipuai";
}
Expand Down
24 changes: 24 additions & 0 deletions libs/langchain-community/src/utils/stream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { IterableReadableStream } from "@langchain/core/utils/stream";

export async function* createStream<T = unknown>(
responseBody: ReadableStream<Uint8Array>
): AsyncGenerator<T> {
const stream = IterableReadableStream.fromReadableStream(responseBody);
const decoder = new TextDecoder("utf-8");
let extra = "";
for await (const chunk of stream) {
const decoded = extra + decoder.decode(chunk);
const lines = decoded.split("\n");
extra = lines.pop() || "";
for (const line of lines) {
if (!line.startsWith("data:")) {
continue;
}
try {
yield JSON.parse(line.slice("data:".length).trim());
} catch (e) {
console.warn(`Received a non-JSON parseable chunk: ${line}`);
}
}
}
}
Loading