Skip to content

Commit

Permalink
[ML] Data Frame Analytics: Fix feature importance (#61761) (#62541)
Browse files Browse the repository at this point in the history
- Fixes missing num_top_feature_importance_values parameter for analytics job configurations
- Fixes analytics create form to consider feature importance
- Fixes missing feature importance fields from results pages
  • Loading branch information
walterra authored Apr 4, 2020
1 parent e979b14 commit 7884336
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ interface OutlierAnalysis {
interface Regression {
dependent_variable: string;
training_percent?: number;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface RegressionAnalysis {
Expand All @@ -44,6 +45,7 @@ interface Classification {
dependent_variable: string;
training_percent?: number;
num_top_classes?: string;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface ClassificationAnalysis {
Expand All @@ -65,6 +67,8 @@ export const SEARCH_SIZE = 1000;
export const TRAINING_PERCENT_MIN = 1;
export const TRAINING_PERCENT_MAX = 100;

export const NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN = 0;

export const defaultSearchQuery = {
match_all: {},
};
Expand Down Expand Up @@ -152,7 +156,7 @@ type AnalysisConfig =
| ClassificationAnalysis
| GenericAnalysis;

export const getAnalysisType = (analysis: AnalysisConfig) => {
export const getAnalysisType = (analysis: AnalysisConfig): string => {
const keys = Object.keys(analysis);

if (keys.length === 1) {
Expand All @@ -162,7 +166,11 @@ export const getAnalysisType = (analysis: AnalysisConfig) => {
return 'unknown';
};

export const getDependentVar = (analysis: AnalysisConfig) => {
export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';

if (isRegressionAnalysis(analysis)) {
Expand All @@ -175,7 +183,11 @@ export const getDependentVar = (analysis: AnalysisConfig) => {
return depVar;
};

export const getTrainingPercent = (analysis: AnalysisConfig) => {
export const getTrainingPercent = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['training_percent']
| ClassificationAnalysis['classification']['training_percent'] => {
let trainingPercent;

if (isRegressionAnalysis(analysis)) {
Expand All @@ -188,7 +200,11 @@ export const getTrainingPercent = (analysis: AnalysisConfig) => {
return trainingPercent;
};

export const getPredictionFieldName = (analysis: AnalysisConfig) => {
export const getPredictionFieldName = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['prediction_field_name']
| ClassificationAnalysis['classification']['prediction_field_name'] => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
Expand All @@ -202,6 +218,26 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => {
return predictionFieldName;
};

export const getNumTopFeatureImportanceValues = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['num_top_feature_importance_values']
| ClassificationAnalysis['classification']['num_top_feature_importance_values'] => {
let numTopFeatureImportanceValues;
if (
isRegressionAnalysis(analysis) &&
analysis.regression.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.regression.num_top_feature_importance_values;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.classification.num_top_feature_importance_values;
}
return numTopFeatureImportanceValues;
};

export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import { getNestedProperty } from '../../util/object_utils';
import {
DataFrameAnalyticsConfig,
getNumTopFeatureImportanceValues,
getPredictedFieldName,
getDependentVar,
getPredictionFieldName,
} from './analytics';
import { Field } from '../../../../common/types/fields';
import { ES_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { newJobCapsService } from '../../services/new_job_capabilities_service';

export type EsId = string;
Expand Down Expand Up @@ -254,14 +255,28 @@ export const getDefaultFieldsFromJobCaps = (
const dependentVariable = getDependentVar(jobConfig.analysis);
const type = newJobCapsService.getFieldById(dependentVariable)?.type;
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
const numTopFeatureImportanceValues = getNumTopFeatureImportanceValues(jobConfig.analysis);
// default is 'ml'
const resultsField = jobConfig.dest.results_field;

const defaultPredictionField = `${dependentVariable}_prediction`;
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;
// Only need to add these first two fields if we didn't use dest index pattern to get the fields

const featureImportanceFields = [];

if ((numTopFeatureImportanceValues ?? 0) > 0) {
featureImportanceFields.push(
...fields.map(d => ({
id: `${resultsField}.feature_importance.${d.id}`,
name: `${resultsField}.feature_importance.${d.name}`,
type: KBN_FIELD_TYPES.NUMBER,
}))
);
}

// Only need to add these fields if we didn't use dest index pattern to get the fields
const allFields: any =
needsDestIndexFields === true
? [
Expand All @@ -271,16 +286,20 @@ export const getDefaultFieldsFromJobCaps = (
type: ES_FIELD_TYPES.BOOLEAN,
},
{ id: predictedField, name: predictedField, type },
...featureImportanceFields,
]
: [];

allFields.push(...fields);
// @ts-ignore
allFields.sort(({ name: a }, { name: b }) => sortRegressionResultsFields(a, b, jobConfig));

let selectedFields = allFields
.slice(0, DEFAULT_REGRESSION_COLUMNS * 2)
.filter((field: any) => field.name === predictedField || !field.name.includes('.keyword'));
allFields.sort(({ name: a }: { name: string }, { name: b }: { name: string }) =>
sortRegressionResultsFields(a, b, jobConfig)
);

let selectedFields = allFields.filter(
(field: any) =>
field.name === predictedField ||
(!field.name.includes('.keyword') && !field.name.includes('.feature_importance.'))
);

if (selectedFields.length > DEFAULT_REGRESSION_COLUMNS) {
selectedFields = selectedFields.slice(0, DEFAULT_REGRESSION_COLUMNS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
Expand Down Expand Up @@ -90,6 +91,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
Expand Down Expand Up @@ -120,6 +122,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
Expand Down Expand Up @@ -188,6 +191,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
Expand Down Expand Up @@ -218,6 +222,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
max_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',
Expand All @@ -243,6 +248,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
maximum_number_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import { i18n } from '@kbn/i18n';
import { DeepReadonly } from '../../../../../../../common/types/common';
import { DataFrameAnalyticsConfig, isOutlierAnalysis } from '../../../../common';
import { isClassificationAnalysis, isRegressionAnalysis } from '../../../../common/analytics';
import { CreateAnalyticsFormProps } from '../../hooks/use_create_analytics_form';
import {
CreateAnalyticsFormProps,
DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
} from '../../hooks/use_create_analytics_form';
import { State } from '../../hooks/use_create_analytics_form/state';
import { DataFrameAnalyticsListRow } from './common';

Expand Down Expand Up @@ -97,6 +100,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
formKey: 'numTopFeatureImportanceValues',
},
class_assignment_objective: {
optional: true,
Expand Down Expand Up @@ -164,6 +169,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
formKey: 'numTopFeatureImportanceValues',
},
randomize_seed: {
optional: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
EuiComboBox,
EuiComboBoxOptionOption,
EuiForm,
EuiFieldNumber,
EuiFieldText,
EuiFormRow,
EuiLink,
Expand Down Expand Up @@ -41,6 +42,7 @@ import {
ANALYSIS_CONFIG_TYPE,
DfAnalyticsExplainResponse,
FieldSelectionItem,
NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN,
TRAINING_PERCENT_MIN,
TRAINING_PERCENT_MAX,
} from '../../../../common/analytics';
Expand Down Expand Up @@ -83,6 +85,8 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
maxDistinctValuesError,
modelMemoryLimit,
modelMemoryLimitValidationResult,
numTopFeatureImportanceValues,
numTopFeatureImportanceValuesValid,
previousJobType,
previousSourceIndex,
sourceIndex,
Expand Down Expand Up @@ -645,6 +649,54 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
data-test-subj="mlAnalyticsCreateJobFlyoutTrainingPercentSlider"
/>
</EuiFormRow>
{/* num_top_feature_importance_values */}
<EuiFormRow
label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesLabel',
{
defaultMessage: 'Feature importance values',
}
)}
helpText={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesHelpText',
{
defaultMessage:
'Specify the maximum number of feature importance values per document to return.',
}
)}
isInvalid={numTopFeatureImportanceValuesValid === false}
error={[
...(numTopFeatureImportanceValuesValid === false
? [
<Fragment>
{i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesErrorText',
{
defaultMessage:
'Invalid maximum number of feature importance values.',
}
)}
</Fragment>,
]
: []),
]}
>
<EuiFieldNumber
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesInputAriaLabel',
{
defaultMessage: 'Maximum number of feature importance values per document.',
}
)}
data-test-subj="mlAnalyticsCreateJobFlyoutnumTopFeatureImportanceValuesInput"
disabled={false}
isInvalid={numTopFeatureImportanceValuesValid === false}
min={NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN}
onChange={e => setFormState({ numTopFeatureImportanceValues: +e.target.value })}
step={1}
value={numTopFeatureImportanceValues}
/>
</EuiFormRow>
</Fragment>
)}
<EuiFormRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
* you may not use this file except in compliance with the Elastic License.
*/

export { DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES } from './state';
export { useCreateAnalyticsForm, CreateAnalyticsFormProps } from './use_create_analytics_form';
Loading

0 comments on commit 7884336

Please sign in to comment.