Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[patch]: anthropic add tool call support new tools api #5640

2 changes: 1 addition & 1 deletion libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ function _formatContent(content: MessageContent) {
* @param messages The base messages to format as a prompt.
* @returns The formatted prompt.
*/
function _formatMessagesForAnthropic(messages: BaseMessage[]): {
export function formatMessagesForAnthropic(messages: BaseMessage[]): {
system?: string;
messages: AnthropicMessage[];
} {
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-anthropic/src/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export * from "./chat_models.js";
export * from "./output_parsers.js";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is biggest question for me - if it's inline with overall "strategy" that I can export this from Anthropic and use it in community? Any thoughts?
I only need single function from there, so that could be a minor copy-paste I assume.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's fine. Maybe we can prefix it with an underscore to show that it shouldn't be imported directly.

22 changes: 21 additions & 1 deletion libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,24 @@ import {
type BaseChatModelParams,
BaseChatModel,
LangSmithParams,
BaseChatModelCallOptions,
} from "@langchain/core/language_models/chat_models";
import { BaseLanguageModelInput } from "@langchain/core/language_models/base";
import { Runnable } from "@langchain/core/runnables";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import {
AIMessageChunk,
BaseMessage,
AIMessage,
ChatMessage,
BaseMessageChunk,
} from "@langchain/core/messages";
import {
ChatGeneration,
ChatGenerationChunk,
ChatResult,
} from "@langchain/core/outputs";
import { StructuredToolInterface } from "@langchain/core/tools";

import {
BaseBedrockInput,
Expand Down Expand Up @@ -225,6 +230,8 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS";
};

tools: (StructuredToolInterface | Record<string, unknown>)[] = [];

get lc_aliases(): Record<string, string> {
return {
model: "model_id",
Expand Down Expand Up @@ -393,7 +400,8 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
this.guardrailConfig
this.guardrailConfig,
this.tools
)
: BedrockLLMInputOutputAdapter.prepareInput(
provider,
Expand Down Expand Up @@ -602,6 +610,18 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
_combineLLMOutput() {
return {};
}

bindTools(
tools: (StructuredToolInterface | Record<string, unknown>)[],
kwargs?: Partial<BaseChatModelCallOptions>
): Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
BaseChatModelCallOptions
> {
this.tools = tools;
return this;
}
}

function isChatGenerationChunk(
Expand Down
119 changes: 40 additions & 79 deletions libs/langchain-community/src/utils/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,19 @@ import {
AIMessageChunk,
BaseMessage,
} from "@langchain/core/messages";
import { StructuredToolInterface } from "@langchain/core/tools";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { ChatGeneration, ChatGenerationChunk } from "@langchain/core/outputs";
import { zodToJsonSchema } from "zod-to-json-schema";
import {
extractToolCalls,
formatMessagesForAnthropic,
} from "@langchain/anthropic";

export type CredentialType =
| AwsCredentialIdentity
| Provider<AwsCredentialIdentity>;

function _formatImage(imageUrl: string) {
const regex = /^data:(image\/.+);base64,(.+)$/;
const match = imageUrl.match(regex);
if (match === null) {
throw new Error(
[
"Anthropic only supports base64-encoded images currently.",
"Example: ...",
].join("\n\n")
);
}
return {
type: "base64",
media_type: match[1] ?? "",
data: match[2] ?? "",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any;
}

function formatMessagesForAnthropic(messages: BaseMessage[]): {
system?: string;
messages: Record<string, unknown>[];
} {
let system: string | undefined;
if (messages.length > 0 && messages[0]._getType() === "system") {
if (typeof messages[0].content !== "string") {
throw new Error("System message content must be a string.");
}
system = messages[0].content;
}
const conversationMessages =
system !== undefined ? messages.slice(1) : messages;
const formattedMessages = conversationMessages.map((message) => {
let role;
if (message._getType() === "human") {
role = "user" as const;
} else if (message._getType() === "ai") {
role = "assistant" as const;
} else if (message._getType() === "system") {
throw new Error(
"System messages are only permitted as the first passed message."
);
} else {
throw new Error(`Message type "${message._getType()}" is not supported.`);
}
if (typeof message.content === "string") {
return {
role,
content: message.content,
};
} else {
return {
role,
content: message.content.map((contentPart) => {
if (contentPart.type === "image_url") {
let source;
if (typeof contentPart.image_url === "string") {
source = _formatImage(contentPart.image_url);
} else {
source = _formatImage(contentPart.image_url.url);
}
return {
type: "image" as const,
source,
};
} else {
return contentPart;
}
}),
};
}
});
return {
messages: formattedMessages,
system,
};
}

/**
* format messages for Cohere Command-R and CommandR+ via AWS Bedrock.
*
Expand Down Expand Up @@ -327,7 +256,8 @@ export class BedrockLLMInputOutputAdapter {
tagSuffix: string;
streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS";
}
| undefined = undefined
| undefined = undefined,
tools: (StructuredToolInterface | Record<string, unknown>)[] = []
): Dict {
const inputBody: Dict = {};

Expand All @@ -342,6 +272,21 @@ export class BedrockLLMInputOutputAdapter {
inputBody.max_tokens = maxTokens;
inputBody.temperature = temperature;
inputBody.stop_sequences = stopSequences;

if (tools.length > 0) {
inputBody.tools = tools.map((tool) => {
if (isStructuredTool(tool)) {
return {
name: tool.name,
description: tool.description,
input_schema: zodToJsonSchema(tool.schema),
};
}

return tool;
});
}
return { ...inputBody, ...modelKwargs };
} else if (provider === "cohere") {
const {
system,
Expand Down Expand Up @@ -516,10 +461,26 @@ function parseMessage(responseBody: any, asChunk?: boolean): ChatGeneration {
generationInfo,
});
} else {
// TODO: we are throwing away here the text response, as the interface of this method returns only one
const toolCalls = extractToolCalls(responseBody.content);

if (toolCalls.length > 0) {
return {
message: new AIMessage({
content: "",
additional_kwargs: { id },
tool_calls: toolCalls,
}),
text: typeof parsedContent === "string" ? parsedContent : "",
generationInfo,
};
}

return {
message: new AIMessage({
content: parsedContent,
additional_kwargs: { id },
tool_calls: toolCalls,
}),
text: typeof parsedContent === "string" ? parsedContent : "",
generationInfo,
Expand Down