Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI Connector] Change completion subAction schema to be OpenAI compatible #200249

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f04ab4c
[AI Connector] Change completion subAction schema to be OpenAI compat…
YulNaumenko Nov 14, 2024
af94594
-
YulNaumenko Nov 21, 2024
697c73c
Merge branch 'main' into ai-connector-inference-completion-openai
YulNaumenko Dec 2, 2024
7428e6f
Merge branch 'main' into ai-connector-inference-completion-openai
YulNaumenko Dec 5, 2024
f17c678
Merge branch 'ai-connector-inference-completion-openai' of github.com…
YulNaumenko Dec 5, 2024
2847a11
added unified completion support
YulNaumenko Dec 6, 2024
3a03150
-
YulNaumenko Dec 6, 2024
f14b3ec
Merge branch 'main' into ai-connector-inference-completion-openai
YulNaumenko Dec 10, 2024
bf34578
added dashboard and async iterator
YulNaumenko Dec 11, 2024
f3b6a75
fixed headers
YulNaumenko Dec 12, 2024
72041b5
fixed params
YulNaumenko Dec 12, 2024
9cec4cb
Merge branch 'main' into ai-connector-inference-completion-openai
YulNaumenko Dec 12, 2024
79c5c3b
made the regular stream and non-stream working
YulNaumenko Dec 16, 2024
27512b5
Merge branch 'main' into ai-connector-inference-completion-openai
YulNaumenko Dec 16, 2024
8ce2707
Merge branch 'ai-connector-inference-completion-openai' of github.com…
YulNaumenko Dec 16, 2024
cc7993b
merge fix
YulNaumenko Dec 16, 2024
166a22b
Merge remote-tracking branch 'upstream/main' into ai-connector-infere…
YulNaumenko Dec 16, 2024
e0ee923
tool calls fix
YulNaumenko Dec 16, 2024
042e813
improved
YulNaumenko Dec 16, 2024
6ded610
-
YulNaumenko Dec 16, 2024
36a6f67
streaming
YulNaumenko Dec 17, 2024
a70051b
fixed test
YulNaumenko Dec 17, 2024
f811503
fixed streaming
YulNaumenko Dec 17, 2024
2d2ebde
Merge remote-tracking branch 'upstream/main' into ai-connector-infere…
YulNaumenko Dec 17, 2024
232fbe9
fixed due to comments
YulNaumenko Dec 17, 2024
ed6107b
excluded n
YulNaumenko Dec 17, 2024
4ccb8f3
fixed tests
YulNaumenko Dec 18, 2024
699b0a9
Merge branch 'main' into ai-connector-inference-completion-openai
YulNaumenko Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,7 @@ export const getGenAiTokenTracking = async ({
};

export const shouldTrackGenAiToken = (actionTypeId: string) =>
actionTypeId === '.gen-ai' || actionTypeId === '.bedrock' || actionTypeId === '.gemini';
actionTypeId === '.gen-ai' ||
actionTypeId === '.bedrock' ||
actionTypeId === '.gemini' ||
actionTypeId === '.inference';
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ export enum ServiceProviderKeys {

export const INFERENCE_CONNECTOR_ID = '.inference';
export enum SUB_ACTION {
UNIFIED_COMPLETION_ASYNC_ITERATOR = 'unified_completion_async_iterator',
UNIFIED_COMPLETION_STREAM = 'unified_completion_stream',
YulNaumenko marked this conversation as resolved.
Show resolved Hide resolved
UNIFIED_COMPLETION = 'unified_completion',
COMPLETION = 'completion',
RERANK = 'rerank',
TEXT_EMBEDDING = 'text_embedding',
Expand Down
179 changes: 179 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,176 @@ export const ChatCompleteParamsSchema = schema.object({
input: schema.string(),
});

// subset of OpenAI.ChatCompletionMessageParam https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts
const AIMessage = schema.object({
role: schema.string(),
content: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
tool_calls: schema.maybe(
schema.arrayOf(
schema.object({
id: schema.string(),
function: schema.object({
arguments: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
}),
type: schema.string(),
})
)
),
tool_call_id: schema.maybe(schema.string()),
});

const AITool = schema.object({
type: schema.string(),
function: schema.object({
name: schema.string(),
description: schema.maybe(schema.string()),
parameters: schema.maybe(schema.recordOf(schema.string(), schema.any())),
}),
});

// subset of OpenAI.ChatCompletionCreateParamsBase https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts
export const UnifiedChatCompleteParamsSchema = schema.object({
body: schema.object({
messages: schema.arrayOf(AIMessage, { defaultValue: [] }),
model: schema.maybe(schema.string()),
/**
* The maximum number of [tokens](/tokenizer) that can be generated in the chat
* completion. This value can be used to control
* [costs](https://openai.com/api/pricing/) for text generated via API.
*
* This value is now deprecated in favor of `max_completion_tokens`, and is not
* compatible with
* [o1 series models](https://platform.openai.com/docs/guides/reasoning).
*/
max_tokens: schema.maybe(schema.number()),
/**
* Developer-defined tags and values used for filtering completions in the
* [dashboard](https://platform.openai.com/chat-completions).
*/
metadata: schema.maybe(schema.recordOf(schema.string(), schema.string())),
/**
* How many chat completion choices to generate for each input message. Note that
* you will be charged based on the number of generated tokens across all of the
* choices. Keep `n` as `1` to minimize costs.
*/
n: schema.maybe(schema.number()),
/**
* Up to 4 sequences where the API will stop generating further tokens.
*/
stop: schema.maybe(
schema.nullable(schema.oneOf([schema.string(), schema.arrayOf(schema.string())]))
),
/**
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
* make the output more random, while lower values like 0.2 will make it more
* focused and deterministic.
*
* We generally recommend altering this or `top_p` but not both.
*/
temperature: schema.maybe(schema.number()),
/**
* Controls which (if any) tool is called by the model. `none` means the model will
* not call any tool and instead generates a message. `auto` means the model can
* pick between generating a message or calling one or more tools. `required` means
* the model must call one or more tools. Specifying a particular tool via
* `{"type": "function", "function": {"name": "my_function"}}` forces the model to
* call that tool.
*
* `none` is the default when no tools are present. `auto` is the default if tools
* are present.
*/
tool_choice: schema.maybe(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I supposed to do a transform on the OpenAI function_call to make it into tool_choice?

function_call: schema.maybe(
schema.oneOf([
schema.literal('none'),
schema.literal('auto'),
schema.object(
{
name: schema.string(),
},
{ unknowns: 'ignore' }
),
])
),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, we should stop using the deprecated way of function_call usage and migrate to tool_choice and tools

schema.oneOf([
schema.string(),
schema.object({
type: schema.string(),
function: schema.object({
name: schema.string(),
}),
}),
])
),
/**
* A list of tools the model may call. Currently, only functions are supported as a
* tool. Use this to provide a list of functions the model may generate JSON inputs
* for. A max of 128 functions are supported.
*/
tools: schema.maybe(schema.arrayOf(AITool)),
/**
* An alternative to sampling with temperature, called nucleus sampling, where the
* model considers the results of the tokens with top_p probability mass. So 0.1
* means only the tokens comprising the top 10% probability mass are considered.
*
* We generally recommend altering this or `temperature` but not both.
*/
top_p: schema.maybe(schema.number()),
/**
* A unique identifier representing your end-user, which can help OpenAI to monitor
* and detect abuse.
* [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
*/
user: schema.maybe(schema.string()),
}),
// abort signal from client
signal: schema.maybe(schema.any()),
});

export const UnifiedChatCompleteResponseSchema = schema.object({
id: schema.string(),
choices: schema.arrayOf(
schema.object({
finish_reason: schema.maybe(
schema.nullable(
schema.oneOf([
schema.literal('stop'),
schema.literal('length'),
schema.literal('tool_calls'),
schema.literal('content_filter'),
schema.literal('function_call'),
])
)
),
index: schema.maybe(schema.number()),
message: schema.object({
content: schema.maybe(schema.nullable(schema.string())),
refusal: schema.maybe(schema.nullable(schema.string())),
role: schema.maybe(schema.string()),
tool_calls: schema.maybe(
schema.arrayOf(
schema.object({
id: schema.maybe(schema.string()),
index: schema.maybe(schema.number()),
function: schema.maybe(
schema.object({
arguments: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
})
),
type: schema.maybe(schema.string()),
}),
{ defaultValue: [] }
)
),
}),
}),
{ defaultValue: [] }
),
created: schema.maybe(schema.number()),
model: schema.maybe(schema.string()),
object: schema.maybe(schema.string()),
usage: schema.maybe(
schema.nullable(
schema.object({
completion_tokens: schema.maybe(schema.number()),
prompt_tokens: schema.maybe(schema.number()),
total_tokens: schema.maybe(schema.number()),
})
)
),
});

export const ChatCompleteResponseSchema = schema.arrayOf(
schema.object({
result: schema.string(),
Expand Down Expand Up @@ -66,3 +236,12 @@ export const TextEmbeddingResponseSchema = schema.arrayOf(
);

export const StreamingResponseSchema = schema.stream();

// Run action schema
export const DashboardActionParamsSchema = schema.object({
dashboardId: schema.string(),
});

export const DashboardActionResponseSchema = schema.object({
available: schema.boolean(),
});
10 changes: 10 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ import {
SparseEmbeddingResponseSchema,
TextEmbeddingParamsSchema,
TextEmbeddingResponseSchema,
UnifiedChatCompleteParamsSchema,
UnifiedChatCompleteResponseSchema,
DashboardActionParamsSchema,
DashboardActionResponseSchema,
} from './schema';
import { ConfigProperties } from '../dynamic_config/types';

export type Config = TypeOf<typeof ConfigSchema>;
export type Secrets = TypeOf<typeof SecretsSchema>;

export type UnifiedChatCompleteParams = TypeOf<typeof UnifiedChatCompleteParamsSchema>;
export type UnifiedChatCompleteResponse = TypeOf<typeof UnifiedChatCompleteResponseSchema>;

export type ChatCompleteParams = TypeOf<typeof ChatCompleteParamsSchema>;
export type ChatCompleteResponse = TypeOf<typeof ChatCompleteResponseSchema>;

Expand All @@ -38,6 +45,9 @@ export type TextEmbeddingResponse = TypeOf<typeof TextEmbeddingResponseSchema>;

export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;

export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;

export type FieldsConfiguration = Record<string, ConfigProperties>;

export interface InferenceProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,27 @@ export const DEFAULT_TEXT_EMBEDDING_BODY = {
inputType: 'ingest',
};

export const DEFAULT_UNIFIED_CHAT_COMPLETE_BODY = {
body: {
messages: [
{
role: 'user',
content: 'Hello world',
},
],
},
};

export const DEFAULTS_BY_TASK_TYPE: Record<string, unknown> = {
[SUB_ACTION.COMPLETION]: DEFAULT_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION_STREAM]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.RERANK]: DEFAULT_RERANK_BODY,
[SUB_ACTION.SPARSE_EMBEDDING]: DEFAULT_SPARSE_EMBEDDING_BODY,
[SUB_ACTION.TEXT_EMBEDDING]: DEFAULT_TEXT_EMBEDDING_BODY,
};

export const DEFAULT_TASK_TYPE = 'completion';
export const DEFAULT_TASK_TYPE = 'unified_completion';

export const DEFAULT_PROVIDER = 'elasticsearch';
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ describe('OpenAI action params validation', () => {
subActionParams: { input: ['message test'], query: 'foobar' },
},
{
subAction: SUB_ACTION.COMPLETION,
subActionParams: { input: 'message test' },
subAction: SUB_ACTION.UNIFIED_COMPLETION,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.UNIFIED_COMPLETION_STREAM,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.TEXT_EMBEDDING,
Expand All @@ -55,6 +63,10 @@ describe('OpenAI action params validation', () => {
subAction: SUB_ACTION.SPARSE_EMBEDDING,
subActionParams: { input: 'message test' },
},
{
subAction: SUB_ACTION.COMPLETION,
subActionParams: { input: 'message test' },
},
])(
'validation succeeds when params are valid for subAction $subAction',
async ({ subAction, subActionParams }) => {
Expand All @@ -63,19 +75,25 @@ describe('OpenAI action params validation', () => {
subActionParams,
};
expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: { input: [], subAction: [], inputType: [], query: [] },
errors: { body: [], input: [], subAction: [], inputType: [], query: [] },
});
}
);

test('params validation fails when params is a wrong object', async () => {
const actionParams = {
subAction: SUB_ACTION.COMPLETION,
subAction: SUB_ACTION.UNIFIED_COMPLETION,
subActionParams: { body: 'message {test}' },
};

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: { input: ['Input is required.'], inputType: [], query: [], subAction: [] },
errors: {
body: ['Messages is required.'],
inputType: [],
query: [],
subAction: [],
input: [],
},
});
});

Expand All @@ -86,6 +104,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: [],
query: [],
Expand All @@ -102,6 +121,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: [],
query: [],
Expand All @@ -118,6 +138,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: ['Input is required.', 'Input does not have a valid Array format.'],
inputType: [],
query: ['Query is required.'],
Expand All @@ -134,6 +155,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: ['Input type is required.'],
query: [],
Expand Down
Loading
Loading