Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add probability values in decision path visualization for classification data frame analytics #80229

Merged
merged 33 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4ff0a24
[ML] Update baseline calc to use trained_models info
qn895 Oct 12, 2020
b615eac
[ML] Update baseline calc for classification
qn895 Oct 12, 2020
d0d807b
[ML] Update baseline calc for multi-class
qn895 Oct 12, 2020
ae105a5
[ML] Safeguard for hypothetically when for some reasons there's only …
qn895 Oct 12, 2020
218ddad
[ML] Remove now unused analyticsFeatureImportanceProvider
qn895 Oct 12, 2020
dfba16d
[ML] Update proper type for InferenceQueryParams.include
qn895 Oct 12, 2020
8121fe6
[ML] Fix results inconsistent for multi class due to different types
qn895 Oct 13, 2020
01a1961
[ML] Add unit test
qn895 Oct 13, 2020
9c6f79a
[ML] Add unit test
qn895 Oct 13, 2020
1f63405
Merge remote-tracking branch 'upstream/master' into ml-new-baseline-path
qn895 Oct 13, 2020
59b43e8
[ML] Add unit test
qn895 Oct 13, 2020
8cc0fcc
[ML] Change to using formatSingleValue instead
qn895 Oct 13, 2020
5452e04
Merge remote-tracking branch 'upstream/master' into ml-new-baseline-path
qn895 Oct 14, 2020
e338028
[ML] Fix missing baseline
qn895 Oct 14, 2020
0d43239
Merge remote-tracking branch 'upstream/master' into ml-new-baseline-path
qn895 Oct 15, 2020
3ca8efc
[ML] Remove duplicate formatSingleValue
qn895 Oct 15, 2020
55ab819
Merge remote-tracking branch 'upstream/master' into ml-new-baseline-path
qn895 Oct 18, 2020
10eae2d
[ML] Add type FeatureImportClassName
qn895 Oct 18, 2020
e5f33b3
[ML] Remove !
qn895 Oct 19, 2020
7526127
[ML] Rename functions to start with process for clarity
qn895 Oct 19, 2020
b123e61
[ML] Add extra other row to binary classification if num features > n…
qn895 Oct 22, 2020
fd4672d
Merge remote-tracking branch 'upstream/master' into ml-new-baseline-path
qn895 Oct 22, 2020
ed9376e
Merge branch 'master' into ml-new-baseline-path
kibanamachine Oct 27, 2020
ed538ff
[ML] Fix broken feature importance fields
qn895 Oct 28, 2020
5459fbf
Merge remote-tracking branch 'upstream/master' into ml-new-baseline-path
qn895 Oct 29, 2020
56764d4
Merge upstream/master into origin/ml-new-baseline-path
qn895 Nov 2, 2020
191991e
[ML] Adjust for multiclass
qn895 Nov 2, 2020
631fa87
[ML] Fix typo in import type
qn895 Nov 2, 2020
ecf321e
[ML] Rename FeatureImportanceClassName
qn895 Nov 2, 2020
19ff877
[ML] Fix FeatureImportanceClassName
qn895 Nov 2, 2020
845339e
[ML] Fix fi broken if result is an array with only one element
qn895 Nov 2, 2020
71e8549
Merge upstream/master into ml-new-baseline-path
qn895 Nov 3, 2020
84e8ed1
[ML] Remove analyticsFeatureImportanceProvider
qn895 Nov 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion x-pack/plugins/ml/common/types/feature_importance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

export interface ClassFeatureImportance {
class_name: string | boolean;
class_name: string | number | boolean;
darnautov marked this conversation as resolved.
Show resolved Hide resolved
importance: number;
}
export interface FeatureImportance {
Expand Down Expand Up @@ -49,6 +49,22 @@ export type TotalFeatureImportance =
| ClassificationTotalFeatureImportance
| RegressionTotalFeatureImportance;

export interface FeatureImportanceClassBaseline {
class_name: string | number | boolean;
baseline: number;
}
export interface ClassificationFeatureImportanceBaseline {
classes: FeatureImportanceClassBaseline[];
}

export interface RegressionFeatureImportanceBaseline {
baseline: number;
}

export type FeatureImportanceBaseline =
| ClassificationFeatureImportanceBaseline
| RegressionFeatureImportanceBaseline;

export function isClassificationTotalFeatureImportance(
summary: ClassificationTotalFeatureImportance | RegressionTotalFeatureImportance
): summary is ClassificationTotalFeatureImportance {
Expand All @@ -60,3 +76,15 @@ export function isRegressionTotalFeatureImportance(
): summary is RegressionTotalFeatureImportance {
return (summary as RegressionTotalFeatureImportance).importance !== undefined;
}

export function isClassificationFeatureImportanceBaseline(
baselineData: ClassificationFeatureImportanceBaseline | RegressionFeatureImportanceBaseline
): baselineData is ClassificationFeatureImportanceBaseline {
return (baselineData as ClassificationFeatureImportanceBaseline).classes !== undefined;
darnautov marked this conversation as resolved.
Show resolved Hide resolved
}

export function isRegressionFeatureImportanceBaseline(
darnautov marked this conversation as resolved.
Show resolved Hide resolved
baselineData: ClassificationFeatureImportanceBaseline | RegressionFeatureImportanceBaseline
): baselineData is RegressionFeatureImportanceBaseline {
return (baselineData as RegressionFeatureImportanceBaseline).baseline !== undefined;
}
3 changes: 2 additions & 1 deletion x-pack/plugins/ml/common/types/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

import { DataFrameAnalyticsConfig } from './data_frame_analytics';
import { TotalFeatureImportance } from './feature_importance';
import { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance';

export interface IngestStats {
count: number;
Expand Down Expand Up @@ -56,6 +56,7 @@ export interface TrainedModelConfigResponse {
analytics_config: DataFrameAnalyticsConfig;
input: any;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
}
| Record<string, any>;
model_id: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/c
import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { TopClasses } from '../../../../common/types/feature_importance';
import { FeatureImportanceBaseline, TopClasses } from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';

Expand All @@ -45,7 +45,7 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
);

interface PropsWithoutHeader extends UseIndexDataReturnType {
baseline?: number;
baseline?: FeatureImportanceBaseline;
analysisType?: DataFrameAnalysisConfigType | 'unknown';
resultsField?: string;
dataTestSubj: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ import React, { useCallback, useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
import { DecisionPathPlotData } from './use_classification_path_data';
import {
FeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
} from '../../../../../common/types/feature_importance';
import { formatSingleValue } from '../../../formatters/format_value';

const { euiColorFullShade, euiColorMediumShade } = euiVars;
const axisColor = euiColorMediumShade;
Expand Down Expand Up @@ -71,15 +76,13 @@ const theme: PartialTheme = {
interface DecisionPathChartProps {
decisionPathData: DecisionPathPlotData;
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
minDomain: number | undefined;
maxDomain: number | undefined;
showValues?: boolean;
}

const DECISION_PATH_MARGIN = 125;
const DECISION_PATH_ROW_HEIGHT = 10;
const NUM_PRECISION = 3;
const AnnotationBaselineMarker = <EuiIcon type="dot" size="m" />;

export const DecisionPathChart = ({
Expand All @@ -88,38 +91,37 @@ export const DecisionPathChart = ({
minDomain,
maxDomain,
baseline,
showValues,
}: DecisionPathChartProps) => {
// adjust the height so it's compact for items with more features
const baselineData: LineAnnotationDatum[] = useMemo(
() => [
{
dataValue: baseline,
header: baseline ? baseline.toPrecision(NUM_PRECISION) : '',
details: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
{
defaultMessage:
'baseline (average of predictions for all data points in the training data set)',
}
),
},
],
const baselineData: LineAnnotationDatum[] | undefined = useMemo(
() =>
baseline && isRegressionFeatureImportanceBaseline(baseline)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With Dima's suggestion making the type guard accept any, the baseline && part here might then no longer be necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated here 10eae2d

? [
{
dataValue: baseline.baseline,
header: formatSingleValue(baseline.baseline, '').toString(),
details: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
{
defaultMessage:
'baseline (average of predictions for all data points in the training data set)',
}
),
},
]
: undefined,
[baseline]
);
// if regression, guarantee up to num_precision significant digits without having it in scientific notation
// if classification, hide the numeric values since we only want to show the path
const tickFormatter = useCallback(
(d) => (showValues === false ? '' : Number(d.toPrecision(NUM_PRECISION)).toString()),
[]
);
const tickFormatter = useCallback((d) => formatSingleValue(d, '').toString(), []);

return (
<Chart
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
>
<Settings theme={theme} rotation={90} />
{baseline && (
{baselineData && (
<LineAnnotation
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
domainType={AnnotationDomainTypes.YDomain}
Expand All @@ -132,7 +134,6 @@ export const DecisionPathChart = ({
<Axis
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
tickFormat={tickFormatter}
ticks={showValues === false ? 0 : undefined}
title={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ import {
useDecisionPathData,
getStringBasedClassName,
} from './use_classification_path_data';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';

Expand All @@ -22,18 +26,21 @@ interface ClassificationDecisionPathProps {
predictionFieldName?: string;
featureImportance: FeatureImportance[];
topClasses: TopClasses;
baseline?: FeatureImportanceBaseline;
}

export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = ({
featureImportance,
predictedValue,
topClasses,
predictionFieldName,
baseline,
}) => {
const [currentClass, setCurrentClass] = useState<string>(
getStringBasedClassName(topClasses[0].class_name)
);
const { decisionPathData } = useDecisionPathData({
baseline,
featureImportance,
predictedValue: currentClass,
});
Expand Down Expand Up @@ -99,7 +106,6 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
predictionFieldName={predictionFieldName}
minDomain={domain.minDomain}
maxDomain={domain.maxDomain}
showValues={false}
/>
</>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@ import { EuiLink, EuiTab, EuiTabs, EuiText } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import { RegressionDecisionPath } from './decision_path_regression';
import { DecisionPathJSONViewer } from './decision_path_json_viewer';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
isClassificationFeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common';
import { ClassificationDecisionPath } from './decision_path_classification';
import { useMlKibana } from '../../../contexts/kibana';
Expand All @@ -19,7 +25,7 @@ interface DecisionPathPopoverProps {
featureImportance: FeatureImportance[];
analysisType: DataFrameAnalysisConfigType;
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictedValue?: number | string | undefined;
topClasses?: TopClasses;
}
Expand Down Expand Up @@ -109,22 +115,27 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
}}
/>
</EuiText>
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && (
<ClassificationDecisionPath
featureImportance={featureImportance}
topClasses={topClasses as TopClasses}
predictedValue={predictedValue as string}
predictionFieldName={predictionFieldName}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION && (
<RegressionDecisionPath
featureImportance={featureImportance}
baseline={baseline}
predictedValue={predictedValue as number}
predictionFieldName={predictionFieldName}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION &&
baseline &&
darnautov marked this conversation as resolved.
Show resolved Hide resolved
isClassificationFeatureImportanceBaseline(baseline) && (
<ClassificationDecisionPath
featureImportance={featureImportance}
topClasses={topClasses as TopClasses}
predictedValue={predictedValue as string}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of typecasting you should probably convert to string

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated here 10eae2d

predictionFieldName={predictionFieldName}
baseline={baseline}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION &&
baseline &&
darnautov marked this conversation as resolved.
Show resolved Hide resolved
isRegressionFeatureImportanceBaseline(baseline) && (
<RegressionDecisionPath
featureImportance={featureImportance}
baseline={baseline}
predictedValue={predictedValue as number}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of typecasting you should probably convert to number

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated here 10eae2d

predictionFieldName={predictionFieldName}
/>
)}
</>
)}
{selectedTabId === DECISION_PATH_TABS.JSON && (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ import React, { FC, useMemo } from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import d3 from 'd3';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';

interface RegressionDecisionPathProps {
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictedValue?: number | undefined;
featureImportance: FeatureImportance[];
topClasses?: TopClasses;
Expand Down
Loading