Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor],openai[patch]: Add usage metadata to AIMessage/Chunk #5586

Merged
merged 12 commits into from
May 31, 2024
13 changes: 13 additions & 0 deletions docs/core_docs/docs/integrations/chat/openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,16 @@ You can also use the callbacks system:
### With `.generate()`

<CodeBlock language="typescript">{OpenAIGenerationInfo}</CodeBlock>

### Streaming tokens

OpenAI supports streaming token counts via an opt-in call option. This can be set by passing `{ stream_options: { include_usage: true } }`.
Setting this call option will cause the model to return an additional chunk at the end of the stream, containing the token usage.

import OpenAIStreamTokens from "@examples/models/chat/integration_openai_stream_tokens.ts";

<CodeBlock language="typescript">{OpenAIStreamTokens}</CodeBlock>

:::tip
See the LangSmith trace [here](https://smith.langchain.com/public/66bf7377-cc69-4676-91b6-25929a05e8b7/r)
:::
30 changes: 30 additions & 0 deletions examples/src/models/chat/integration_openai_stream_tokens.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { AIMessageChunk } from "@langchain/core/messages";
import { ChatOpenAI } from "@langchain/openai";

// Instantiate the model
const model = new ChatOpenAI();

const response = await model.stream("Hello, how are you?", {
// Pass the stream options
stream_options: {
include_usage: true,
},
});

// Iterate over the response, only saving the last chunk
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}

console.log(finalResult?.usage_metadata);

/*

{ input_tokens: 13, output_tokens: 30, total_tokens: 43 }

*/
58 changes: 55 additions & 3 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ import {
export type AIMessageFields = BaseMessageFields & {
tool_calls?: ToolCall[];
invalid_tool_calls?: InvalidToolCall[];
usage_metadata?: UsageMetadata;
};

/**
* Usage metadata for a message, such as token counts.
*/
export type UsageMetadata = {
/**
* The count of input (or prompt) tokens.
*/
input_tokens: number;
/**
* The count of output (or completion) tokens
*/
output_tokens: number;
/**
* The total token count
*/
total_tokens: number;
};

/**
Expand All @@ -30,6 +49,11 @@ export class AIMessage extends BaseMessage {

invalid_tool_calls?: InvalidToolCall[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

get lc_aliases(): Record<string, string> {
// exclude snake case conversion to pascal case
return {
Expand Down Expand Up @@ -94,6 +118,7 @@ export class AIMessage extends BaseMessage {
this.invalid_tool_calls =
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
}
this.usage_metadata = initParams.usage_metadata;
}

static lc_name() {
Expand Down Expand Up @@ -127,6 +152,11 @@ export class AIMessageChunk extends BaseMessageChunk {

tool_call_chunks?: ToolCallChunk[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

constructor(fields: string | AIMessageChunkFields) {
let initParams: AIMessageChunkFields;
if (typeof fields === "string") {
Expand Down Expand Up @@ -177,10 +207,11 @@ export class AIMessageChunk extends BaseMessageChunk {
// properties with initializers, so we have to check types twice.
super(initParams);
this.tool_call_chunks =
initParams?.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams?.tool_calls ?? this.tool_calls;
initParams.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams.tool_calls ?? this.tool_calls;
this.invalid_tool_calls =
initParams?.invalid_tool_calls ?? this.invalid_tool_calls;
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
this.usage_metadata = initParams.usage_metadata;
}

get lc_aliases(): Record<string, string> {
Expand Down Expand Up @@ -226,6 +257,27 @@ export class AIMessageChunk extends BaseMessageChunk {
combinedFields.tool_call_chunks = rawToolCalls;
}
}
if (
this.usage_metadata !== undefined ||
chunk.usage_metadata !== undefined
) {
const left: UsageMetadata = this.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const right: UsageMetadata = chunk.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const usage_metadata: UsageMetadata = {
input_tokens: left.input_tokens + right.input_tokens,
output_tokens: left.output_tokens + right.output_tokens,
total_tokens: left.total_tokens + right.total_tokens,
};
combinedFields.usage_metadata = usage_metadata;
}
return new AIMessageChunk(combinedFields);
}
}
22 changes: 22 additions & 0 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ export interface ChatOpenAICallOptions
promptIndex?: number;
response_format?: { type: "json_object" };
seed?: number;
stream_options?: { include_usage: boolean };
}

/**
Expand Down Expand Up @@ -553,6 +554,9 @@ export class ChatOpenAI<
tool_choice: options?.tool_choice,
response_format: options?.response_format,
seed: options?.seed,
...(options?.stream_options !== undefined
? { stream_options: options.stream_options }
: {}),
...this.modelKwargs,
};
return params;
Expand Down Expand Up @@ -586,8 +590,12 @@ export class ChatOpenAI<
};
let defaultRole: OpenAIRoleEnum | undefined;
const streamIterable = await this.completionWithRetry(params, options);
let usage: OpenAIClient.Completions.CompletionUsage | undefined;
for await (const data of streamIterable) {
const choice = data?.choices[0];
if (data.usage) {
usage = data.usage;
}
if (!choice) {
continue;
}
Expand Down Expand Up @@ -632,6 +640,20 @@ export class ChatOpenAI<
{ chunk: generationChunk }
);
}
if (usage) {
const generationChunk = new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
usage_metadata: {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
},
}),
text: "",
});
yield generationChunk;
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down
60 changes: 60 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { test, jest, expect } from "@jest/globals";
import {
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
Expand Down Expand Up @@ -767,3 +768,62 @@ test("Test ChatOpenAI token usage reporting for streaming calls", async () => {
expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed);
}
});

test("Finish reason is 'stop'", async () => {
const model = new ChatOpenAI();
const response = await model.stream("Hello, how are you?");
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}
expect(finalResult).toBeTruthy();
expect(finalResult?.response_metadata?.finish_reason).toBe("stop");
});

test("Streaming tokens can be found in usage_metadata field", async () => {
const model = new ChatOpenAI();
const response = await model.stream("Hello, how are you?", {
stream_options: {
include_usage: true,
},
});
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}
console.log({
usage_metadata: finalResult?.usage_metadata,
});
expect(finalResult).toBeTruthy();
expect(finalResult?.usage_metadata).toBeTruthy();
expect(finalResult?.usage_metadata?.input_tokens).toBeGreaterThan(0);
expect(finalResult?.usage_metadata?.output_tokens).toBeGreaterThan(0);
expect(finalResult?.usage_metadata?.total_tokens).toBeGreaterThan(0);
});

test("streaming: true tokens can be found in usage_metadata field", async () => {
const model = new ChatOpenAI({
streaming: true,
});
const response = await model.invoke("Hello, how are you?", {
stream_options: {
include_usage: true,
},
});
console.log({
usage_metadata: response?.usage_metadata,
});
expect(response).toBeTruthy();
expect(response?.usage_metadata).toBeTruthy();
expect(response?.usage_metadata?.input_tokens).toBeGreaterThan(0);
expect(response?.usage_metadata?.output_tokens).toBeGreaterThan(0);
expect(response?.usage_metadata?.total_tokens).toBeGreaterThan(0);
});
Loading