Skip to content

Commit

Permalink
[Security Assistant] AI Assistant - Better Solution for OSS models (e…
Browse files Browse the repository at this point in the history
…lastic#10416) (elastic#194166)

(cherry picked from commit 1ee648d)
  • Loading branch information
e40pud committed Oct 7, 2024
1 parent 46fd6db commit a813812
Show file tree
Hide file tree
Showing 22 changed files with 452 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ import { useCallback, useRef, useState } from 'react';
import { ApiConfig, Replacements } from '@kbn/elastic-assistant-common';
import { useAssistantContext } from '../../assistant_context';
import { fetchConnectorExecuteAction, FetchConnectorExecuteResponse } from '../api';
import * as i18n from './translations';

/**
* TODO: This is a workaround to solve the issue with the long standing server tasks while cahtting with the assistant.
* Some models (like Llama 3.1 70B) can perform poorly and be slow which leads to a long time to handle the request.
* The `core-http-browser` has a timeout of two minutes after which it will re-try the request. In combination with the slow model it can lead to
* a situation where core http client will initiate same request again and again.
* To avoid this, we abort http request after timeout which is slightly below two minutes.
*/
const EXECUTE_ACTION_TIMEOUT = 110 * 1000; // in milliseconds

interface SendMessageProps {
apiConfig: ApiConfig;
Expand Down Expand Up @@ -38,6 +48,11 @@ export const useSendMessage = (): UseSendMessage => {
async ({ apiConfig, http, message, conversationId, replacements }: SendMessageProps) => {
setIsLoading(true);

const timeoutId = setTimeout(() => {
abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR);
abortController.current = new AbortController();
}, EXECUTE_ACTION_TIMEOUT);

try {
return await fetchConnectorExecuteAction({
conversationId,
Expand All @@ -52,6 +67,7 @@ export const useSendMessage = (): UseSendMessage => {
traceOptions,
});
} finally {
clearTimeout(timeoutId);
setIsLoading(false);
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { i18n } from '@kbn/i18n';

export const FETCH_MESSAGE_TIMEOUT_ERROR = i18n.translate(
'xpack.elasticAssistant.assistant.useSendMessage.fetchMessageTimeoutError',
{
defaultMessage: 'Assistant could not respond in time. Please try again later.',
}
);
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface AgentExecutorParams<T extends boolean> {
esClient: ElasticsearchClient;
langChainMessages: BaseMessage[];
llmType?: string;
isOssModel?: boolean;
logger: Logger;
inference: InferenceServerStart;
onNewReplacements?: (newReplacements: Replacements) => void;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
isOssModel: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
conversation: {
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ interface StreamGraphParams {
assistantGraph: DefaultAssistantGraph;
inputs: GraphInputs;
logger: Logger;
isOssModel?: boolean;
onLlmResponse?: OnLlmResponse;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
traceOptions?: TraceOptions;
Expand All @@ -36,6 +37,7 @@ interface StreamGraphParams {
* @param assistantGraph
* @param inputs
* @param logger
* @param isOssModel
* @param onLlmResponse
* @param request
* @param traceOptions
Expand All @@ -45,6 +47,7 @@ export const streamGraph = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,
Expand Down Expand Up @@ -80,8 +83,8 @@ export const streamGraph = async ({
};

if (
(inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') &&
inputs?.bedrockChatEnabled
inputs.isOssModel ||
((inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && inputs?.bedrockChatEnabled)
) {
const stream = await assistantGraph.streamEvents(
inputs,
Expand All @@ -92,7 +95,9 @@ export const streamGraph = async ({
version: 'v2',
streamMode: 'values',
},
inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
inputs.isOssModel || inputs?.llmType === 'bedrock'
? { includeNames: ['Summarizer'] }
: undefined
);

for await (const { event, data, tags } of stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
inference,
langChainMessages,
llmType,
isOssModel,
logger: parentLogger,
isStream = false,
onLlmResponse,
Expand All @@ -48,7 +49,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
responseLanguage = 'English',
}) => {
const logger = parentLogger.get('defaultAssistantGraph');
const isOpenAI = llmType === 'openai';
const isOpenAI = llmType === 'openai' && !isOssModel;
const llmClass = getLlmClass(llmType, bedrockChatEnabled);

/**
Expand Down Expand Up @@ -111,7 +112,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
};

const tools: StructuredTool[] = assistantTools.flatMap(
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? []
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? []
);

// If KB enabled, fetch for any KB IndexEntries and generate a tool for each
Expand Down Expand Up @@ -166,6 +167,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
conversationId,
llmType,
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
};

Expand All @@ -175,6 +177,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ interface ModelInputParams extends NodeParamsBase {
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);

const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock';
const hasRespondStep =
state.isStream &&
(state.isOssModel || (state.bedrockChatEnabled && state.llmType === 'bedrock'));

return {
hasRespondStep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,59 @@ const KB_CATCH =
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;

export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. You have access to the following tools:
{tools}
The tool action_input should ALWAYS follow the tool JSON schema args.
Valid "action" values: "Final Answer" or {tool_names}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
Provide only ONE action per $JSON_BLOB, as shown:
\`\`\`
{{
"action": $TOOL_NAME,
"action_input": $TOOL_INPUT
}}
\`\`\`
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
\`\`\`
$JSON_BLOB
\`\`\`
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
\`\`\`
{{
"action": "Final Answer",
"action_input": "Final response to human"}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
DEFAULT_SYSTEM_PROMPT,
GEMINI_SYSTEM_PROMPT,
GEMINI_USER_PROMPT,
STRUCTURED_SYSTEM_PROMPT,
} from './nodes/translations';

export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
Expand All @@ -26,61 +27,7 @@ export const systemPrompts = {
bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`,
// The default prompt overwhelms gemini, do not prepend
gemini: GEMINI_SYSTEM_PROMPT,
structuredChat: `Respond to the human as helpfully and accurately as possible. You have access to the following tools:
{tools}
The tool action_input should ALWAYS follow the tool JSON schema args.
Valid "action" values: "Final Answer" or {tool_names}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
Provide only ONE action per $JSON_BLOB, as shown:
\`\`\`
{{
"action": $TOOL_NAME,
"action_input": $TOOL_INPUT
}}
\`\`\`
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
\`\`\`
$JSON_BLOB
\`\`\`
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
\`\`\`
{{
"action": "Final Answer",
"action_input": "Final response to human"}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`,
structuredChat: STRUCTURED_SYSTEM_PROMPT,
};

export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export interface GraphInputs {
conversationId?: string;
llmType?: string;
isStream?: boolean;
isOssModel?: boolean;
input: string;
responseLanguage?: string;
}
Expand All @@ -31,6 +32,7 @@ export interface AgentState extends AgentStateBase {
lastNode: string;
hasRespondStep: boolean;
isStream: boolean;
isOssModel: boolean;
bedrockChatEnabled: boolean;
llmType: string;
responseLanguage: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
} from '../helpers';
import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers';
import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types';
import { isOpenSourceModel } from '../utils';

export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => {
return `CONTEXT:\n"""\n${context}\n"""`;
Expand Down Expand Up @@ -99,7 +100,9 @@ export const chatCompleteRoute = (
const actions = ctx.elasticAssistant.actions;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai';
const connector = connectors.length > 0 ? connectors[0] : undefined;
actionTypeId = connector?.actionTypeId ?? '.gen-ai';
const isOssModel = isOpenSourceModel(connector);

// replacements
const anonymizationFieldsRes =
Expand Down Expand Up @@ -192,6 +195,7 @@ export const chatCompleteRoute = (
actionsClient,
actionTypeId,
connectorId,
isOssModel,
conversationId: conversationId ?? newConversation?.id,
context: ctx,
getElser,
Expand Down
Loading

0 comments on commit a813812

Please sign in to comment.