Skip to content

Commit

Permalink
[ML] Minor refactoring
Browse files Browse the repository at this point in the history
- useMemo for decision path data
- use dest.results_field instead of hardcoded .ml
- predictionfieldName uses default name if not available
- use d3.extent instead of custom max min func
- return eui callout if decision path data for some reason is not available
- new isDecisionPathData()
- tickFormatter doesn't need template literal
  • Loading branch information
qn895 committed Aug 24, 2020
1 parent 172a58e commit ef815d1
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 84 deletions.
33 changes: 33 additions & 0 deletions x-pack/plugins/ml/common/util/analytics_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysi
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
};

export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';

if (isRegressionAnalysis(analysis)) {
depVar = analysis.regression.dependent_variable;
}

if (isClassificationAnalysis(analysis)) {
depVar = analysis.classification.dependent_variable;
}
return depVar;
};

export const getPredictionFieldName = (
analysis: AnalysisConfig
):
Expand All @@ -44,3 +61,19 @@ export const getPredictionFieldName = (
}
return predictionFieldName;
};

export const getDefaultPredictionFieldName = (analysis: AnalysisConfig) => {
return `${getDependentVar(analysis)}_prediction`;
};
export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
forSort?: boolean
) => {
// default is 'ml'
const predictionFieldName = getPredictionFieldName(analysis);
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(analysis)
}`;
return predictedField;
};
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
interface PropsWithoutHeader extends UseIndexDataReturnType {
baseline?: number;
analysisType?: ANALYSIS_CONFIG_TYPE;
resultsField?: string;
dataTestSubj: string;
toastNotifications: CoreSetup['notifications']['toasts'];
}
Expand Down Expand Up @@ -85,6 +86,7 @@ export const DataGrid: FC<Props> = memo(
toggleChartVisibility,
visibleColumns,
predictionFieldName,
resultsField,
analysisType,
} = props;
// TODO Fix row hovering + bar highlighting
Expand All @@ -103,16 +105,18 @@ export const DataGrid: FC<Props> = memo(
const rowIndex = children?.props?.visibleRowIndex;
const row = data[rowIndex];
if (!row) return <div />;
const parsedFIArray = row.ml.feature_importance;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? 'ml';

This comment has been minimized.

Copy link
@walterra

walterra Sep 7, 2020

Contributor

We have DEFAULT_RESULTS_FIELD for this in plugins/ml/public/application/data_frame_analytics/common/constants.ts

const parsedFIArray = row[mlResultsField].feature_importance;
let predictedValue: string | number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
row &&
row.ml[predictionFieldName] !== undefined
row[mlResultsField][predictionFieldName] !== undefined
) {
predictedValue = row.ml[predictionFieldName];
topClasses = row.ml.top_classes;
predictedValue = row[mlResultsField][predictionFieldName];
topClasses = row[mlResultsField].top_classes;
}

return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
* you may not use this file except in compliance with the Elastic License.
*/

// adjust the height so it's compact for items with more features
import {
AnnotationDomainTypes,
Axis,
Expand Down Expand Up @@ -64,6 +63,7 @@ export const DecisionPathChart = ({
maxDomain,
baseline,
}: DecisionPathChartProps) => {
// adjust the height so it's compact for items with more features
const heightMultiplier = Array.isArray(decisionPathData) && decisionPathData.length > 4 ? 30 : 75;
const baselineData: LineAnnotationDatum[] = useMemo(
() => [
Expand All @@ -80,7 +80,7 @@ export const DecisionPathChart = ({
],
[baseline]
);
const tickFormatter = useCallback((d) => `${Number(d).toPrecision(3)}`, []);
const tickFormatter = useCallback((d) => Number(d).toPrecision(3), []);

return (
<Chart size={{ height: decisionPathData.length * heightMultiplier }}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
import React, { FC, useMemo, useState } from 'react';
import { i18n } from '@kbn/i18n';
import { EuiHealth, EuiSpacer, EuiSuperSelect, EuiTitle } from '@elastic/eui';
import { findMaxMin, useDecisionPathData } from './use_classification_path_data';
import d3 from 'd3';
import { isDecisionPathData, useDecisionPathData } from './use_classification_path_data';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';

interface ClassificationDecisionPathProps {
predictedValue: string | undefined;
predictionFieldName?: string;
Expand Down Expand Up @@ -49,20 +52,16 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
let maxDomain;
let minDomain;
// if decisionPathData has calculated cumulative path
if (
Array.isArray(decisionPathData) &&
decisionPathData.length > 0 &&
decisionPathData[0].length === 3
) {
const { max, min } = findMaxMin(decisionPathData, (d: [string, number, number]) => d[2]);
if (decisionPathData && isDecisionPathData(decisionPathData)) {
const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]);
const buffer = Math.abs(max - min) * 0.1;
maxDomain = max + buffer;
minDomain = min - buffer;
}
return { maxDomain, minDomain };
}, [decisionPathData]);

if (!decisionPathData) return <div />;
if (!decisionPathData) return <MissingDecisionPathCallout />;

return (
<>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import React, { FC, useMemo } from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import d3 from 'd3';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { findMaxMin, useDecisionPathData } from './use_classification_path_data';
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';

interface RegressionDecisionPathProps {
predictionFieldName?: string;
Expand All @@ -34,12 +36,8 @@ export const RegressionDecisionPath: FC<RegressionDecisionPathProps> = ({
let maxDomain;
let minDomain;
// if decisionPathData has calculated cumulative path
if (
Array.isArray(decisionPathData) &&
decisionPathData.length > 0 &&
decisionPathData[0].length === 3
) {
const { max, min } = findMaxMin(decisionPathData, (d: [string, number, number]) => d[2]);
if (decisionPathData && isDecisionPathData(decisionPathData)) {
const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]);
maxDomain = max;
minDomain = min;
const buffer = Math.abs(maxDomain - minDomain) * 0.1;
Expand All @@ -51,7 +49,7 @@ export const RegressionDecisionPath: FC<RegressionDecisionPathProps> = ({
return { maxDomain, minDomain };
}, [decisionPathData, baseline]);

if (!decisionPathData) return <div />;
if (!decisionPathData) return <MissingDecisionPathCallout />;

return (
<>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

import React from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';

export const MissingDecisionPathCallout = () => {
return (
<EuiCallOut color={'warning'}>
<FormattedMessage
id="xpack.ml.dataframe.analytics.explorationResults.regressionDecisionPathDataMissingCallout"
defaultMessage="No decision path data available."
/>
</EuiCallOut>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/

import { useEffect, useState } from 'react';
import { useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { ExtendedFeatureImportance } from './decision_path_popover';
Expand All @@ -27,15 +27,20 @@ interface RegressionDecisionPathProps {
const FEATURE_NAME = 'feature_name';
const FEATURE_IMPORTANCE = 'importance';

export const isDecisionPathData = (decisionPathData: any): boolean => {
return (
Array.isArray(decisionPathData) &&
decisionPathData.length > 0 &&
decisionPathData[0].length === 3
);
};
export const useDecisionPathData = ({
baseline,
featureImportance,
predictedValue,
}: UseDecisionPathDataParams): { decisionPathData: DecisionPathPlotData | undefined } => {
const [decisionPathData, setDecisionPlotData] = useState<DecisionPathPlotData | undefined>();

useEffect(() => {
const result = baseline
const decisionPathData = useMemo(() => {
return baseline
? buildRegressionDecisionPathData({
baseline,
featureImportance,
Expand All @@ -45,8 +50,6 @@ export const useDecisionPathData = ({
featureImportance,
currentClass: predictedValue as string | undefined,
});

setDecisionPlotData(result);
}, [baseline, featureImportance, predictedValue]);

return { decisionPathData };
Expand Down Expand Up @@ -153,17 +156,3 @@ export const buildClassificationDecisionPathData = ({

return buildDecisionPathData(filteredFeatureImportance);
};

export const findMaxMin = (
data: DecisionPathPlotData,
getter: Function
): { max: number; min: number } => {
let min = Infinity;
let max = -Infinity;
data.forEach((d) => {
const value = getter(d);
if (value > max) max = value;
if (value < min) min = value;
});
return { max, min };
};
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ export interface UseIndexDataReturnType
| 'visibleColumns'
| 'baseline'
| 'predictionFieldName'
| 'resultsField'
> {
renderCellValue: RenderCellValue;
}
Expand Down Expand Up @@ -109,4 +110,5 @@ export interface UseDataGridReturnType {
visibleColumns: ColumnId[];
baseline?: number;
predictionFieldName?: string;
resultsField?: string;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import {
isRegressionAnalysis,
isClassificationAnalysis,
getPredictionFieldName,
getDependentVar,
getPredictedFieldName,
} from '../../../../common/util/analytics_utils';
export type IndexPattern = string;

Expand Down Expand Up @@ -155,23 +157,6 @@ export const getAnalysisType = (analysis: AnalysisConfig): string => {
return 'unknown';
};

export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';

if (isRegressionAnalysis(analysis)) {
depVar = analysis.regression.dependent_variable;
}

if (isClassificationAnalysis(analysis)) {
depVar = analysis.classification.dependent_variable;
}
return depVar;
};

export const getTrainingPercent = (
analysis: AnalysisConfig
):
Expand Down Expand Up @@ -219,20 +204,6 @@ export const getNumTopFeatureImportanceValues = (
return numTopFeatureImportanceValues;
};

export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
forSort?: boolean
) => {
// default is 'ml'
const predictionFieldName = getPredictionFieldName(analysis);
const defaultPredictionField = `${getDependentVar(analysis)}_prediction`;
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;
return predictedField;
};

export const isResultsSearchBoolQuery = (arg: any): arg is ResultsSearchBoolQuery => {
if (arg === undefined) return false;
const keys = Object.keys(arg);
Expand Down Expand Up @@ -580,4 +551,6 @@ export {
isClassificationAnalysis,
getPredictionFieldName,
ANALYSIS_CONFIG_TYPE,
getDependentVar,
getPredictedFieldName,
};
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
* you may not use this file except in compliance with the Elastic License.
*/

import { getNumTopClasses, getNumTopFeatureImportanceValues } from './analytics';
import { Field } from '../../../../common/types/fields';
import {
getNumTopClasses,
getNumTopFeatureImportanceValues,
getPredictedFieldName,
getDependentVar,
getPredictionFieldName,
isClassificationAnalysis,
isOutlierAnalysis,
isRegressionAnalysis,
} from './analytics';
import { Field } from '../../../../common/types/fields';
} from '../../../../common/util/analytics_utils';
import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { newJobCapsService } from '../../services/new_job_capabilities_service';

Expand Down
Loading

0 comments on commit ef815d1

Please sign in to comment.