Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference Connector] Modified getProvider to use _inference/_services ES API instead of hardcoded values. #199047

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
TextEmbeddingParamsSchema,
TextEmbeddingResponseSchema,
} from './schema';
import { ConfigProperties } from '../dynamic_config/types';

export type Config = TypeOf<typeof ConfigSchema>;
export type Secrets = TypeOf<typeof SecretsSchema>;
Expand All @@ -36,3 +37,17 @@ export type TextEmbeddingParams = TypeOf<typeof TextEmbeddingParamsSchema>;
export type TextEmbeddingResponse = TypeOf<typeof TextEmbeddingResponseSchema>;

export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;

export type FieldsConfiguration = Record<string, ConfigProperties>;

export interface InferenceTaskType {
task_type: string;
configuration: FieldsConfiguration;
}

export interface InferenceProvider {
provider: string;
task_types: InferenceTaskType[];
logo?: string;
configuration: FieldsConfiguration;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ import {
import { FormattedMessage } from '@kbn/i18n-react';

import { fieldValidators } from '@kbn/es-ui-shared-plugin/static/forms/helpers';
import { ConfigEntryView } from '../../../common/dynamic_config/types';
import { ConnectorConfigurationFormItems } from '../lib/dynamic_config/connector_configuration_form_items';
import * as i18n from './translations';
import { DEFAULT_TASK_TYPE } from './constants';
import { ConfigEntryView } from '../lib/dynamic_config/types';
import { Config } from './types';
import { TaskTypeOption } from './helpers';

Expand All @@ -52,7 +52,7 @@ interface AdditionalOptionsConnectorFieldsProps {
isEdit: boolean;
optionalProviderFormFields: ConfigEntryView[];
onSetProviderConfigEntry: (key: string, value: unknown) => Promise<void>;
onTaskTypeOptionsSelect: (taskType: string, provider?: string) => Promise<void>;
onTaskTypeOptionsSelect: (taskType: string, provider?: string) => void;
selectedTaskType?: string;
taskTypeFormFields: ConfigEntryView[];
taskTypeSchema: ConfigEntryView[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ import { ConnectorFormTestProvider } from '../lib/test_utils';
import { render, waitFor } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { createStartServicesMock } from '@kbn/triggers-actions-ui-plugin/public/common/lib/kibana/kibana_react.mock';
import { DisplayType, FieldType } from '../lib/dynamic_config/types';
import { useProviders } from './providers/get_providers';
import { getTaskTypes } from './get_task_types';
import { HttpSetup } from '@kbn/core-http-browser';
import { DisplayType, FieldType } from '../../../common/dynamic_config/types';

jest.mock('./providers/get_providers');
jest.mock('./get_task_types');

const mockUseKibanaReturnValue = createStartServicesMock();
jest.mock('@kbn/triggers-actions-ui-plugin/public/common/lib/kibana', () => ({
Expand All @@ -37,13 +34,32 @@ jest.mock('@faker-js/faker', () => ({
}));

const mockProviders = useProviders as jest.Mock;
const mockTaskTypes = getTaskTypes as jest.Mock;

const providersSchemas = [
{
provider: 'openai',
logo: '', // should be openai logo here, the hardcoded uses assets/images
taskTypes: ['completion', 'text_embedding'],
task_types: [
{
task_type: 'completion',
configuration: {
user: {
display: DisplayType.TEXTBOX,
label: 'User',
order: 1,
required: false,
sensitive: false,
tooltip: 'Specifies the user issuing the request.',
type: FieldType.STRING,
validations: [],
value: '',
ui_restrictions: [],
default_value: null,
depends_on: [],
},
},
},
],
configuration: {
api_key: {
display: DisplayType.TEXTBOX,
Expand Down Expand Up @@ -106,7 +122,16 @@ const providersSchemas = [
{
provider: 'googleaistudio',
logo: '', // should be googleaistudio logo here, the hardcoded uses assets/images
taskTypes: ['completion', 'text_embedding'],
task_types: [
{
task_type: 'completion',
configuration: {},
},
{
task_type: 'text_embedding',
configuration: {},
},
],
configuration: {
api_key: {
display: DisplayType.TEXTBOX,
Expand Down Expand Up @@ -139,39 +164,6 @@ const providersSchemas = [
},
},
];
const taskTypesSchemas: Record<string, any> = {
googleaistudio: [
{
task_type: 'completion',
configuration: {},
},
{
task_type: 'text_embedding',
configuration: {},
},
],
openai: [
{
task_type: 'completion',
configuration: {
user: {
display: DisplayType.TEXTBOX,
label: 'User',
order: 1,
required: false,
sensitive: false,
tooltip: 'Specifies the user issuing the request.',
type: FieldType.STRING,
validations: [],
value: '',
ui_restrictions: [],
default_value: null,
depends_on: [],
},
},
},
],
};

const openAiConnector = {
actionTypeId: '.inference',
Expand Down Expand Up @@ -222,9 +214,6 @@ describe('ConnectorFields renders', () => {
isLoading: false,
data: providersSchemas,
});
mockTaskTypes.mockImplementation(
(http: HttpSetup, provider: string) => taskTypesSchemas[provider]
);
});
test('openai provider fields are rendered', async () => {
const { getAllByTestId } = render(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import React, { useState, useEffect, useCallback } from 'react';
import React, { useState, useEffect, useCallback, useMemo } from 'react';
import {
EuiFormRow,
EuiSpacer,
Expand All @@ -31,12 +31,12 @@ import {
import { useKibana } from '@kbn/triggers-actions-ui-plugin/public';

import { fieldValidators } from '@kbn/es-ui-shared-plugin/static/forms/helpers';
import { ConfigEntryView } from '../../../common/dynamic_config/types';
import { InferenceTaskType } from '../../../common/inference/types';
import { ServiceProviderKeys } from '../../../common/inference/constants';
import { ConnectorConfigurationFormItems } from '../lib/dynamic_config/connector_configuration_form_items';
import { getTaskTypes } from './get_task_types';
import * as i18n from './translations';
import { DEFAULT_TASK_TYPE } from './constants';
import { ConfigEntryView } from '../lib/dynamic_config/types';
import { SelectableProvider } from './providers/selectable';
import { Config, Secrets } from './types';
import { generateInferenceEndpointId, getTaskTypeOptions, TaskTypeOption } from './helpers';
Expand Down Expand Up @@ -116,13 +116,13 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
}, [isSubmitting, config, validateFields]);

const onTaskTypeOptionsSelect = useCallback(
async (taskType: string, provider?: string) => {
(taskType: string, provider?: string) => {
// Get task type settings
const currentTaskTypes = await getTaskTypes(http, provider ?? config?.provider);
const currentProvider = providers?.find((p) => p.provider === (provider ?? config?.provider));
const currentTaskTypes = currentProvider?.task_types;
const newTaskType = currentTaskTypes?.find((p) => p.task_type === taskType);

setSelectedTaskType(taskType);
generateInferenceEndpointId(config, setFieldValue);

// transform the schema
const newTaskTypeSchema = Object.keys(newTaskType?.configuration ?? {}).map((k) => ({
Expand Down Expand Up @@ -150,19 +150,23 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
taskTypeConfig: configDefaults,
},
});
generateInferenceEndpointId(
{ ...config, taskType, taskTypeConfig: configDefaults },
setFieldValue
);
},
[config, http, setFieldValue, updateFieldValues]
[config, providers, setFieldValue, updateFieldValues]
);

const onProviderChange = useCallback(
async (provider?: string) => {
(provider?: string) => {
const newProvider = providers?.find((p) => p.provider === provider);

// Update task types list available for the selected provider
const providerTaskTypes = newProvider?.taskTypes ?? [];
const providerTaskTypes = (newProvider?.task_types ?? []).map((t) => t.task_type);
setTaskTypeOptions(getTaskTypeOptions(providerTaskTypes));
if (providerTaskTypes.length > 0) {
await onTaskTypeOptionsSelect(providerTaskTypes[0], provider);
onTaskTypeOptionsSelect(providerTaskTypes[0], provider);
}

// Update connector providerSchema
Expand Down Expand Up @@ -203,9 +207,8 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
);

useEffect(() => {
const getTaskTypeSchema = async () => {
const currentTaskTypes = await getTaskTypes(http, config?.provider ?? '');
const newTaskType = currentTaskTypes?.find((p) => p.task_type === config?.taskType);
const getTaskTypeSchema = (taskTypes: InferenceTaskType[]) => {
const newTaskType = taskTypes.find((p) => p.task_type === config?.taskType);

// transform the schema
const newTaskTypeSchema = Object.keys(newTaskType?.configuration ?? {}).map((k) => ({
Expand All @@ -228,7 +231,7 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields

setProviderSchema(newProviderSchema);

getTaskTypeSchema();
getTaskTypeSchema(newProvider?.task_types ?? []);
}
}, [config?.provider, config?.taskType, http, isEdit, providers]);

Expand Down Expand Up @@ -309,6 +312,22 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
setFieldValue('config.provider', '');
}, [onProviderChange, setFieldValue]);

const providerIcon = useMemo(
() =>
Object.keys(SERVICE_PROVIDERS).includes(config?.provider)
? SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].icon
: undefined,
[config?.provider]
);

const providerName = useMemo(
() =>
Object.keys(SERVICE_PROVIDERS).includes(config?.provider)
? SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].name
: config?.provider,
[config?.provider]
);

const providerSuperSelect = useCallback(
(isInvalid: boolean) => (
<EuiFormControlLayout
Expand All @@ -317,21 +336,15 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
isDisabled={isEdit || readOnly}
isInvalid={isInvalid}
fullWidth
icon={
!config?.provider
? { type: 'sparkles', side: 'left' }
: SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].icon
}
icon={!config?.provider ? { type: 'sparkles', side: 'left' } : providerIcon}
>
<EuiFieldText
onClick={handleProviderPopover}
data-test-subj="provider-select"
isInvalid={isInvalid}
disabled={isEdit || readOnly}
onKeyDown={handleProviderKeyboardOpen}
value={
config?.provider ? SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].name : ''
}
value={config?.provider ? providerName : ''}
fullWidth
placeholder={i18n.SELECT_PROVIDER}
icon={{ type: 'arrowDown', side: 'right' }}
Expand All @@ -345,8 +358,10 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
readOnly,
onClearProvider,
config?.provider,
providerIcon,
handleProviderPopover,
handleProviderKeyboardOpen,
providerName,
isProviderPopoverOpen,
]
);
Expand Down

This file was deleted.

Loading