Skip to content

Commit

Permalink
[ML] Fix DFA feature importance popover empty (elastic#91061)
Browse files Browse the repository at this point in the history
  • Loading branch information
qn895 authored and kibanamachine committed Feb 16, 2021
1 parent eeaaeb7 commit 857d478
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 71 deletions.
86 changes: 33 additions & 53 deletions x-pack/plugins/ml/public/application/components/data_grid/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ import {

import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { extractErrorMessage } from '../../../../common/util/errors';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceClassName,
TopClasses,
} from '../../../../common/types/feature_importance';

import {
BASIC_NUMERICAL_TYPES,
Expand Down Expand Up @@ -168,8 +172,9 @@ const getClassName = (className: string, isClassTypeBoolean: boolean) => {

return className;
};

/**
* Helper to transform feature importance flattened fields with arrays back to object structure
* Helper to transform feature importance fields with arrays back to primitive value
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
Expand All @@ -180,69 +185,44 @@ export const getFeatureImportance = (
mlResultsField: string,
isClassTypeBoolean = false
): FeatureImportance[] => {
const featureNames: string[] | undefined =
row[`${mlResultsField}.feature_importance.feature_name`];
const classNames: string[] | undefined =
row[`${mlResultsField}.feature_importance.classes.class_name`];
const classImportance: number[] | undefined =
row[`${mlResultsField}.feature_importance.classes.importance`];

if (featureNames === undefined) {
return [];
}

// return object structure for classification job
if (classNames !== undefined && classImportance !== undefined) {
const overallClassNames = classNames?.slice(0, classNames.length / featureNames.length);

return featureNames.map((fName, index) => {
const offset = overallClassNames.length * index;
const featureClassImportance = classImportance.slice(
offset,
offset + overallClassNames.length
);
return {
feature_name: fName,
classes: overallClassNames.map((fClassName, fIndex) => {
const featureImportance: Array<{
feature_name: string[];
classes?: Array<{ class_name: FeatureImportanceClassName[]; importance: number[] }>;
importance?: number | number[];
}> = row[`${mlResultsField}.feature_importance`];
if (featureImportance === undefined) return [];

return featureImportance.map((fi) => ({
feature_name: Array.isArray(fi.feature_name) ? fi.feature_name[0] : fi.feature_name,
classes: Array.isArray(fi.classes)
? fi.classes.map((c) => {
const processedClass = getProcessedFields(c);
return {
class_name: getClassName(fClassName, isClassTypeBoolean),
importance: featureClassImportance[fIndex],
importance: processedClass.importance,
class_name: getClassName(processedClass.class_name, isClassTypeBoolean),
};
}),
};
});
}

// return object structure for regression job
const importance: number[] = row[`${mlResultsField}.feature_importance.importance`];
return featureNames.map((fName, index) => ({
feature_name: fName,
importance: importance[index],
})
: fi.classes,
importance: Array.isArray(fi.importance) ? fi.importance[0] : fi.importance,
}));
};

/**
* Helper to transforms top classes flattened fields with arrays back to object structure
* Helper to transforms top classes fields with arrays back to original primitive value
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
* @returns nested object structure of feature importance values
*/
export const getTopClasses = (row: Record<string, any>, mlResultsField: string): TopClasses => {
const classNames: string[] | undefined = row[`${mlResultsField}.top_classes.class_name`];
const classProbabilities: number[] | undefined =
row[`${mlResultsField}.top_classes.class_probability`];
const classScores: number[] | undefined = row[`${mlResultsField}.top_classes.class_score`];

if (classNames === undefined || classProbabilities === undefined || classScores === undefined) {
return [];
}

return classNames.map((className, index) => ({
class_name: className,
class_probability: classProbabilities[index],
class_score: classScores[index],
}));
const topClasses: Array<{
class_name: FeatureImportanceClassName[];
class_probability: number[];
class_score: number[];
}> = row[`${mlResultsField}.top_classes`];

if (topClasses === undefined) return [];
return topClasses.map((tc) => getProcessedFields(tc)) as TopClasses;
};

export const useRenderCellValue = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import {
getTopClasses,
} from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { DecisionPathPopover } from '../../data_frame_analytics/pages/analytics_exploration/components/feature_importance/decision_path_popover';
import {
FeatureImportanceBaseline,
FeatureImportance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@
.mlExpandableSection-contentPadding {
padding: $euiSizeS;
}

// Make sure the charts tooltip in popover
// have higher zIndex than Eui popover cells
[id^='echTooltipPortal'] {
z-index: $euiZLevel9 !important;
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ import { EuiIcon } from '@elastic/eui';
import React, { useCallback, useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
import { DecisionPathPlotData } from './use_classification_path_data';
import { formatSingleValue } from '../../../formatters/format_value';
import type { DecisionPathPlotData } from './use_classification_path_data';
import { formatSingleValue } from '../../../../../formatters/format_value';
import {
FeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
} from '../../../../../common/types/feature_importance';
} from '../../../../../../../common/types/feature_importance';
const { euiColorFullShade, euiColorMediumShade } = euiVars;
const axisColor = euiColorMediumShade;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ import {
useDecisionPathData,
getStringBasedClassName,
} from './use_classification_path_data';
import {
import type {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
} from '../../../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import React, { FC } from 'react';
import { EuiCodeBlock } from '@elastic/eui';
import { FeatureImportance } from '../../../../../common/types/feature_importance';
import type { FeatureImportance } from '../../../../../../../common/types/feature_importance';

interface DecisionPathJSONViewerProps {
featureImportance: FeatureImportance[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ import {
isClassificationFeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common';
} from '../../../../../../../common/types/feature_importance';
import { ANALYSIS_CONFIG_TYPE } from '../../../../common';
import { ClassificationDecisionPath } from './decision_path_classification';
import { useMlKibana } from '../../../contexts/kibana';
import { DataFrameAnalysisConfigType } from '../../../../../common/types/data_frame_analytics';
import { useMlKibana } from '../../../../../contexts/kibana';
import type { DataFrameAnalysisConfigType } from '../../../../../../../common/types/data_frame_analytics';
import { getStringBasedClassName } from './use_classification_path_data';

interface DecisionPathPopoverProps {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import React, { FC, useMemo } from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import d3 from 'd3';
import {
import type {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
} from '../../../../../../../common/types/feature_importance';
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
buildClassificationDecisionPathData,
buildRegressionDecisionPathData,
} from './use_classification_path_data';
import { FeatureImportance } from '../../../../../common/types/feature_importance';
import type { FeatureImportance } from '../../../../../../../common/types/feature_importance';

describe('buildClassificationDecisionPathData()', () => {
test('should return correct prediction probability for binary classification', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import {
isClassificationFeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { ExtendedFeatureImportance } from './decision_path_popover';
} from '../../../../../../../common/types/feature_importance';
import type { ExtendedFeatureImportance } from './decision_path_popover';

export type DecisionPathPlotData = Array<[string, number, number]>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ export default function ({ getService }: FtrProviderContext) {
const esArchiver = getService('esArchiver');
const ml = getService('ml');

// Failing: See https://github.com/elastic/kibana/issues/90526
describe.skip('total feature importance panel and decision path popover', function () {
describe('total feature importance panel and decision path popover', function () {
const testDataList: Array<{
suiteTitle: string;
archive: string;
Expand Down

0 comments on commit 857d478

Please sign in to comment.