From 81a010c9f13012afa35207a9e5558e34d71799d3 Mon Sep 17 00:00:00 2001 From: Jean Philippe Wan Date: Wed, 28 Aug 2024 12:15:12 -0400 Subject: [PATCH] Passing state to tool so that we can use them in custom tools --- .../sequentialagents/ToolNode/ToolNode.ts | 46 +++++++++++++++++-- packages/components/src/Interface.ts | 4 ++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts b/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts index 44cefc8794f..45ef49c1b8f 100644 --- a/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts +++ b/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts @@ -1,5 +1,14 @@ import { flatten } from 'lodash' -import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams, ISeqAgentNode, IUsedTool } from '../../../src/Interface' +import { + ICommonObject, + IDatabaseEntity, + INode, + INodeData, + INodeParams, + ISeqAgentNode, + IUsedTool, + IStateWithMessages +} from '../../../src/Interface' import { AIMessage, AIMessageChunk, BaseMessage, ToolMessage } from '@langchain/core/messages' import { StructuredTool } from '@langchain/core/tools' import { RunnableConfig } from '@langchain/core/runnables' @@ -9,6 +18,7 @@ import { DataSource } from 'typeorm' import { MessagesState, RunnableCallable, customGet, getVM } from '../commonUtils' import { getVars, prepareSandboxVars } from '../../../src/utils' import { ChatPromptTemplate } from '@langchain/core/prompts' +import { DynamicStructuredTool } from '../../tools/CustomTool/core' const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed` @@ -350,7 +360,7 @@ class ToolNode_SeqAgents implements INode { } } -class ToolNode extends RunnableCallable { +class ToolNode extends RunnableCallable { tools: StructuredTool[] nodeData: INodeData inputQuery: string @@ -372,19 +382,45 @@ class ToolNode extends RunnableCallable this.options = options } - private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise { - const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1] + private async run(input: T, config: RunnableConfig): Promise { + let messages: BaseMessage[] + + // Check if input is an array of BaseMessage[] + if (Array.isArray(input)) { + messages = input + } + // Check if input is IStateWithMessages + else if ((input as IStateWithMessages).messages) { + messages = (input as IStateWithMessages).messages + } + // Handle MessagesState type + else { + messages = (input as MessagesState).messages + } + + // Get the last message + const message = messages[messages.length - 1] if (message._getType() !== 'ai') { throw new Error('ToolNode only accepts AIMessages as input.') } + // Extract all properties except messages for IStateWithMessages + const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input + const ChannelsWithoutMessages = { + state: inputWithoutMessages + } + const outputs = await Promise.all( (message as AIMessage).tool_calls?.map(async (call) => { const tool = this.tools.find((tool) => tool.name === call.name) if (tool === undefined) { throw new Error(`Tool ${call.name} not found.`) } + if (tool && tool instanceof DynamicStructuredTool) { + // @ts-ignore + tool.setFlowObject(ChannelsWithoutMessages) + } let output = await tool.invoke(call.args, config) let sourceDocuments: Document[] = [] if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) { @@ -436,7 +472,7 @@ const getReturnOutput = async ( input: string, options: ICommonObject, outputs: ToolMessage[], - state: BaseMessage[] | MessagesState + state: ICommonObject ) => { const appDataSource = options.appDataSource as DataSource const databaseEntities = options.databaseEntities as IDatabaseEntity diff --git a/packages/components/src/Interface.ts b/packages/components/src/Interface.ts index 6b687fc77a0..c3d2a72d409 100644 --- a/packages/components/src/Interface.ts +++ b/packages/components/src/Interface.ts @@ -396,3 +396,7 @@ export interface IVisionChatModal { revertToOriginalModel(): void setMultiModalOption(multiModalOption: IMultiModalOption): void } +export interface IStateWithMessages extends ICommonObject { + messages: BaseMessage[] + [key: string]: any +}