diff --git a/docs/core_docs/docs/integrations/chat/index.mdx b/docs/core_docs/docs/integrations/chat/index.mdx index 51ba69379288..c73f748e0bd0 100644 --- a/docs/core_docs/docs/integrations/chat/index.mdx +++ b/docs/core_docs/docs/integrations/chat/index.mdx @@ -5,9 +5,6 @@ sidebar_class_name: hidden # Chat models - - - ## Features (natively supported) All ChatModels implement the Runnable interface, which comes with default implementations of all methods, ie. `invoke`, `batch`, `stream`. This gives all ChatModels basic support for invoking, streaming and batching, which by default is implemented as below: @@ -24,26 +21,26 @@ Some models in LangChain have also implemented a `withStructuredOutput()` method The table shows, for each integration, which features have been implemented with native support. Yellow circles (🟡) indicates partial support - for example, if the model supports tool calling but not tool messages for agents. -| Model | Invoke | Stream | Batch | Function Calling | Tool Calling | `withStructuredOutput()` | -| :---------------------- | :----: | :----: | :---: | :--------------: | :----------: | :----------------------: | -| BedrockChat | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatAlibabaTongyi | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| ChatAnthropic | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | -| ChatBaiduWenxin | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| ChatCloudflareWorkersAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatCohere | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatFireworks | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| ChatGoogleGenerativeAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatGoogleVertexAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatVertexAI | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | -| ChatGooglePaLM | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| ChatGroq | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | -| ChatLlamaCpp | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatMinimax | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | -| ChatMistralAI | ✅ | ❌ | ✅ | ❌ | ✅ | ✅ | -| ChatOllama | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatOpenAI | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| ChatTencentHunyuan | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatTogetherAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| ChatYandexGPT | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| ChatZhipuAI | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| Model | Invoke | Stream | Batch | Function Calling | Tool Calling | `withStructuredOutput()` | +| :---------------------- | :----: | :----: | :---: | :--------------: | :-------------------------: | :----------------------: | +| BedrockChat | ✅ | ✅ | ✅ | ❌ | 🟡 (Bedrock Anthropic only) | ❌ | +| ChatAlibabaTongyi | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| ChatAnthropic | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | +| ChatBaiduWenxin | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| ChatCloudflareWorkersAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatCohere | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatFireworks | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| ChatGoogleGenerativeAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatGoogleVertexAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatVertexAI | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | +| ChatGooglePaLM | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| ChatGroq | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | +| ChatLlamaCpp | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatMinimax | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | +| ChatMistralAI | ✅ | ❌ | ✅ | ❌ | ✅ | ✅ | +| ChatOllama | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatOpenAI | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| ChatTencentHunyuan | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatTogetherAI | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| ChatYandexGPT | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| ChatZhipuAI | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index 0dcc422f95ac..248d4a14a317 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -35,7 +35,7 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/core": "~0.2.0", + "@langchain/core": "~0.2.6", "@langchain/openai": "~0.1.0", "binary-extensions": "^2.2.0", "expr-eval": "^2.0.2", diff --git a/libs/langchain-community/src/chat_models/bedrock/index.ts b/libs/langchain-community/src/chat_models/bedrock/index.ts index d30123db303a..25e8269d0f61 100644 --- a/libs/langchain-community/src/chat_models/bedrock/index.ts +++ b/libs/langchain-community/src/chat_models/bedrock/index.ts @@ -5,7 +5,7 @@ import { import type { BaseChatModelParams } from "@langchain/core/language_models/chat_models"; -import { BaseBedrockInput } from "../../utils/bedrock.js"; +import { BaseBedrockInput } from "../../utils/bedrock/index.js"; import { BedrockChat as BaseBedrockChat } from "./web.js"; export interface BedrockChatFields diff --git a/libs/langchain-community/src/chat_models/bedrock/web.ts b/libs/langchain-community/src/chat_models/bedrock/web.ts index 1b6e0757772a..d7811fac0089 100644 --- a/libs/langchain-community/src/chat_models/bedrock/web.ts +++ b/libs/langchain-community/src/chat_models/bedrock/web.ts @@ -9,25 +9,34 @@ import { type BaseChatModelParams, BaseChatModel, LangSmithParams, + BaseChatModelCallOptions, } from "@langchain/core/language_models/chat_models"; +import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import { Runnable } from "@langchain/core/runnables"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { AIMessageChunk, BaseMessage, AIMessage, ChatMessage, + BaseMessageChunk, + isAIMessage, } from "@langchain/core/messages"; import { ChatGeneration, ChatGenerationChunk, ChatResult, } from "@langchain/core/outputs"; +import { StructuredToolInterface } from "@langchain/core/tools"; +import { isStructuredTool } from "@langchain/core/utils/function_calling"; +import { ToolCall } from "@langchain/core/messages/tool"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { BaseBedrockInput, BedrockLLMInputOutputAdapter, type CredentialType, -} from "../../utils/bedrock.js"; +} from "../../utils/bedrock/index.js"; import type { SerializedFields } from "../../load/map_keys.js"; const PRELUDE_TOTAL_LENGTH_BYTES = 4; @@ -225,6 +234,8 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS"; }; + protected _anthropicTools?: Record[]; + get lc_aliases(): Record { return { model: "model_id", @@ -306,6 +317,15 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { this.guardrailConfig = fields?.guardrailConfig; } + override invocationParams(options?: this["ParsedCallOptions"]) { + return { + tools: this._anthropicTools, + temperature: this.temperature, + max_tokens: this.maxTokens, + stop: options?.stop, + }; + } + getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { const params = this.invocationParams(options); return { @@ -323,10 +343,6 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { options: Partial, runManager?: CallbackManagerForLLMRun ): Promise { - const service = "bedrock-runtime"; - const endpointHost = - this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; - const provider = this.model.split(".")[0]; if (this.streaming) { const stream = this._streamResponseChunks(messages, options, runManager); let finalResult: ChatGenerationChunk | undefined; @@ -347,7 +363,18 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { llmOutput: finalResult.generationInfo, }; } + return this._generateNonStreaming(messages, options, runManager); + } + async _generateNonStreaming( + messages: BaseMessage[], + options: Partial, + _runManager?: CallbackManagerForLLMRun + ): Promise { + const service = "bedrock-runtime"; + const endpointHost = + this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; + const provider = this.model.split(".")[0]; const response = await this._signedFetch(messages, options, { bedrockMethod: "invoke", endpointHost, @@ -393,7 +420,8 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { this.temperature, options.stop ?? this.stopSequences, this.modelKwargs, - this.guardrailConfig + this.guardrailConfig, + this._anthropicTools ) : BedrockLLMInputOutputAdapter.prepareInput( provider, @@ -459,97 +487,145 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - const provider = this.model.split(".")[0]; - const service = "bedrock-runtime"; - - const endpointHost = - this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; - - const bedrockMethod = - provider === "anthropic" || - provider === "cohere" || - provider === "meta" || - provider === "mistral" - ? "invoke-with-response-stream" - : "invoke"; - - const response = await this._signedFetch(messages, options, { - bedrockMethod, - endpointHost, - provider, - }); - - if (response.status < 200 || response.status >= 300) { - throw Error( - `Failed to access underlying url '${endpointHost}': got ${ - response.status - } ${response.statusText}: ${await response.text()}` + if (this._anthropicTools) { + const { generations } = await this._generateNonStreaming( + messages, + options ); - } + const result = generations[0].message as AIMessage; + const toolCallChunks = result.tool_calls?.map( + (toolCall: ToolCall, index: number) => ({ + name: toolCall.name, + args: JSON.stringify(toolCall.args), + id: toolCall.id, + index, + }) + ); + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: result.content, + additional_kwargs: result.additional_kwargs, + tool_call_chunks: toolCallChunks, + }), + text: generations[0].text, + }); + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(generations[0].text); + } else { + const provider = this.model.split(".")[0]; + const service = "bedrock-runtime"; + + const endpointHost = + this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; + + const bedrockMethod = + provider === "anthropic" || + provider === "cohere" || + provider === "meta" || + provider === "mistral" + ? "invoke-with-response-stream" + : "invoke"; + + const response = await this._signedFetch(messages, options, { + bedrockMethod, + endpointHost, + provider, + }); - if ( - provider === "anthropic" || - provider === "cohere" || - provider === "meta" || - provider === "mistral" - ) { - const reader = response.body?.getReader(); - const decoder = new TextDecoder(); - for await (const chunk of this._readChunks(reader)) { - const event = this.codec.decode(chunk); - if ( - (event.headers[":event-type"] !== undefined && - event.headers[":event-type"].value !== "chunk") || - event.headers[":content-type"].value !== "application/json" - ) { - throw Error(`Failed to get event chunk: got ${chunk}`); - } - const body = JSON.parse(decoder.decode(event.body)); - if (body.message) { - throw new Error(body.message); - } - if (body.bytes !== undefined) { - const chunkResult = JSON.parse( - decoder.decode( - Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0) - ) - ); - if (this.usesMessagesApi) { - const chunk = BedrockLLMInputOutputAdapter.prepareMessagesOutput( - provider, - chunkResult + if (response.status < 200 || response.status >= 300) { + throw Error( + `Failed to access underlying url '${endpointHost}': got ${ + response.status + } ${response.statusText}: ${await response.text()}` + ); + } + + if ( + provider === "anthropic" || + provider === "cohere" || + provider === "meta" || + provider === "mistral" + ) { + const reader = response.body?.getReader(); + const decoder = new TextDecoder(); + for await (const chunk of this._readChunks(reader)) { + const event = this.codec.decode(chunk); + if ( + (event.headers[":event-type"] !== undefined && + event.headers[":event-type"].value !== "chunk") || + event.headers[":content-type"].value !== "application/json" + ) { + throw Error(`Failed to get event chunk: got ${chunk}`); + } + const body = JSON.parse(decoder.decode(event.body)); + if (body.message) { + throw new Error(body.message); + } + if (body.bytes !== undefined) { + const chunkResult = JSON.parse( + decoder.decode( + Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0) + ) ); - if (chunk === undefined) { - continue; + if (this.usesMessagesApi) { + const chunk = BedrockLLMInputOutputAdapter.prepareMessagesOutput( + provider, + chunkResult + ); + if (chunk === undefined) { + continue; + } + if ( + provider === "anthropic" && + chunk.generationInfo?.usage !== undefined + ) { + // Avoid bad aggregation in chunks, rely on final Bedrock data + delete chunk.generationInfo.usage; + } + const finalMetrics = + chunk.generationInfo?.["amazon-bedrock-invocationMetrics"]; + if ( + finalMetrics != null && + typeof finalMetrics === "object" && + isAIMessage(chunk.message) + ) { + chunk.message.usage_metadata = { + input_tokens: finalMetrics.inputTokenCount, + output_tokens: finalMetrics.outputTokenCount, + total_tokens: + finalMetrics.inputTokenCount + + finalMetrics.outputTokenCount, + }; + } + if (isChatGenerationChunk(chunk)) { + yield chunk; + } + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(chunk.text); + } else { + const text = BedrockLLMInputOutputAdapter.prepareOutput( + provider, + chunkResult + ); + yield new ChatGenerationChunk({ + text, + message: new AIMessageChunk({ content: text }), + }); + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(text); } - if (isChatGenerationChunk(chunk)) { - yield chunk; - } - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(chunk.text); - } else { - const text = BedrockLLMInputOutputAdapter.prepareOutput( - provider, - chunkResult - ); - yield new ChatGenerationChunk({ - text, - message: new AIMessageChunk({ content: text }), - }); - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(text); } } + } else { + const json = await response.json(); + const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); + yield new ChatGenerationChunk({ + text, + message: new AIMessageChunk({ content: text }), + }); + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(text); } - } else { - const json = await response.json(); - const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); - yield new ChatGenerationChunk({ - text, - message: new AIMessageChunk({ content: text }), - }); - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(text); } } @@ -602,6 +678,33 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { _combineLLMOutput() { return {}; } + + override bindTools( + tools: (StructuredToolInterface | Record)[], + _kwargs?: Partial + ): Runnable< + BaseLanguageModelInput, + BaseMessageChunk, + BaseChatModelCallOptions + > { + const provider = this.model.split(".")[0]; + if (provider !== "anthropic") { + throw new Error( + "Currently, tool calling through Bedrock is only supported for Anthropic models." + ); + } + this._anthropicTools = tools.map((tool) => { + if (isStructuredTool(tool)) { + return { + name: tool.name, + description: tool.description, + input_schema: zodToJsonSchema(tool.schema), + }; + } + return tool; + }); + return this; + } } function isChatGenerationChunk( diff --git a/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts index 2d92d70f89c9..a202570eb895 100644 --- a/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts @@ -4,7 +4,10 @@ import { test, expect } from "@jest/globals"; import { HumanMessage } from "@langchain/core/messages"; +import { AgentExecutor, createToolCallingAgent } from "langchain/agents"; +import { ChatPromptTemplate } from "@langchain/core/prompts"; import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js"; +import { TavilySearchResults } from "../../tools/tavily_search.js"; void testChatModel( "Test Bedrock chat model Generating search queries: Command-r", @@ -180,7 +183,7 @@ async function testChatModel( }); const res = await bedrock.invoke([new HumanMessage(message)]); - console.log(res); + console.log(res, res.content); expect(res).toBeDefined(); if (trace && guardrailIdentifier && guardrailVersion) { @@ -320,6 +323,41 @@ async function testChatHandleLLMNewToken( }); } +test.skip("Tool calling agent with Anthropic", async () => { + const tools = [new TavilySearchResults({ maxResults: 1 })]; + const region = process.env.BEDROCK_AWS_REGION; + const bedrock = new BedrockChatWeb({ + maxTokens: 200, + region, + model: "anthropic.claude-3-sonnet-20240229-v1:0", + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["placeholder", "{chat_history}"], + ["human", "{input}"], + ["placeholder", "{agent_scratchpad}"], + ]); + const agent = await createToolCallingAgent({ + llm: bedrock, + tools, + prompt, + }); + const agentExecutor = new AgentExecutor({ + agent, + tools, + }); + const input = "what is the current weather in SF?"; + const result = await agentExecutor.invoke({ + input, + }); + console.log(result); +}); + test.skip.each([ "amazon.titan-text-express-v1", // These models should be supported in the future diff --git a/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts index f1e1808632b9..6c796e8d2780 100644 --- a/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts @@ -13,11 +13,11 @@ class BedrockChatStandardIntegrationTests extends ChatModelIntegrationTests< const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1"; super({ Cls: BedrockChat, - chatModelHasToolCalling: false, + chatModelHasToolCalling: true, chatModelHasStructuredOutput: false, constructorArgs: { region, - model: "amazon.titan-text-express-v1", + model: "anthropic.claude-3-sonnet-20240229-v1:0", credentials: { secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY, accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID, diff --git a/libs/langchain-community/src/embeddings/bedrock.ts b/libs/langchain-community/src/embeddings/bedrock.ts index 305387007cd0..ff324e348cf7 100644 --- a/libs/langchain-community/src/embeddings/bedrock.ts +++ b/libs/langchain-community/src/embeddings/bedrock.ts @@ -3,7 +3,7 @@ import { InvokeModelCommand, } from "@aws-sdk/client-bedrock-runtime"; import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; -import type { CredentialType } from "../utils/bedrock.js"; +import type { CredentialType } from "../utils/bedrock/index.js"; /** * Interface that extends EmbeddingsParams and defines additional diff --git a/libs/langchain-community/src/llms/bedrock/index.ts b/libs/langchain-community/src/llms/bedrock/index.ts index 64f39a279671..06f2695294f2 100644 --- a/libs/langchain-community/src/llms/bedrock/index.ts +++ b/libs/langchain-community/src/llms/bedrock/index.ts @@ -1,6 +1,6 @@ import { defaultProvider } from "@aws-sdk/credential-provider-node"; import type { BaseLLMParams } from "@langchain/core/language_models/llms"; -import { BaseBedrockInput } from "../../utils/bedrock.js"; +import { BaseBedrockInput } from "../../utils/bedrock/index.js"; import { Bedrock as BaseBedrock } from "./web.js"; export class Bedrock extends BaseBedrock { diff --git a/libs/langchain-community/src/llms/bedrock/web.ts b/libs/langchain-community/src/llms/bedrock/web.ts index fbd4cc5a3594..8f7a5791cda9 100644 --- a/libs/langchain-community/src/llms/bedrock/web.ts +++ b/libs/langchain-community/src/llms/bedrock/web.ts @@ -14,7 +14,7 @@ import { BaseBedrockInput, BedrockLLMInputOutputAdapter, type CredentialType, -} from "../../utils/bedrock.js"; +} from "../../utils/bedrock/index.js"; import type { SerializedFields } from "../../load/map_keys.js"; const PRELUDE_TOTAL_LENGTH_BYTES = 4; diff --git a/libs/langchain-community/src/utils/bedrock/anthropic.ts b/libs/langchain-community/src/utils/bedrock/anthropic.ts new file mode 100644 index 000000000000..f6d37dc2e018 --- /dev/null +++ b/libs/langchain-community/src/utils/bedrock/anthropic.ts @@ -0,0 +1,221 @@ +import { + AIMessage, + BaseMessage, + HumanMessage, + MessageContent, + SystemMessage, + ToolMessage, + isAIMessage, +} from "@langchain/core/messages"; +import { ToolCall } from "@langchain/core/messages/tool"; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function extractToolCalls(content: Record[]) { + const toolCalls: ToolCall[] = []; + for (const block of content) { + if (block.type === "tool_use") { + toolCalls.push({ name: block.name, args: block.input, id: block.id }); + } + } + return toolCalls; +} + +function _formatImage(imageUrl: string) { + const regex = /^data:(image\/.+);base64,(.+)$/; + const match = imageUrl.match(regex); + if (match === null) { + throw new Error( + [ + "Anthropic only supports base64-encoded images currently.", + "Example: data:image/png;base64,/9j/4AAQSk...", + ].join("\n\n") + ); + } + return { + type: "base64", + media_type: match[1] ?? "", + data: match[2] ?? "", + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any; +} + +function _mergeMessages( + messages: BaseMessage[] +): (SystemMessage | HumanMessage | AIMessage)[] { + // Merge runs of human/tool messages into single human messages with content blocks. + const merged = []; + for (const message of messages) { + if (message._getType() === "tool") { + if (typeof message.content === "string") { + merged.push( + new HumanMessage({ + content: [ + { + type: "tool_result", + content: message.content, + tool_use_id: (message as ToolMessage).tool_call_id, + }, + ], + }) + ); + } else { + merged.push(new HumanMessage({ content: message.content })); + } + } else { + const previousMessage = merged[merged.length - 1]; + if ( + previousMessage?._getType() === "human" && + message._getType() === "human" + ) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let combinedContent: Record[]; + if (typeof previousMessage.content === "string") { + combinedContent = [{ type: "text", text: previousMessage.content }]; + } else { + combinedContent = previousMessage.content; + } + if (typeof message.content === "string") { + combinedContent.push({ type: "text", text: message.content }); + } else { + combinedContent = combinedContent.concat(message.content); + } + previousMessage.content = combinedContent; + } else { + merged.push(message); + } + } + } + return merged; +} + +export function _convertLangChainToolCallToAnthropic( + toolCall: ToolCall + // eslint-disable-next-line @typescript-eslint/no-explicit-any +): Record { + if (toolCall.id === undefined) { + throw new Error(`Anthropic requires all tool calls to have an "id".`); + } + return { + type: "tool_use", + id: toolCall.id, + name: toolCall.name, + input: toolCall.args, + }; +} + +function _formatContent(content: MessageContent) { + if (typeof content === "string") { + return content; + } else { + const contentBlocks = content.map((contentPart) => { + if (contentPart.type === "image_url") { + let source; + if (typeof contentPart.image_url === "string") { + source = _formatImage(contentPart.image_url); + } else { + source = _formatImage(contentPart.image_url.url); + } + return { + type: "image" as const, // Explicitly setting the type as "image" + source, + }; + } else if (contentPart.type === "text") { + // Assuming contentPart is of type MessageContentText here + return { + type: "text" as const, // Explicitly setting the type as "text" + text: contentPart.text, + }; + } else if ( + contentPart.type === "tool_use" || + contentPart.type === "tool_result" + ) { + // TODO: Fix when SDK types are fixed + return { + ...contentPart, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any; + } else { + throw new Error("Unsupported message content format"); + } + }); + return contentBlocks; + } +} + +export function formatMessagesForAnthropic(messages: BaseMessage[]): { + system?: string; + messages: Record[]; +} { + const mergedMessages = _mergeMessages(messages); + let system: string | undefined; + if (mergedMessages.length > 0 && mergedMessages[0]._getType() === "system") { + if (typeof messages[0].content !== "string") { + throw new Error("System message content must be a string."); + } + system = messages[0].content; + } + const conversationMessages = + system !== undefined ? mergedMessages.slice(1) : mergedMessages; + const formattedMessages = conversationMessages.map((message) => { + let role; + if (message._getType() === "human") { + role = "user" as const; + } else if (message._getType() === "ai") { + role = "assistant" as const; + } else if (message._getType() === "tool") { + role = "user" as const; + } else if (message._getType() === "system") { + throw new Error( + "System messages are only permitted as the first passed message." + ); + } else { + throw new Error(`Message type "${message._getType()}" is not supported.`); + } + if (isAIMessage(message) && !!message.tool_calls?.length) { + if (typeof message.content === "string") { + if (message.content === "") { + return { + role, + content: message.tool_calls.map( + _convertLangChainToolCallToAnthropic + ), + }; + } else { + return { + role, + content: [ + { type: "text", text: message.content }, + ...message.tool_calls.map(_convertLangChainToolCallToAnthropic), + ], + }; + } + } else { + const { content } = message; + const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) => + content.find( + (contentPart) => + contentPart.type === "tool_use" && contentPart.id === toolCall.id + ) + ); + if (hasMismatchedToolCalls) { + console.warn( + `The "tool_calls" field on a message is only respected if content is a string.` + ); + } + return { + role, + content: _formatContent(message.content), + }; + } + } else { + return { + role, + content: _formatContent(message.content), + }; + } + }); + return { + messages: formattedMessages, + system, + }; +} diff --git a/libs/langchain-community/src/utils/bedrock.ts b/libs/langchain-community/src/utils/bedrock/index.ts similarity index 87% rename from libs/langchain-community/src/utils/bedrock.ts rename to libs/langchain-community/src/utils/bedrock/index.ts index b9243679c2f5..b101ce1c38b9 100644 --- a/libs/langchain-community/src/utils/bedrock.ts +++ b/libs/langchain-community/src/utils/bedrock/index.ts @@ -4,90 +4,14 @@ import { AIMessageChunk, BaseMessage, } from "@langchain/core/messages"; +import { StructuredToolInterface } from "@langchain/core/tools"; import { ChatGeneration, ChatGenerationChunk } from "@langchain/core/outputs"; +import { extractToolCalls, formatMessagesForAnthropic } from "./anthropic.js"; export type CredentialType = | AwsCredentialIdentity | Provider; -function _formatImage(imageUrl: string) { - const regex = /^data:(image\/.+);base64,(.+)$/; - const match = imageUrl.match(regex); - if (match === null) { - throw new Error( - [ - "Anthropic only supports base64-encoded images currently.", - "Example: data:image/png;base64,/9j/4AAQSk...", - ].join("\n\n") - ); - } - return { - type: "base64", - media_type: match[1] ?? "", - data: match[2] ?? "", - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any; -} - -function formatMessagesForAnthropic(messages: BaseMessage[]): { - system?: string; - messages: Record[]; -} { - let system: string | undefined; - if (messages.length > 0 && messages[0]._getType() === "system") { - if (typeof messages[0].content !== "string") { - throw new Error("System message content must be a string."); - } - system = messages[0].content; - } - const conversationMessages = - system !== undefined ? messages.slice(1) : messages; - const formattedMessages = conversationMessages.map((message) => { - let role; - if (message._getType() === "human") { - role = "user" as const; - } else if (message._getType() === "ai") { - role = "assistant" as const; - } else if (message._getType() === "system") { - throw new Error( - "System messages are only permitted as the first passed message." - ); - } else { - throw new Error(`Message type "${message._getType()}" is not supported.`); - } - if (typeof message.content === "string") { - return { - role, - content: message.content, - }; - } else { - return { - role, - content: message.content.map((contentPart) => { - if (contentPart.type === "image_url") { - let source; - if (typeof contentPart.image_url === "string") { - source = _formatImage(contentPart.image_url); - } else { - source = _formatImage(contentPart.image_url.url); - } - return { - type: "image" as const, - source, - }; - } else { - return contentPart; - } - }), - }; - } - }); - return { - messages: formattedMessages, - system, - }; -} - /** * format messages for Cohere Command-R and CommandR+ via AWS Bedrock. * @@ -327,7 +251,8 @@ export class BedrockLLMInputOutputAdapter { tagSuffix: string; streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS"; } - | undefined = undefined + | undefined = undefined, + tools: (StructuredToolInterface | Record)[] = [] ): Dict { const inputBody: Dict = {}; @@ -342,6 +267,11 @@ export class BedrockLLMInputOutputAdapter { inputBody.max_tokens = maxTokens; inputBody.temperature = temperature; inputBody.stop_sequences = stopSequences; + + if (tools.length > 0) { + inputBody.tools = tools; + } + return { ...inputBody, ...modelKwargs }; } else if (provider === "cohere") { const { system, @@ -516,10 +446,26 @@ function parseMessage(responseBody: any, asChunk?: boolean): ChatGeneration { generationInfo, }); } else { + // TODO: we are throwing away here the text response, as the interface of this method returns only one + const toolCalls = extractToolCalls(responseBody.content); + + if (toolCalls.length > 0) { + return { + message: new AIMessage({ + content: "", + additional_kwargs: { id }, + tool_calls: toolCalls, + }), + text: typeof parsedContent === "string" ? parsedContent : "", + generationInfo, + }; + } + return { message: new AIMessage({ content: parsedContent, additional_kwargs: { id }, + tool_calls: toolCalls, }), text: typeof parsedContent === "string" ? parsedContent : "", generationInfo, diff --git a/yarn.lock b/yarn.lock index b9f4c16cfbcf..73760cf13fb3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -9098,7 +9098,7 @@ __metadata: "@gradientai/nodejs-sdk": ^1.2.0 "@huggingface/inference": ^2.6.4 "@jest/globals": ^29.5.0 - "@langchain/core": ~0.2.0 + "@langchain/core": ~0.2.6 "@langchain/openai": ~0.1.0 "@langchain/scripts": ~0.0.14 "@langchain/standard-tests": "workspace:*"