Skip to content

Commit

Permalink
Add failure handling for set processors in ML inference pipelines (el…
Browse files Browse the repository at this point in the history
…astic#144654)

## Summary

Also, add a `remove` processor and `text_classification` and
`text_embedding` types.

### Checklist

Delete any items that are not applicable to this PR.

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios

(cherry picked from commit 8e81a7d)
  • Loading branch information
brianmcgue committed Nov 16, 2022
1 parent a9f7ba6 commit f6f5233
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

import {
IngestRemoveProcessor,
IngestSetProcessor,
MlTrainedModelConfig,
MlTrainedModelStats,
Expand All @@ -19,11 +20,12 @@ import {
BUILT_IN_MODEL_TAG as LOCAL_BUILT_IN_MODEL_TAG,
generateMlInferencePipelineBody,
getMlModelTypesForModelConfig,
getRemoveProcessorForInferenceType,
getSetProcessorForInferenceType,
SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS,
parseMlInferenceParametersFromPipeline,
parseModelStateFromStats,
parseModelStateReasonFromStats,
SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS,
} from '.';

const mockModel: MlTrainedModelConfig = {
Expand Down Expand Up @@ -69,6 +71,38 @@ describe('getMlModelTypesForModelConfig lib function', () => {
});
});

describe('getRemoveProcessorForInferenceType lib function', () => {
const destinationField = 'dest';

it('should return expected value for TEXT_CLASSIFICATION', () => {
const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION;

const expected: IngestRemoveProcessor = {
field: destinationField,
ignore_missing: true,
};

expect(getRemoveProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
});

it('should return expected value for TEXT_EMBEDDING', () => {
const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING;

const expected: IngestRemoveProcessor = {
field: destinationField,
ignore_missing: true,
};

expect(getRemoveProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
});

it('should return undefined for unknown inferenceType', () => {
const inferenceType = 'wrongInferenceType';

expect(getRemoveProcessorForInferenceType(destinationField, inferenceType)).toBeUndefined();
});
});

describe('getSetProcessorForInferenceType lib function', () => {
const destinationField = 'dest';

Expand All @@ -84,7 +118,7 @@ describe('getSetProcessorForInferenceType lib function', () => {
description:
"Copy the predicted_value to 'dest' if the prediction_probability is greater than 0.5",
field: destinationField,
if: 'ctx.ml.inference.dest.prediction_probability > 0.5',
if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null && ctx.ml.inference['dest'].prediction_probability > 0.5",
value: undefined,
};

Expand All @@ -98,6 +132,7 @@ describe('getSetProcessorForInferenceType lib function', () => {
copy_from: 'ml.inference.dest.predicted_value',
description: "Copy the predicted_value to 'dest'",
field: destinationField,
if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null",
value: undefined,
};

Expand Down Expand Up @@ -191,13 +226,19 @@ describe('generateMlInferencePipelineBody lib function', () => {
expect.objectContaining({
description: expect.any(String),
processors: expect.arrayContaining([
expect.objectContaining({
remove: {
field: 'my-destination-field',
ignore_missing: true,
},
}),
expect.objectContaining({
set: {
copy_from: 'ml.inference.my-destination-field.predicted_value',
description:
"Copy the predicted_value to 'my-destination-field' if the prediction_probability is greater than 0.5",
field: 'my-destination-field',
if: 'ctx.ml.inference.my-destination-field.prediction_probability > 0.5',
if: "ctx?.ml?.inference != null && ctx.ml.inference['my-destination-field'] != null && ctx.ml.inference['my-destination-field'].prediction_probability > 0.5",
},
}),
]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import {
IngestPipeline,
IngestRemoveProcessor,
IngestSetProcessor,
MlTrainedModelConfig,
MlTrainedModelStats,
Expand Down Expand Up @@ -58,6 +59,7 @@ export const generateMlInferencePipelineBody = ({
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';

const inferenceType = Object.keys(model.inference_config)[0];
const remove = getRemoveProcessorForInferenceType(destinationField, inferenceType);
const set = getSetProcessorForInferenceType(destinationField, inferenceType);

return {
Expand All @@ -69,6 +71,7 @@ export const generateMlInferencePipelineBody = ({
ignore_missing: true,
},
},
...(remove ? [{ remove }] : []),
{
inference: {
field_map: {
Expand Down Expand Up @@ -123,21 +126,37 @@ export const getSetProcessorForInferenceType = (
copy_from: `${prefixedDestinationField}.predicted_value`,
description: `Copy the predicted_value to '${destinationField}' if the prediction_probability is greater than 0.5`,
field: destinationField,
if: `ctx.${prefixedDestinationField}.prediction_probability > 0.5`,
if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null && ctx.ml.inference['${destinationField}'].prediction_probability > 0.5`,
value: undefined,
};
} else if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) {
set = {
copy_from: `${prefixedDestinationField}.predicted_value`,
description: `Copy the predicted_value to '${destinationField}'`,
field: destinationField,
if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null`,
value: undefined,
};
}

return set;
};

export const getRemoveProcessorForInferenceType = (
destinationField: string,
inferenceType: string
): IngestRemoveProcessor | undefined => {
if (
inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION ||
inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING
) {
return {
field: destinationField,
ignore_missing: true,
};
}
};

/**
* Parses model types list from the given configuration of a trained machine learning model
* @param trainedModel configuration for a trained machine learning model
Expand Down

0 comments on commit f6f5233

Please sign in to comment.