Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Solution] [Elastic AI Assistant] LangChain Agents and Tools integration for ES|QL query generation via ELSER #167097

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,88 @@ describe('fetchConnectorExecuteAction', () => {

expect(result).toBe('Test response');
});

it('returns the value of the action_input property when assistantLangChain is true, and `content` has properly prefixed and suffixed JSON with the action_input property', async () => {
const content = '```json\n{"action_input": "value from action_input"}\n```';

(mockHttp.fetch as jest.Mock).mockResolvedValue({
status: 'ok',
data: {
choices: [
{
message: {
content,
},
},
],
},
});

const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true, // <-- requires response parsing
http: mockHttp,
messages,
apiConfig,
};

const result = await fetchConnectorExecuteAction(testProps);

expect(result).toBe('value from action_input');
});

it('returns the original content when assistantLangChain is true, and `content` has properly formatted JSON WITHOUT the action_input property', async () => {
const content = '```json\n{"some_key": "some value"}\n```';

(mockHttp.fetch as jest.Mock).mockResolvedValue({
status: 'ok',
data: {
choices: [
{
message: {
content,
},
},
],
},
});

const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true, // <-- requires response parsing
http: mockHttp,
messages,
apiConfig,
};

const result = await fetchConnectorExecuteAction(testProps);

expect(result).toBe(content);
});

it('returns the original when assistantLangChain is true, and `content` is not JSON', async () => {
const content = 'plain text content';

(mockHttp.fetch as jest.Mock).mockResolvedValue({
status: 'ok',
data: {
choices: [
{
message: {
content,
},
},
],
},
});

const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true, // <-- requires response parsing
http: mockHttp,
messages,
apiConfig,
};

const result = await fetchConnectorExecuteAction(testProps);

expect(result).toBe(content);
});
});
4 changes: 3 additions & 1 deletion x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { HttpSetup, IHttpFetchError } from '@kbn/core-http-browser';
import type { Conversation, Message } from '../assistant_context/types';
import { API_ERROR } from './translations';
import { MODEL_GPT_3_5_TURBO } from '../connectorland/models/model_selector/model_selector';
import { getFormattedMessageContent } from './helpers';

export interface FetchConnectorExecuteAction {
assistantLangChain: boolean;
Expand Down Expand Up @@ -78,7 +79,8 @@ export const fetchConnectorExecuteAction = async ({

if (data.choices && data.choices.length > 0 && data.choices[0].message.content) {
const result = data.choices[0].message.content.trim();
return result;

return assistantLangChain ? getFormattedMessageContent(result) : result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A client side change was required to the assistant, because the response returned from the agent executor is JSON.

Thoughts on doing this server side as to not leak the agent executor abstraction to the client? Or was there a specific reason this needed to be pushed to the client?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A client side change was required to the assistant, because the response returned from the agent executor is JSON.

Thoughts on doing this server side as to not leak the agent executor abstraction to the client? Or was there a specific reason this needed to be pushed to the client?

In summary of our offline discussion:

  • Yes, it's possible to do this server side by introducing an additional JSON.parse of the response from the LLM
  • In short, the server side code in x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts would look something like the following (psudocode):
  await executor.call({ input: latestMessage[0].content });

  const rawData = llm.getActionResultData(); // the response from the actions framework

  if (rawData.choices && rawData.choices.length > 0 && rawData.choices[0].message.content) {
    const result = rawData.choices[0].message.content.trim();

    const formatted = getFormattedMessageContent(result);
    const data = convertToRecord(formatted);

    return {
      connector_id: connectorId,
      data,
      status: 'ok',
    };
  } else {
    throw new Error('Unexpected raw response from the LLM');
  }
  • The psudocode above is very similar to the client side code in x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx
  • If server-side parsing is implemented, the existing client side code should be extracted into a reusable function (exported by the package), for reuse on the server

In summary, it's possible, but given the above adds an additional server side JSON parse of the LLM response and requires some additional client side refactoring, we'll reconsider the above post FF.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking the time to discuss offline and for summarizing here 🙏 , path forward sounds good to me 👍

} else {
return API_ERROR;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
* 2.0.
*/

import { getDefaultConnector, getBlockBotConversation } from './helpers';
import {
getBlockBotConversation,
getDefaultConnector,
getFormattedMessageContent,
} from './helpers';
import { enterpriseMessaging } from './use_conversation/sample_conversations';
import { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public';

Expand Down Expand Up @@ -190,4 +194,41 @@ describe('getBlockBotConversation', () => {
expect(result).toBeUndefined();
});
});

describe('getFormattedMessageContent', () => {
it('returns the value of the action_input property when `content` has properly prefixed and suffixed JSON with the action_input property', () => {
const content = '```json\n{"action_input": "value from action_input"}\n```';

expect(getFormattedMessageContent(content)).toBe('value from action_input');
});

it('returns the original content when `content` has properly formatted JSON WITHOUT the action_input property', () => {
const content = '```json\n{"some_key": "some value"}\n```';
expect(getFormattedMessageContent(content)).toBe(content);
});

it('returns the original content when `content` has improperly formatted JSON', () => {
const content = '```json\n{"action_input": "value from action_input",}\n```'; // <-- the trailing comma makes it invalid

expect(getFormattedMessageContent(content)).toBe(content);
});

it('returns the original content when `content` is missing the prefix', () => {
const content = '{"action_input": "value from action_input"}\n```'; // <-- missing prefix

expect(getFormattedMessageContent(content)).toBe(content);
});

it('returns the original content when `content` is missing the suffix', () => {
const content = '```json\n{"action_input": "value from action_input"}'; // <-- missing suffix

expect(getFormattedMessageContent(content)).toBe(content);
});

it('returns the original content when `content` does NOT contain a JSON string', () => {
const content = 'plain text content';

expect(getFormattedMessageContent(content)).toBe(content);
});
});
});
21 changes: 21 additions & 0 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,24 @@ export const getDefaultConnector = (
connectors: Array<ActionConnector<Record<string, unknown>, Record<string, unknown>>> | undefined
): ActionConnector<Record<string, unknown>, Record<string, unknown>> | undefined =>
connectors?.length === 1 ? connectors[0] : undefined;

/**
* When `content` is a JSON string, prefixed with "```json\n"
* and suffixed with "\n```", this function will attempt to parse it and return
* the `action_input` property if it exists.
*/
export const getFormattedMessageContent = (content: string): string => {
const formattedContentMatch = content.match(/```json\n([\s\S]+)\n```/);

if (formattedContentMatch) {
try {
const parsedContent = JSON.parse(formattedContentMatch[1]);

return parsedContent.action_input ?? content;
} catch {
// we don't want to throw an error here, so we'll fall back to the original content
}
}

return content;
};
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { ResponseBody } from '../helpers';
import { ActionsClientLlm } from '../llm/actions_client_llm';
import { mockActionResultData } from '../../../__mocks__/action_result_data';
import { langChainMessages } from '../../../__mocks__/lang_chain_messages';
import { executeCustomLlmChain } from '.';
import { callAgentExecutor } from '.';
import { loggerMock } from '@kbn/logging-mocks';
import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks';

Expand All @@ -23,11 +23,18 @@ const mockConversationChain = {
};

jest.mock('langchain/chains', () => ({
ConversationalRetrievalQAChain: {
RetrievalQAChain: {
fromLLM: jest.fn().mockImplementation(() => mockConversationChain),
},
}));

const mockCall = jest.fn();
jest.mock('langchain/agents', () => ({
initializeAgentExecutorWithOptions: jest.fn().mockImplementation(() => ({
call: mockCall,
})),
}));

const mockConnectorId = 'mock-connector-id';

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -42,7 +49,7 @@ const mockActions: ActionsPluginStart = {} as ActionsPluginStart;
const mockLogger = loggerMock.create();
const esClientMock = elasticsearchServiceMock.createScopedClusterClient().asCurrentUser;

describe('executeCustomLlmChain', () => {
describe('callAgentExecutor', () => {
beforeEach(() => {
jest.clearAllMocks();

Expand All @@ -52,7 +59,7 @@ describe('executeCustomLlmChain', () => {
});

it('creates an instance of ActionsClientLlm with the expected context from the request', async () => {
await executeCustomLlmChain({
await callAgentExecutor({
actions: mockActions,
connectorId: mockConnectorId,
esClient: esClientMock,
Expand All @@ -70,7 +77,7 @@ describe('executeCustomLlmChain', () => {
});

it('kicks off the chain with (only) the last message', async () => {
await executeCustomLlmChain({
await callAgentExecutor({
actions: mockActions,
connectorId: mockConnectorId,
esClient: esClientMock,
Expand All @@ -79,15 +86,15 @@ describe('executeCustomLlmChain', () => {
request: mockRequest,
});

expect(mockConversationChain.call).toHaveBeenCalledWith({
question: '\n\nDo you know my name?',
expect(mockCall).toHaveBeenCalledWith({
input: '\n\nDo you know my name?',
});
});

it('kicks off the chain with the expected message when langChainMessages has only one entry', async () => {
const onlyOneMessage = [langChainMessages[0]];

await executeCustomLlmChain({
await callAgentExecutor({
actions: mockActions,
connectorId: mockConnectorId,
esClient: esClientMock,
Expand All @@ -96,13 +103,13 @@ describe('executeCustomLlmChain', () => {
request: mockRequest,
});

expect(mockConversationChain.call).toHaveBeenCalledWith({
question: 'What is my name?',
expect(mockCall).toHaveBeenCalledWith({
input: 'What is my name?',
});
});

it('returns the expected response body', async () => {
const result: ResponseBody = await executeCustomLlmChain({
const result: ResponseBody = await callAgentExecutor({
actions: mockActions,
connectorId: mockConnectorId,
esClient: esClientMock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@

import { ElasticsearchClient, KibanaRequest, Logger } from '@kbn/core/server';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { initializeAgentExecutorWithOptions } from 'langchain/agents';
import { RetrievalQAChain } from 'langchain/chains';
import { BufferMemory, ChatMessageHistory } from 'langchain/memory';
import { BaseMessage } from 'langchain/schema';
import { ChainTool, Tool } from 'langchain/tools';

import { ConversationalRetrievalQAChain } from 'langchain/chains';
import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store';
import { ResponseBody } from '../helpers';
import { ActionsClientLlm } from '../llm/actions_client_llm';
import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store';
import { KNOWLEDGE_BASE_INDEX_PATTERN } from '../../../routes/knowledge_base/constants';

export const executeCustomLlmChain = async ({
export const callAgentExecutor = async ({
actions,
connectorId,
esClient,
Expand All @@ -34,31 +36,38 @@ export const executeCustomLlmChain = async ({
}): Promise<ResponseBody> => {
const llm = new ActionsClientLlm({ actions, connectorId, request, logger });

// Chat History Memory: in-memory memory, from client local storage, first message is the system prompt
const pastMessages = langChainMessages.slice(0, -1); // all but the last message
const latestMessage = langChainMessages.slice(-1); // the last message

const memory = new BufferMemory({
chatHistory: new ChatMessageHistory(pastMessages),
memoryKey: 'chat_history',
memoryKey: 'chat_history', // this is the key expected by https://github.com/langchain-ai/langchainjs/blob/a13a8969345b0f149c1ca4a120d63508b06c52a5/langchain/src/agents/initialize.ts#L166
inputKey: 'input',
outputKey: 'output',
returnMessages: true,
});

// ELSER backed ElasticsearchStore for Knowledge Base
const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger);
const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever());

const tools: Tool[] = [
new ChainTool({
name: 'esql-language-knowledge-base',
description:
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language.',
chain,
}),
];

// Chain w/ chat history memory and knowledge base retriever
const chain = ConversationalRetrievalQAChain.fromLLM(llm, esStore.asRetriever(), {
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: 'chat-conversational-react-description',
memory,
// See `qaChainOptions` from https://js.langchain.com/docs/modules/chains/popular/chat_vector_db
qaChainOptions: { type: 'stuff' },
verbose: false,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set this to true for verbose debugging per the PR description

});
await chain.call({ question: latestMessage[0].content });

// Chain w/ just knowledge base retriever
// const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever());
// await chain.call({ query: latestMessage[0].content });
await executor.call({ input: latestMessage[0].content });

// The assistant (on the client side) expects the same response returned
// from the actions framework, so we need to return the same shape of data:
return {
connector_id: connectorId,
data: llm.getActionResultData(), // the response from the actions framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jest.mock('../lib/build_response', () => ({
}));

jest.mock('../lib/langchain/execute_custom_llm_chain', () => ({
executeCustomLlmChain: jest.fn().mockImplementation(
callAgentExecutor: jest.fn().mockImplementation(
async ({
connectorId,
}: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
PostActionsConnectorExecutePathParams,
} from '../schemas/post_actions_connector_execute';
import { ElasticAssistantRequestHandlerContext } from '../types';
import { executeCustomLlmChain } from '../lib/langchain/execute_custom_llm_chain';
import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain';

export const postActionsConnectorExecuteRoute = (
router: IRouter<ElasticAssistantRequestHandlerContext>
Expand Down Expand Up @@ -53,7 +53,7 @@ export const postActionsConnectorExecuteRoute = (
// convert the assistant messages to LangChain messages:
const langChainMessages = getLangChainMessages(assistantMessages);

const langChainResponseBody = await executeCustomLlmChain({
const langChainResponseBody = await callAgentExecutor({
actions,
connectorId,
esClient,
Expand Down