diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts index 66ea0c2d3237c..a99a7a4bfc80b 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts @@ -6,6 +6,7 @@ */ import { + IngestRemoveProcessor, IngestSetProcessor, MlTrainedModelConfig, MlTrainedModelStats, @@ -19,11 +20,12 @@ import { BUILT_IN_MODEL_TAG as LOCAL_BUILT_IN_MODEL_TAG, generateMlInferencePipelineBody, getMlModelTypesForModelConfig, + getRemoveProcessorForInferenceType, getSetProcessorForInferenceType, - SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS, parseMlInferenceParametersFromPipeline, parseModelStateFromStats, parseModelStateReasonFromStats, + SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS, } from '.'; const mockModel: MlTrainedModelConfig = { @@ -69,6 +71,38 @@ describe('getMlModelTypesForModelConfig lib function', () => { }); }); +describe('getRemoveProcessorForInferenceType lib function', () => { + const destinationField = 'dest'; + + it('should return expected value for TEXT_CLASSIFICATION', () => { + const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION; + + const expected: IngestRemoveProcessor = { + field: destinationField, + ignore_missing: true, + }; + + expect(getRemoveProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected); + }); + + it('should return expected value for TEXT_EMBEDDING', () => { + const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING; + + const expected: IngestRemoveProcessor = { + field: destinationField, + ignore_missing: true, + }; + + expect(getRemoveProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected); + }); + + it('should return undefined for unknown inferenceType', () => { + const inferenceType = 'wrongInferenceType'; + + expect(getRemoveProcessorForInferenceType(destinationField, inferenceType)).toBeUndefined(); + }); +}); + describe('getSetProcessorForInferenceType lib function', () => { const destinationField = 'dest'; @@ -84,7 +118,7 @@ describe('getSetProcessorForInferenceType lib function', () => { description: "Copy the predicted_value to 'dest' if the prediction_probability is greater than 0.5", field: destinationField, - if: 'ctx.ml.inference.dest.prediction_probability > 0.5', + if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null && ctx.ml.inference['dest'].prediction_probability > 0.5", value: undefined, }; @@ -98,6 +132,7 @@ describe('getSetProcessorForInferenceType lib function', () => { copy_from: 'ml.inference.dest.predicted_value', description: "Copy the predicted_value to 'dest'", field: destinationField, + if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null", value: undefined, }; @@ -191,13 +226,19 @@ describe('generateMlInferencePipelineBody lib function', () => { expect.objectContaining({ description: expect.any(String), processors: expect.arrayContaining([ + expect.objectContaining({ + remove: { + field: 'my-destination-field', + ignore_missing: true, + }, + }), expect.objectContaining({ set: { copy_from: 'ml.inference.my-destination-field.predicted_value', description: "Copy the predicted_value to 'my-destination-field' if the prediction_probability is greater than 0.5", field: 'my-destination-field', - if: 'ctx.ml.inference.my-destination-field.prediction_probability > 0.5', + if: "ctx?.ml?.inference != null && ctx.ml.inference['my-destination-field'] != null && ctx.ml.inference['my-destination-field'].prediction_probability > 0.5", }, }), ]), diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts index b7525734fd5a1..61669d36badad 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts @@ -7,6 +7,7 @@ import { IngestPipeline, + IngestRemoveProcessor, IngestSetProcessor, MlTrainedModelConfig, MlTrainedModelStats, @@ -58,6 +59,7 @@ export const generateMlInferencePipelineBody = ({ model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD'; const inferenceType = Object.keys(model.inference_config)[0]; + const remove = getRemoveProcessorForInferenceType(destinationField, inferenceType); const set = getSetProcessorForInferenceType(destinationField, inferenceType); return { @@ -69,6 +71,7 @@ export const generateMlInferencePipelineBody = ({ ignore_missing: true, }, }, + ...(remove ? [{ remove }] : []), { inference: { field_map: { @@ -123,7 +126,7 @@ export const getSetProcessorForInferenceType = ( copy_from: `${prefixedDestinationField}.predicted_value`, description: `Copy the predicted_value to '${destinationField}' if the prediction_probability is greater than 0.5`, field: destinationField, - if: `ctx.${prefixedDestinationField}.prediction_probability > 0.5`, + if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null && ctx.ml.inference['${destinationField}'].prediction_probability > 0.5`, value: undefined, }; } else if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) { @@ -131,6 +134,7 @@ export const getSetProcessorForInferenceType = ( copy_from: `${prefixedDestinationField}.predicted_value`, description: `Copy the predicted_value to '${destinationField}'`, field: destinationField, + if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null`, value: undefined, }; } @@ -138,6 +142,21 @@ export const getSetProcessorForInferenceType = ( return set; }; +export const getRemoveProcessorForInferenceType = ( + destinationField: string, + inferenceType: string +): IngestRemoveProcessor | undefined => { + if ( + inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION || + inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING + ) { + return { + field: destinationField, + ignore_missing: true, + }; + } +}; + /** * Parses model types list from the given configuration of a trained machine learning model * @param trainedModel configuration for a trained machine learning model