From 0b74f62a338fe1f6b281a019e71279f0d2ffb81f Mon Sep 17 00:00:00 2001 From: Pierre Gayvallet Date: Tue, 17 Dec 2024 16:13:17 +0100 Subject: [PATCH] [inference] Add cancelation support for chatComplete and output (#203108) ## Summary Fix https://github.com/elastic/kibana/issues/200757 Add cancelation support for `chatComplete` and `output`, based on an abort signal. ### Examples #### response mode ```ts import { isInferenceRequestAbortedError } from '@kbn/inference-common'; try { const abortController = new AbortController(); const chatResponse = await inferenceClient.chatComplete({ connectorId: 'some-gen-ai-connector', abortSignal: abortController.signal, messages: [{ role: MessageRole.User, content: 'Do something' }], }); } catch(e) { if(isInferenceRequestAbortedError(e)) { // request was aborted, do something } else { // was another error, do something else } } // elsewhere abortController.abort() ``` #### stream mode ```ts import { isInferenceRequestAbortedError } from '@kbn/inference-common'; const abortController = new AbortController(); const events$ = inferenceClient.chatComplete({ stream: true, connectorId: 'some-gen-ai-connector', abortSignal: abortController.signal, messages: [{ role: MessageRole.User, content: 'Do something' }], }); events$.subscribe({ next: (event) => { // do something }, error: (err) => { if(isInferenceRequestAbortedError(e)) { // request was aborted, do something } else { // was another error, do something else } } }); abortController.abort(); ``` --- .github/CODEOWNERS | 2 +- .../shared/ai-infra/inference-common/index.ts | 3 + .../inference-common/src/chat_complete/api.ts | 4 + .../ai-infra/inference-common/src/errors.ts | 44 ++++ .../inference-common/src/output/api.ts | 5 +- .../plugins/shared/inference/README.md | 69 +++++ .../common/output/create_output_api.test.ts | 22 ++ .../common/output/create_output_api.ts | 3 + .../bedrock/bedrock_claude_adapter.test.ts | 19 ++ .../bedrock/bedrock_claude_adapter.ts | 3 +- .../adapters/gemini/gemini_adapter.test.ts | 19 ++ .../adapters/gemini/gemini_adapter.ts | 3 +- .../adapters/openai/openai_adapter.test.ts | 20 ++ .../adapters/openai/openai_adapter.ts | 12 +- .../server/chat_complete/api.test.mocks.ts | 26 ++ .../server/chat_complete/api.test.ts | 237 ++++++++++++++++++ .../inference/server/chat_complete/api.ts | 35 +-- .../inference/server/chat_complete/types.ts | 1 + .../utils/handle_cancellation.test.ts | 53 ++++ .../utils/handle_cancellation.ts | 39 +++ .../server/chat_complete/utils/index.ts | 3 +- .../chat_complete/utils/inference_executor.ts | 23 +- .../inference/server/routes/chat_complete.ts | 4 + .../inference/server/test_utils/index.ts | 11 + .../server/test_utils/inference_connector.ts | 19 ++ .../test_utils/inference_connector_adapter.ts | 14 ++ .../server/test_utils/inference_executor.ts | 19 ++ 27 files changed, 688 insertions(+), 24 deletions(-) create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.mocks.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.test.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/test_utils/index.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector_adapter.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d9def6481fc9..627579d513e9 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1832,7 +1832,7 @@ packages/kbn-monaco/src/esql @elastic/kibana-esql #CC# /x-pack/plugins/global_search_providers/ @elastic/kibana-core # AppEx AI Infra -/x-pack/plugins/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai +/x-pack/platform/plugins/shared/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai /x-pack/test/functional_gen_ai/inference @elastic/appex-ai-infra # AppEx Platform Services Security diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts index 603192fb96db..134b0f02811f 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts @@ -84,11 +84,14 @@ export { type InferenceTaskErrorEvent, type InferenceTaskInternalError, type InferenceTaskRequestError, + type InferenceTaskAbortedError, createInferenceInternalError, createInferenceRequestError, + createInferenceRequestAbortedError, isInferenceError, isInferenceInternalError, isInferenceRequestError, + isInferenceRequestAbortedError, } from './src/errors'; export { truncateList } from './src/truncate_list'; diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts index cb91f4e53e8a..4e29d5f7dad0 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts @@ -93,6 +93,10 @@ export type ChatCompleteOptions< * Function calling mode, defaults to "native". */ functionCalling?: FunctionCallingMode; + /** + * Optional signal that can be used to forcefully abort the request. + */ + abortSignal?: AbortSignal; } & TToolOptions; /** diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/errors.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/errors.ts index 5a99adc4321d..472ed50e231f 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/errors.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/errors.ts @@ -13,6 +13,7 @@ import { InferenceTaskEventBase, InferenceTaskEventType } from './inference_task export enum InferenceTaskErrorCode { internalError = 'internalError', requestError = 'requestError', + abortedError = 'requestAborted', } /** @@ -46,16 +47,37 @@ export type InferenceTaskErrorEvent = InferenceTaskEventBase >; +/** + * Inference error thrown when the request was considered invalid. + * + * Some example of reasons for invalid requests would be: + * - no connector matching the provided connectorId + * - invalid connector type for the provided connectorId + */ export type InferenceTaskRequestError = InferenceTaskError< InferenceTaskErrorCode.requestError, { status: number } >; +/** + * Inference error thrown when the request was aborted. + * + * Request abortion occurs when providing an abort signal and firing it + * before the call to the LLM completes. + */ +export type InferenceTaskAbortedError = InferenceTaskError< + InferenceTaskErrorCode.abortedError, + { status: number } +>; + export function createInferenceInternalError( message = 'An internal error occurred', meta?: Record @@ -72,16 +94,38 @@ export function createInferenceRequestError( }); } +export function createInferenceRequestAbortedError(): InferenceTaskAbortedError { + return new InferenceTaskError(InferenceTaskErrorCode.abortedError, 'Request was aborted', { + status: 499, + }); +} + +/** + * Check if the given error is an {@link InferenceTaskError} + */ export function isInferenceError( error: unknown ): error is InferenceTaskError | undefined> { return error instanceof InferenceTaskError; } +/** + * Check if the given error is an {@link InferenceTaskInternalError} + */ export function isInferenceInternalError(error: unknown): error is InferenceTaskInternalError { return isInferenceError(error) && error.code === InferenceTaskErrorCode.internalError; } +/** + * Check if the given error is an {@link InferenceTaskRequestError} + */ export function isInferenceRequestError(error: unknown): error is InferenceTaskRequestError { return isInferenceError(error) && error.code === InferenceTaskErrorCode.requestError; } + +/** + * Check if the given error is an {@link InferenceTaskAbortedError} + */ +export function isInferenceRequestAbortedError(error: unknown): error is InferenceTaskAbortedError { + return isInferenceError(error) && error.code === InferenceTaskErrorCode.abortedError; +} diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/output/api.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/output/api.ts index cd90394cd67d..3ae4a6a07ee2 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/output/api.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/output/api.ts @@ -96,7 +96,10 @@ export interface OutputOptions< * Defaults to false. */ stream?: TStream; - + /** + * Optional signal that can be used to forcefully abort the request. + */ + abortSignal?: AbortSignal; /** * Optional configuration for retrying the call if an error occurs. */ diff --git a/x-pack/platform/plugins/shared/inference/README.md b/x-pack/platform/plugins/shared/inference/README.md index bba5b4cdcfc2..a52e589a9dea 100644 --- a/x-pack/platform/plugins/shared/inference/README.md +++ b/x-pack/platform/plugins/shared/inference/README.md @@ -221,6 +221,75 @@ const toolCall = toolCalls[0]; // process the tool call and eventually continue the conversation with the LLM ``` +#### Request cancellation + +Request cancellation can be done by passing an abort signal when calling the API. Firing the signal +before the request completes will cause the abortion, and the API call will throw an error. + +```ts +const abortController = new AbortController(); + +const chatResponse = await inferenceClient.chatComplete({ + connectorId: 'some-gen-ai-connector', + abortSignal: abortController.signal, + messages: [{ role: MessageRole.User, content: 'Do something' }], +}); + +// from elsewhere / before the request completes and the promise resolves: + +abortController.abort(); +``` + +The `isInferenceRequestAbortedError` helper function, exposed from `@kbn/inference-common`, can be used easily identify those errors: + +```ts +import { isInferenceRequestAbortedError } from '@kbn/inference-common'; + +try { + const abortController = new AbortController(); + const chatResponse = await inferenceClient.chatComplete({ + connectorId: 'some-gen-ai-connector', + abortSignal: abortController.signal, + messages: [{ role: MessageRole.User, content: 'Do something' }], + }); +} catch(e) { + if(isInferenceRequestAbortedError(e)) { + // request was aborted, do something + } else { + // was another error, do something else + } +} +``` + +The approach is very similar for stream mode: + +```ts +import { isInferenceRequestAbortedError } from '@kbn/inference-common'; + +const abortController = new AbortController(); +const events$ = inferenceClient.chatComplete({ + stream: true, + connectorId: 'some-gen-ai-connector', + abortSignal: abortController.signal, + messages: [{ role: MessageRole.User, content: 'Do something' }], +}); + +events$.subscribe({ + next: (event) => { + // do something + }, + error: (err) => { + if(isInferenceRequestAbortedError(e)) { + // request was aborted, do something + } else { + // was another error, do something else + } + } +}); + +abortController.abort(); +``` + ### `output` API `output` is a wrapper around the `chatComplete` API that is catered towards a specific use case: having the LLM output a structured response, based on a schema. diff --git a/x-pack/platform/plugins/shared/inference/common/output/create_output_api.test.ts b/x-pack/platform/plugins/shared/inference/common/output/create_output_api.test.ts index c65720aae2e4..d29f88009f8e 100644 --- a/x-pack/platform/plugins/shared/inference/common/output/create_output_api.test.ts +++ b/x-pack/platform/plugins/shared/inference/common/output/create_output_api.test.ts @@ -196,4 +196,26 @@ describe('createOutputApi', () => { ).toThrowError('Retry options are not supported in streaming mode'); }); }); + + it('propagates the abort signal when provided', async () => { + chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] })); + + const output = createOutputApi(chatComplete); + + const abortController = new AbortController(); + + await output({ + id: 'id', + connectorId: '.my-connector', + input: 'input message', + abortSignal: abortController.signal, + }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith( + expect.objectContaining({ + abortSignal: abortController.signal, + }) + ); + }); }); diff --git a/x-pack/platform/plugins/shared/inference/common/output/create_output_api.ts b/x-pack/platform/plugins/shared/inference/common/output/create_output_api.ts index 3e65cb283dd4..7cd7e9cad144 100644 --- a/x-pack/platform/plugins/shared/inference/common/output/create_output_api.ts +++ b/x-pack/platform/plugins/shared/inference/common/output/create_output_api.ts @@ -34,6 +34,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) { previousMessages, functionCalling, stream, + abortSignal, retry, }: DefaultOutputOptions): OutputCompositeResponse { if (stream && retry !== undefined) { @@ -52,6 +53,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) { connectorId, stream, functionCalling, + abortSignal, system, messages, ...(schema @@ -113,6 +115,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) { input, schema, system, + abortSignal, previousMessages: messages.concat( { role: MessageRole.Assistant as const, diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts index 565727b7f57f..c6114c3b09e9 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts @@ -325,5 +325,24 @@ describe('bedrockClaudeAdapter', () => { expect(tools).toEqual([]); expect(system).toEqual(addNoToolUsageDirective('some system instruction')); }); + + it('propagates the abort signal when provided', () => { + const abortController = new AbortController(); + + bedrockClaudeAdapter.chatComplete({ + logger, + executor: executorMock, + messages: [{ role: MessageRole.User, content: 'question' }], + abortSignal: abortController.signal, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(executorMock.invoke).toHaveBeenCalledWith({ + subAction: 'invokeStream', + subActionParams: expect.objectContaining({ + signal: abortController.signal, + }), + }); + }); }); }); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts index e73d9c9344c9..e34605a4c96a 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts @@ -26,7 +26,7 @@ import { processCompletionChunks } from './process_completion_chunks'; import { addNoToolUsageDirective } from './prompts'; export const bedrockClaudeAdapter: InferenceConnectorAdapter = { - chatComplete: ({ executor, system, messages, toolChoice, tools }) => { + chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => { const noToolUsage = toolChoice === ToolChoiceType.none; const subActionParams = { @@ -36,6 +36,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = { toolChoice: toolChoiceToBedrock(toolChoice), temperature: 0, stopSequences: ['\n\nHuman:'], + signal: abortSignal, }; return from( diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts index 95a46f73d5d1..5024bd1f4c87 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts @@ -402,5 +402,24 @@ describe('geminiAdapter', () => { expect(tapFn).toHaveBeenCalledWith({ chunk: 1 }); expect(tapFn).toHaveBeenCalledWith({ chunk: 2 }); }); + + it('propagates the abort signal when provided', () => { + const abortController = new AbortController(); + + geminiAdapter.chatComplete({ + logger, + executor: executorMock, + messages: [{ role: MessageRole.User, content: 'question' }], + abortSignal: abortController.signal, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(executorMock.invoke).toHaveBeenCalledWith({ + subAction: 'invokeStream', + subActionParams: expect.objectContaining({ + signal: abortController.signal, + }), + }); + }); }); }); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts index 80d043944906..aa62f7006eac 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts @@ -22,7 +22,7 @@ import { processVertexStream } from './process_vertex_stream'; import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types'; export const geminiAdapter: InferenceConnectorAdapter = { - chatComplete: ({ executor, system, messages, toolChoice, tools }) => { + chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => { return from( executor.invoke({ subAction: 'invokeStream', @@ -32,6 +32,7 @@ export const geminiAdapter: InferenceConnectorAdapter = { tools: toolsToGemini(tools), toolConfig: toolChoiceToConfig(toolChoice), temperature: 0, + signal: abortSignal, stopSequences: ['\n\nHuman:'], }, }) diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts index 48544f1bb0fb..9b7fbc388024 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts @@ -77,6 +77,7 @@ describe('openAIAdapter', () => { }; }); }); + it('correctly formats messages ', () => { openAIAdapter.chatComplete({ ...defaultArgs, @@ -254,6 +255,25 @@ describe('openAIAdapter', () => { expect(getRequest().stream).toBe(true); expect(getRequest().body.stream).toBe(true); }); + + it('propagates the abort signal when provided', () => { + const abortController = new AbortController(); + + openAIAdapter.chatComplete({ + logger, + executor: executorMock, + messages: [{ role: MessageRole.User, content: 'question' }], + abortSignal: abortController.signal, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(executorMock.invoke).toHaveBeenCalledWith({ + subAction: 'stream', + subActionParams: expect.objectContaining({ + signal: abortController.signal, + }), + }); + }); }); describe('when handling the response', () => { diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts index 49b6bb514202..0529820b1bfb 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts @@ -43,7 +43,16 @@ import { } from '../../simulated_function_calling'; export const openAIAdapter: InferenceConnectorAdapter = { - chatComplete: ({ executor, system, messages, toolChoice, tools, functionCalling, logger }) => { + chatComplete: ({ + executor, + system, + messages, + toolChoice, + tools, + functionCalling, + logger, + abortSignal, + }) => { const stream = true; const simulatedFunctionCalling = functionCalling === 'simulated'; @@ -73,6 +82,7 @@ export const openAIAdapter: InferenceConnectorAdapter = { subAction: 'stream', subActionParams: { body: JSON.stringify(request), + signal: abortSignal, stream, }, }) diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.mocks.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.mocks.ts new file mode 100644 index 000000000000..e3248b79af40 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.mocks.ts @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export const getInferenceAdapterMock = jest.fn(); + +jest.doMock('./adapters', () => { + const actual = jest.requireActual('./adapters'); + return { + ...actual, + getInferenceAdapter: getInferenceAdapterMock, + }; +}); + +export const getInferenceExecutorMock = jest.fn(); + +jest.doMock('./utils', () => { + const actual = jest.requireActual('./utils'); + return { + ...actual, + getInferenceExecutor: getInferenceExecutorMock, + }; +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.ts new file mode 100644 index 000000000000..7d557ec512fc --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/api.test.ts @@ -0,0 +1,237 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { getInferenceExecutorMock, getInferenceAdapterMock } from './api.test.mocks'; + +import { of, Subject, isObservable, toArray, firstValueFrom } from 'rxjs'; +import { loggerMock, type MockedLogger } from '@kbn/logging-mocks'; +import { httpServerMock } from '@kbn/core/server/mocks'; +import { actionsMock } from '@kbn/actions-plugin/server/mocks'; +import { + type ChatCompleteAPI, + type ChatCompletionChunkEvent, + MessageRole, +} from '@kbn/inference-common'; +import { + createInferenceConnectorAdapterMock, + createInferenceConnectorMock, + createInferenceExecutorMock, + chunkEvent, +} from '../test_utils'; +import { createChatCompleteApi } from './api'; + +describe('createChatCompleteApi', () => { + let request: ReturnType; + let logger: MockedLogger; + let actions: ReturnType; + let inferenceAdapter: ReturnType; + let inferenceConnector: ReturnType; + let inferenceExecutor: ReturnType; + + let chatComplete: ChatCompleteAPI; + + beforeEach(() => { + request = httpServerMock.createKibanaRequest(); + logger = loggerMock.create(); + actions = actionsMock.createStart(); + + chatComplete = createChatCompleteApi({ request, actions, logger }); + + inferenceAdapter = createInferenceConnectorAdapterMock(); + inferenceAdapter.chatComplete.mockReturnValue(of(chunkEvent('chunk-1'))); + getInferenceAdapterMock.mockReturnValue(inferenceAdapter); + + inferenceConnector = createInferenceConnectorMock(); + + inferenceExecutor = createInferenceExecutorMock({ connector: inferenceConnector }); + getInferenceExecutorMock.mockResolvedValue(inferenceExecutor); + }); + + afterEach(() => { + getInferenceExecutorMock.mockReset(); + getInferenceAdapterMock.mockReset(); + }); + + it('calls `getInferenceExecutor` with the right parameters', async () => { + await chatComplete({ + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + }); + + expect(getInferenceExecutorMock).toHaveBeenCalledTimes(1); + expect(getInferenceExecutorMock).toHaveBeenCalledWith({ + connectorId: 'connectorId', + request, + actions, + }); + }); + + it('calls `getInferenceAdapter` with the right parameters', async () => { + await chatComplete({ + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + }); + + expect(getInferenceAdapterMock).toHaveBeenCalledTimes(1); + expect(getInferenceAdapterMock).toHaveBeenCalledWith(inferenceConnector.type); + }); + + it('calls `inferenceAdapter.chatComplete` with the right parameters', async () => { + await chatComplete({ + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + }); + + expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1); + expect(inferenceAdapter.chatComplete).toHaveBeenCalledWith({ + messages: [{ role: MessageRole.User, content: 'question' }], + executor: inferenceExecutor, + logger, + }); + }); + + it('throws if the connector is not compatible', async () => { + getInferenceAdapterMock.mockReturnValue(undefined); + + await expect( + chatComplete({ + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + }) + ).rejects.toThrowErrorMatchingInlineSnapshot(`"Adapter for type .gen-ai not implemented"`); + }); + + describe('response mode', () => { + it('returns a promise resolving with the response', async () => { + inferenceAdapter.chatComplete.mockReturnValue( + of(chunkEvent('chunk-1'), chunkEvent('chunk-2')) + ); + + const response = await chatComplete({ + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + }); + + expect(response).toEqual({ + content: 'chunk-1chunk-2', + toolCalls: [], + }); + }); + + describe('request cancellation', () => { + it('passes the abortSignal down to `inferenceAdapter.chatComplete`', async () => { + const abortController = new AbortController(); + + await chatComplete({ + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + abortSignal: abortController.signal, + }); + + expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1); + expect(inferenceAdapter.chatComplete).toHaveBeenCalledWith({ + messages: [{ role: MessageRole.User, content: 'question' }], + executor: inferenceExecutor, + abortSignal: abortController.signal, + logger, + }); + }); + + it('throws an error when the signal is triggered', async () => { + const abortController = new AbortController(); + + const subject = new Subject(); + inferenceAdapter.chatComplete.mockReturnValue(subject.asObservable()); + + subject.next(chunkEvent('chunk-1')); + + let caughtError: any; + + const promise = chatComplete({ + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + abortSignal: abortController.signal, + }).catch((err) => { + caughtError = err; + }); + + abortController.abort(); + + await promise; + + expect(caughtError).toBeInstanceOf(Error); + expect(caughtError.message).toContain('Request was aborted'); + }); + }); + }); + + describe('stream mode', () => { + it('returns an observable of events', async () => { + inferenceAdapter.chatComplete.mockReturnValue( + of(chunkEvent('chunk-1'), chunkEvent('chunk-2')) + ); + + const events$ = chatComplete({ + stream: true, + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + }); + + expect(isObservable(events$)).toBe(true); + + const events = await firstValueFrom(events$.pipe(toArray())); + expect(events).toEqual([ + { + content: 'chunk-1', + tool_calls: [], + type: 'chatCompletionChunk', + }, + { + content: 'chunk-2', + tool_calls: [], + type: 'chatCompletionChunk', + }, + { + content: 'chunk-1chunk-2', + toolCalls: [], + type: 'chatCompletionMessage', + }, + ]); + }); + + describe('request cancellation', () => { + it('throws an error when the signal is triggered', async () => { + const abortController = new AbortController(); + + const subject = new Subject(); + inferenceAdapter.chatComplete.mockReturnValue(subject.asObservable()); + + subject.next(chunkEvent('chunk-1')); + + let caughtError: any; + + const events$ = chatComplete({ + stream: true, + connectorId: 'connectorId', + messages: [{ role: MessageRole.User, content: 'question' }], + abortSignal: abortController.signal, + }); + + events$.subscribe({ + error: (err: any) => { + caughtError = err; + }, + }); + + abortController.abort(); + + expect(caughtError).toBeInstanceOf(Error); + expect(caughtError.message).toContain('Request was aborted'); + }); + }); + }); +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/api.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/api.ts index e58c94759e16..0e58c255bd60 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/api.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/api.ts @@ -6,7 +6,7 @@ */ import { last, omit } from 'lodash'; -import { defer, switchMap, throwError } from 'rxjs'; +import { defer, switchMap, throwError, identity } from 'rxjs'; import type { Logger } from '@kbn/logging'; import type { KibanaRequest } from '@kbn/core-http-server'; import { @@ -17,9 +17,13 @@ import { ChatCompleteOptions, } from '@kbn/inference-common'; import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; -import { getConnectorById } from '../util/get_connector_by_id'; import { getInferenceAdapter } from './adapters'; -import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils'; +import { + getInferenceExecutor, + chunksIntoMessage, + streamToResponse, + handleCancellation, +} from './utils'; interface CreateChatCompleteApiOptions { request: KibanaRequest; @@ -37,18 +41,16 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo system, functionCalling, stream, + abortSignal, }: ChatCompleteOptions): ChatCompleteCompositeResponse< ToolOptions, boolean > => { - const obs$ = defer(async () => { - const actionsClient = await actions.getActionsClientWithRequest(request); - const connector = await getConnectorById({ connectorId, actionsClient }); - const executor = createInferenceExecutor({ actionsClient, connector }); - return { executor, connector }; + const inference$ = defer(async () => { + return await getInferenceExecutor({ connectorId, request, actions }); }).pipe( - switchMap(({ executor, connector }) => { - const connectorType = connector.type; + switchMap((executor) => { + const connectorType = executor.getConnector().type; const inferenceAdapter = getInferenceAdapter(connectorType); const messagesWithoutData = messages.map((message) => omit(message, 'data')); @@ -80,21 +82,20 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo tools, logger, functionCalling, + abortSignal, }); }), chunksIntoMessage({ - toolOptions: { - toolChoice, - tools, - }, + toolOptions: { toolChoice, tools }, logger, - }) + }), + abortSignal ? handleCancellation(abortSignal) : identity ); if (stream) { - return obs$; + return inference$; } else { - return streamToResponse(obs$); + return streamToResponse(inference$); } }; } diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/types.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/types.ts index 64cc542ff611..498afb9a2a17 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/types.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/types.ts @@ -29,6 +29,7 @@ export interface InferenceConnectorAdapter { messages: Message[]; system?: string; functionCalling?: FunctionCallingMode; + abortSignal?: AbortSignal; logger: Logger; } & ToolOptions ) => Observable; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.test.ts new file mode 100644 index 000000000000..7fd464a7051c --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.test.ts @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { of, Subject, toArray, firstValueFrom } from 'rxjs'; +import { InferenceTaskError, InferenceTaskErrorCode } from '@kbn/inference-common'; +import { handleCancellation } from './handle_cancellation'; + +describe('handleCancellation', () => { + it('mirrors the source when the abort signal is not triggered', async () => { + const abortController = new AbortController(); + + const source$ = of(1, 2, 3); + + const output$ = source$.pipe(handleCancellation(abortController.signal)); + + const events = await firstValueFrom(output$.pipe(toArray())); + expect(events).toEqual([1, 2, 3]); + }); + + it('causes the observable to error when the signal fires', () => { + const abortController = new AbortController(); + + const source$ = new Subject(); + + const output$ = source$.pipe(handleCancellation(abortController.signal)); + + let thrownError: any; + const values: number[] = []; + + output$.subscribe({ + next: (value) => { + values.push(value); + }, + error: (err) => { + thrownError = err; + }, + }); + + source$.next(1); + source$.next(2); + abortController.abort(); + source$.next(3); + + expect(values).toEqual([1, 2]); + expect(thrownError).toBeInstanceOf(InferenceTaskError); + expect(thrownError.code).toBe(InferenceTaskErrorCode.abortedError); + expect(thrownError.message).toContain('Request was aborted'); + }); +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.ts new file mode 100644 index 000000000000..640172b150e4 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/handle_cancellation.ts @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { OperatorFunction, Observable, Subject, takeUntil } from 'rxjs'; +import { createInferenceRequestAbortedError } from '@kbn/inference-common'; + +export function handleCancellation(abortSignal: AbortSignal): OperatorFunction { + return (source$) => { + const stop$ = new Subject(); + if (abortSignal.aborted) { + stop$.next(); + } + abortSignal.addEventListener('abort', () => { + stop$.next(); + }); + + return new Observable((subscriber) => { + return source$.pipe(takeUntil(stop$)).subscribe({ + next: (value) => { + subscriber.next(value); + }, + error: (err) => { + subscriber.error(err); + }, + complete: () => { + if (abortSignal.aborted) { + subscriber.error(createInferenceRequestAbortedError()); + } else { + subscriber.complete(); + } + }, + }); + }); + }; +} diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts index d3dc2010cba3..4314a554589d 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts @@ -6,10 +6,11 @@ */ export { - createInferenceExecutor, + getInferenceExecutor, type InferenceInvokeOptions, type InferenceInvokeResult, type InferenceExecutor, } from './inference_executor'; export { chunksIntoMessage } from './chunks_into_message'; export { streamToResponse } from './stream_to_response'; +export { handleCancellation } from './handle_cancellation'; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts index 736beb82aa68..c461e6b6cdfb 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts @@ -5,9 +5,14 @@ * 2.0. */ +import type { KibanaRequest } from '@kbn/core-http-server'; import type { ActionTypeExecutorResult } from '@kbn/actions-plugin/common'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; +import type { + ActionsClient, + PluginStartContract as ActionsPluginStart, +} from '@kbn/actions-plugin/server'; import type { InferenceConnector } from '../../../common/connectors'; +import { getConnectorById } from '../../util/get_connector_by_id'; export interface InferenceInvokeOptions { subAction: string; @@ -22,6 +27,7 @@ export type InferenceInvokeResult = ActionTypeExecutorResult InferenceConnector; invoke(params: InferenceInvokeOptions): Promise; } @@ -33,6 +39,7 @@ export const createInferenceExecutor = ({ actionsClient: ActionsClient; }): InferenceExecutor => { return { + getConnector: () => connector, async invoke({ subAction, subActionParams }): Promise { return await actionsClient.execute({ actionId: connector.connectorId, @@ -44,3 +51,17 @@ export const createInferenceExecutor = ({ }, }; }; + +export const getInferenceExecutor = async ({ + connectorId, + actions, + request, +}: { + connectorId: string; + actions: ActionsPluginStart; + request: KibanaRequest; +}) => { + const actionsClient = await actions.getActionsClientWithRequest(request); + const connector = await getConnectorById({ connectorId, actionsClient }); + return createInferenceExecutor({ actionsClient, connector }); +}; diff --git a/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts b/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts index 84e3dd57cded..06ca5381cd83 100644 --- a/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts +++ b/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts @@ -109,6 +109,9 @@ export function registerChatCompleteRoute({ .getStartServices() .then(([coreStart, pluginsStart]) => pluginsStart.actions); + const abortController = new AbortController(); + request.events.aborted$.subscribe(() => abortController.abort()); + const client = createInferenceClient({ request, actions, logger }); const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body; @@ -121,6 +124,7 @@ export function registerChatCompleteRoute({ tools, functionCalling, stream, + abortSignal: abortController.signal, }); } diff --git a/x-pack/platform/plugins/shared/inference/server/test_utils/index.ts b/x-pack/platform/plugins/shared/inference/server/test_utils/index.ts new file mode 100644 index 000000000000..2eafe20bfdca --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/test_utils/index.ts @@ -0,0 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { chunkEvent, tokensEvent, messageEvent } from './chat_complete_events'; +export { createInferenceConnectorMock } from './inference_connector'; +export { createInferenceConnectorAdapterMock } from './inference_connector_adapter'; +export { createInferenceExecutorMock } from './inference_executor'; diff --git a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts new file mode 100644 index 000000000000..af7f35115325 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { InferenceConnector, InferenceConnectorType } from '../../common/connectors'; + +export const createInferenceConnectorMock = ( + parts: Partial = {} +): InferenceConnector => { + return { + type: InferenceConnectorType.OpenAI, + name: 'Inference connector', + connectorId: 'connector-id', + ...parts, + }; +}; diff --git a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector_adapter.ts b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector_adapter.ts new file mode 100644 index 000000000000..9e2c4516f4f1 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector_adapter.ts @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { InferenceConnectorAdapter } from '../chat_complete/types'; + +export const createInferenceConnectorAdapterMock = (): jest.Mocked => { + return { + chatComplete: jest.fn(), + }; +}; diff --git a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts new file mode 100644 index 000000000000..64b5100a9db3 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { InferenceConnector } from '../../common/connectors'; +import { InferenceExecutor } from '../chat_complete/utils'; +import { createInferenceConnectorMock } from './inference_connector'; + +export const createInferenceExecutorMock = ({ + connector = createInferenceConnectorMock(), +}: { connector?: InferenceConnector } = {}): jest.Mocked => { + return { + getConnector: jest.fn().mockReturnValue(connector), + invoke: jest.fn(), + }; +};