Skip to content

Commit

Permalink
[ML] Add probability values in decision path visualization for classi…
Browse files Browse the repository at this point in the history
…fication data frame analytics (#80229) (#82551)

Co-authored-by: Kibana Machine <[email protected]>

Co-authored-by: Kibana Machine <[email protected]>
  • Loading branch information
qn895 and kibanamachine authored Nov 4, 2020
1 parent 7bce2a1 commit e937618
Show file tree
Hide file tree
Showing 16 changed files with 652 additions and 221 deletions.
40 changes: 37 additions & 3 deletions x-pack/plugins/ml/common/types/feature_importance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
* you may not use this file except in compliance with the Elastic License.
*/

export type FeatureImportanceClassName = string | number | boolean;

export interface ClassFeatureImportance {
class_name: string | boolean;
class_name: FeatureImportanceClassName;
importance: number;
}

Expand All @@ -18,15 +20,15 @@ export interface FeatureImportance {
}

export interface TopClass {
class_name: string;
class_name: FeatureImportanceClassName;
class_probability: number;
class_score: number;
}

export type TopClasses = TopClass[];

export interface ClassFeatureImportanceSummary {
class_name: string;
class_name: FeatureImportanceClassName;
importance: {
max: number;
min: number;
Expand All @@ -52,6 +54,22 @@ export type TotalFeatureImportance =
| ClassificationTotalFeatureImportance
| RegressionTotalFeatureImportance;

export interface FeatureImportanceClassBaseline {
class_name: FeatureImportanceClassName;
baseline: number;
}
export interface ClassificationFeatureImportanceBaseline {
classes: FeatureImportanceClassBaseline[];
}

export interface RegressionFeatureImportanceBaseline {
baseline: number;
}

export type FeatureImportanceBaseline =
| ClassificationFeatureImportanceBaseline
| RegressionFeatureImportanceBaseline;

export function isClassificationTotalFeatureImportance(
summary: ClassificationTotalFeatureImportance | RegressionTotalFeatureImportance
): summary is ClassificationTotalFeatureImportance {
Expand All @@ -63,3 +81,19 @@ export function isRegressionTotalFeatureImportance(
): summary is RegressionTotalFeatureImportance {
return (summary as RegressionTotalFeatureImportance).importance !== undefined;
}

export function isClassificationFeatureImportanceBaseline(
baselineData: any
): baselineData is ClassificationFeatureImportanceBaseline {
return (
typeof baselineData === 'object' &&
baselineData.hasOwnProperty('classes') &&
Array.isArray(baselineData.classes)
);
}

export function isRegressionFeatureImportanceBaseline(
baselineData: any
): baselineData is RegressionFeatureImportanceBaseline {
return typeof baselineData === 'object' && baselineData.hasOwnProperty('baseline');
}
3 changes: 2 additions & 1 deletion x-pack/plugins/ml/common/types/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

import { DataFrameAnalyticsConfig } from './data_frame_analytics';
import { TotalFeatureImportance } from './feature_importance';
import { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance';

export interface IngestStats {
count: number;
Expand Down Expand Up @@ -56,6 +56,7 @@ export interface TrainedModelConfigResponse {
analytics_config: DataFrameAnalyticsConfig;
input: any;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
}
| Record<string, any>;
model_id: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ import {
} from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
import {
FeatureImportanceBaseline,
FeatureImportance,
TopClasses,
} from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';

Expand All @@ -50,7 +54,7 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
);

interface PropsWithoutHeader extends UseIndexDataReturnType {
baseline?: number;
baseline?: FeatureImportanceBaseline;
analysisType?: DataFrameAnalysisConfigType | 'unknown';
resultsField?: string;
dataTestSubj: string;
Expand Down Expand Up @@ -124,6 +128,7 @@ export const DataGrid: FC<Props> = memo(
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
let predictedValue: string | number | undefined;
let predictedProbability: number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
Expand All @@ -132,6 +137,7 @@ export const DataGrid: FC<Props> = memo(
) {
predictedValue = row[`${mlResultsField}.${predictionFieldName}`];
topClasses = getTopClasses(row, mlResultsField);
predictedProbability = row[`${mlResultsField}.prediction_probability`];
}

const isClassTypeBoolean = topClasses.reduce(
Expand All @@ -149,6 +155,7 @@ export const DataGrid: FC<Props> = memo(
<DecisionPathPopover
analysisType={analysisType}
predictedValue={predictedValue}
predictedProbability={predictedProbability}
baseline={baseline}
featureImportance={parsedFIArray}
topClasses={topClasses}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ import { i18n } from '@kbn/i18n';
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
import { DecisionPathPlotData } from './use_classification_path_data';
import { formatSingleValue } from '../../../formatters/format_value';

import {
FeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
} from '../../../../../common/types/feature_importance';
const { euiColorFullShade, euiColorMediumShade } = euiVars;
const axisColor = euiColorMediumShade;

Expand Down Expand Up @@ -72,10 +75,9 @@ const theme: PartialTheme = {
interface DecisionPathChartProps {
decisionPathData: DecisionPathPlotData;
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
minDomain: number | undefined;
maxDomain: number | undefined;
showValues?: boolean;
}

const DECISION_PATH_MARGIN = 125;
Expand All @@ -88,38 +90,37 @@ export const DecisionPathChart = ({
minDomain,
maxDomain,
baseline,
showValues,
}: DecisionPathChartProps) => {
// adjust the height so it's compact for items with more features
const baselineData: LineAnnotationDatum[] = useMemo(
() => [
{
dataValue: baseline,
header: baseline ? formatSingleValue(baseline).toString() : '',
details: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
{
defaultMessage:
'baseline (average of predictions for all data points in the training data set)',
}
),
},
],
const baselineData: LineAnnotationDatum[] | undefined = useMemo(
() =>
baseline && isRegressionFeatureImportanceBaseline(baseline)
? [
{
dataValue: baseline.baseline,
header: formatSingleValue(baseline.baseline, '').toString(),
details: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
{
defaultMessage:
'baseline (average of predictions for all data points in the training data set)',
}
),
},
]
: undefined,
[baseline]
);
// if regression, guarantee up to num_precision significant digits without having it in scientific notation
// if classification, hide the numeric values since we only want to show the path
const tickFormatter = useCallback(
(d) => (showValues === false ? '' : formatSingleValue(d).toString()),
[]
);
const tickFormatter = useCallback((d) => formatSingleValue(d, '').toString(), []);

return (
<Chart
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
>
<Settings theme={theme} rotation={90} />
{baseline && (
{baselineData && (
<LineAnnotation
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
domainType={AnnotationDomainTypes.YDomain}
Expand All @@ -132,7 +133,6 @@ export const DecisionPathChart = ({
<Axis
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
tickFormat={tickFormatter}
ticks={showValues === false ? 0 : undefined}
title={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,39 @@ import {
useDecisionPathData,
getStringBasedClassName,
} from './use_classification_path_data';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';

interface ClassificationDecisionPathProps {
predictedValue: string | boolean;
predictedProbability: number | undefined;
predictionFieldName?: string;
featureImportance: FeatureImportance[];
topClasses: TopClasses;
baseline?: FeatureImportanceBaseline;
}

export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = ({
featureImportance,
predictedValue,
topClasses,
predictionFieldName,
predictedProbability,
baseline,
}) => {
const [currentClass, setCurrentClass] = useState<string>(
getStringBasedClassName(topClasses[0].class_name)
);
const { decisionPathData } = useDecisionPathData({
baseline,
featureImportance,
predictedValue: currentClass,
predictedProbability,
});
const options = useMemo(() => {
const predictionValueStr = getStringBasedClassName(predictedValue);
Expand Down Expand Up @@ -99,7 +109,6 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
predictionFieldName={predictionFieldName}
minDomain={domain.minDomain}
maxDomain={domain.maxDomain}
showValues={false}
/>
</>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,26 @@ import { EuiLink, EuiTab, EuiTabs, EuiText } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import { RegressionDecisionPath } from './decision_path_regression';
import { DecisionPathJSONViewer } from './decision_path_json_viewer';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
isClassificationFeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common';
import { ClassificationDecisionPath } from './decision_path_classification';
import { useMlKibana } from '../../../contexts/kibana';
import { DataFrameAnalysisConfigType } from '../../../../../common/types/data_frame_analytics';
import { getStringBasedClassName } from './use_classification_path_data';

interface DecisionPathPopoverProps {
featureImportance: FeatureImportance[];
analysisType: DataFrameAnalysisConfigType;
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictedValue?: number | string | undefined;
predictedProbability?: number; // for classification
topClasses?: TopClasses;
}

Expand All @@ -30,7 +38,7 @@ enum DECISION_PATH_TABS {
}

export interface ExtendedFeatureImportance extends FeatureImportance {
absImportance?: number;
absImportance: number;
}

export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
Expand All @@ -40,6 +48,7 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
topClasses,
analysisType,
predictionFieldName,
predictedProbability,
}) => {
const [selectedTabId, setSelectedTabId] = useState(DECISION_PATH_TABS.CHART);
const {
Expand Down Expand Up @@ -109,22 +118,29 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
}}
/>
</EuiText>
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && (
<ClassificationDecisionPath
featureImportance={featureImportance}
topClasses={topClasses as TopClasses}
predictedValue={predictedValue as string}
predictionFieldName={predictionFieldName}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION && (
<RegressionDecisionPath
featureImportance={featureImportance}
baseline={baseline}
predictedValue={predictedValue as number}
predictionFieldName={predictionFieldName}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION &&
isClassificationFeatureImportanceBaseline(baseline) && (
<ClassificationDecisionPath
featureImportance={featureImportance}
topClasses={topClasses as TopClasses}
predictedValue={getStringBasedClassName(predictedValue)}
predictedProbability={predictedProbability}
predictionFieldName={predictionFieldName}
baseline={baseline}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION &&
isRegressionFeatureImportanceBaseline(baseline) &&
predictedValue !== undefined && (
<RegressionDecisionPath
featureImportance={featureImportance}
baseline={baseline}
predictedValue={
typeof predictedValue === 'string' ? parseFloat(predictedValue) : predictedValue
}
predictionFieldName={predictionFieldName}
/>
)}
</>
)}
{selectedTabId === DECISION_PATH_TABS.JSON && (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ import React, { FC, useMemo } from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import d3 from 'd3';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';

interface RegressionDecisionPathProps {
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictedValue?: number | undefined;
featureImportance: FeatureImportance[];
topClasses?: TopClasses;
Expand Down
Loading

0 comments on commit e937618

Please sign in to comment.