Skip to content

Commit

Permalink
[ML] DF Analytics: create classification jobs results view (elastic#5…
Browse files Browse the repository at this point in the history
…2584)

* wip: create classification results page + table and evaluate panel

* enable view link for classification jobs

* wip: fetch classification eval data

* wip: display confusion matrix in datagrid

* evaluate panel: add heatmap for cells and doc count

* Update use of loadEvalData in expanded row component

* Add metric type for evaluate endpoint and fix localization error

* handle no incorrect prediction classes case for confusion matrix. remove unused translation

* setCellProps needs to be called from a lifecycle method - wrap in useEffect

* TypeScript improvements

* fix datagrid column resize affecting results table

* allow custom prediction field for classification jobs

* ensure values are rounded correctly and add tooltip

* temp workaroun for datagrid width issues
  • Loading branch information
alvarezmelissa87 authored Dec 12, 2019
1 parent 79a8528 commit 0cd5bb0
Show file tree
Hide file tree
Showing 25 changed files with 1,514 additions and 116 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@import 'pages/analytics_exploration/components/exploration/index';
@import 'pages/analytics_exploration/components/regression_exploration/index';
@import 'pages/analytics_exploration/components/classification_exploration/index';
@import 'pages/analytics_management/components/analytics_list/index';
@import 'pages/analytics_management/components/create_analytics_form/index';
@import 'pages/analytics_management/components/create_analytics_flyout/index';
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ interface ClassificationAnalysis {
dependent_variable: string;
training_percent?: number;
num_top_classes?: string;
prediction_field_name?: string;
};
}

Expand Down Expand Up @@ -74,13 +75,33 @@ export interface RegressionEvaluateResponse {
};
}

export interface PredictedClass {
predicted_class: string;
count: number;
}

export interface ConfusionMatrix {
actual_class: string;
actual_class_doc_count: number;
predicted_classes: PredictedClass[];
other_predicted_class_doc_count: number;
}

export interface ClassificationEvaluateResponse {
classification: {
multiclass_confusion_matrix: {
confusion_matrix: ConfusionMatrix[];
};
};
}

interface GenericAnalysis {
[key: string]: Record<string, any>;
}

interface LoadEvaluateResult {
success: boolean;
eval: RegressionEvaluateResponse | null;
eval: RegressionEvaluateResponse | ClassificationEvaluateResponse | null;
error: string | null;
}

Expand Down Expand Up @@ -109,6 +130,7 @@ export const getAnalysisType = (analysis: AnalysisConfig) => {

export const getDependentVar = (analysis: AnalysisConfig) => {
let depVar = '';

if (isRegressionAnalysis(analysis)) {
depVar = analysis.regression.dependent_variable;
}
Expand All @@ -124,17 +146,26 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => {
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
predictionFieldName = analysis.regression.prediction_field_name;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.prediction_field_name !== undefined
) {
predictionFieldName = analysis.classification.prediction_field_name;
}
return predictionFieldName;
};

export const getPredictedFieldName = (resultsField: string, analysis: AnalysisConfig) => {
export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
forSort?: boolean
) => {
// default is 'ml'
const predictionFieldName = getPredictionFieldName(analysis);
const defaultPredictionField = `${getDependentVar(analysis)}_prediction`;
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;
}${isClassificationAnalysis(analysis) && !forSort ? '.keyword' : ''}`;
return predictedField;
};

Expand All @@ -153,13 +184,32 @@ export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysi
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
};

export const isRegressionResultsSearchBoolQuery = (
arg: any
): arg is RegressionResultsSearchBoolQuery => {
export const isResultsSearchBoolQuery = (arg: any): arg is ResultsSearchBoolQuery => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === 'bool';
};

export const isRegressionEvaluateResponse = (arg: any): arg is RegressionEvaluateResponse => {
const keys = Object.keys(arg);
return (
keys.length === 1 &&
keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION &&
arg?.regression?.mean_squared_error !== undefined &&
arg?.regression?.r_squared !== undefined
);
};

export const isClassificationEvaluateResponse = (
arg: any
): arg is ClassificationEvaluateResponse => {
const keys = Object.keys(arg);
return (
keys.length === 1 &&
keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION &&
arg?.classification?.multiclass_confusion_matrix !== undefined
);
};

export interface DataFrameAnalyticsConfig {
id: DataFrameAnalyticsId;
// Description attribute is not supported yet
Expand Down Expand Up @@ -254,17 +304,14 @@ export function getValuesFromResponse(response: RegressionEvaluateResponse) {

return { meanSquaredError, rSquared };
}
interface RegressionResultsSearchBoolQuery {
interface ResultsSearchBoolQuery {
bool: Dictionary<any>;
}
interface RegressionResultsSearchTermQuery {
interface ResultsSearchTermQuery {
term: Dictionary<any>;
}

export type RegressionResultsSearchQuery =
| RegressionResultsSearchBoolQuery
| RegressionResultsSearchTermQuery
| SavedSearchQuery;
export type ResultsSearchQuery = ResultsSearchBoolQuery | ResultsSearchTermQuery | SavedSearchQuery;

export function getEvalQueryBody({
resultsField,
Expand All @@ -274,23 +321,44 @@ export function getEvalQueryBody({
}: {
resultsField: string;
isTraining: boolean;
searchQuery?: RegressionResultsSearchQuery;
searchQuery?: ResultsSearchQuery;
ignoreDefaultQuery?: boolean;
}) {
let query: RegressionResultsSearchQuery = {
let query: ResultsSearchQuery = {
term: { [`${resultsField}.is_training`]: { value: isTraining } },
};

if (searchQuery !== undefined && ignoreDefaultQuery === true) {
query = searchQuery;
} else if (searchQuery !== undefined && isRegressionResultsSearchBoolQuery(searchQuery)) {
} else if (searchQuery !== undefined && isResultsSearchBoolQuery(searchQuery)) {
const searchQueryClone = cloneDeep(searchQuery);
searchQueryClone.bool.must.push(query);
query = searchQueryClone;
}
return query;
}

interface EvaluateMetrics {
classification: {
multiclass_confusion_matrix: object;
};
regression: {
r_squared: object;
mean_squared_error: object;
};
}

interface LoadEvalDataConfig {
isTraining: boolean;
index: string;
dependentVariable: string;
resultsField: string;
predictionFieldName?: string;
searchQuery?: ResultsSearchQuery;
ignoreDefaultQuery?: boolean;
jobType: ANALYSIS_CONFIG_TYPE;
}

export const loadEvalData = async ({
isTraining,
index,
Expand All @@ -299,34 +367,38 @@ export const loadEvalData = async ({
predictionFieldName,
searchQuery,
ignoreDefaultQuery,
}: {
isTraining: boolean;
index: string;
dependentVariable: string;
resultsField: string;
predictionFieldName?: string;
searchQuery?: RegressionResultsSearchQuery;
ignoreDefaultQuery?: boolean;
}) => {
jobType,
}: LoadEvalDataConfig) => {
const results: LoadEvaluateResult = { success: false, eval: null, error: null };
const defaultPredictionField = `${dependentVariable}_prediction`;
const predictedField = `${resultsField}.${
let predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;

if (jobType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION) {
predictedField = `${predictedField}.keyword`;
}

const query = getEvalQueryBody({ resultsField, isTraining, searchQuery, ignoreDefaultQuery });

const metrics: EvaluateMetrics = {
classification: {
multiclass_confusion_matrix: {},
},
regression: {
r_squared: {},
mean_squared_error: {},
},
};

const config = {
index,
query,
evaluation: {
regression: {
[jobType]: {
actual_field: dependentVariable,
predicted_field: predictedField,
metrics: {
r_squared: {},
mean_squared_error: {},
},
metrics: metrics[jobType as keyof EvaluateMetrics],
},
},
};
Expand All @@ -341,3 +413,57 @@ export const loadEvalData = async ({
return results;
}
};

interface TrackTotalHitsSearchResponse {
hits: {
total: {
value: number;
relation: string;
};
hits: any[];
};
}

interface LoadDocsCountConfig {
ignoreDefaultQuery?: boolean;
isTraining: boolean;
searchQuery: SavedSearchQuery;
resultsField: string;
destIndex: string;
}

interface LoadDocsCountResponse {
docsCount: number | null;
success: boolean;
}

export const loadDocsCount = async ({
ignoreDefaultQuery = true,
isTraining,
searchQuery,
resultsField,
destIndex,
}: LoadDocsCountConfig): Promise<LoadDocsCountResponse> => {
const query = getEvalQueryBody({ resultsField, isTraining, ignoreDefaultQuery, searchQuery });

try {
const body: SearchQuery = {
track_total_hits: true,
query,
};

const resp: TrackTotalHitsSearchResponse = await ml.esSearch({
index: destIndex,
size: 0,
body,
});

const docsCount = resp.hits.total && resp.hits.total.value;
return { docsCount, success: docsCount !== undefined };
} catch (e) {
return {
docsCount: null,
success: false,
};
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ export const sortRegressionResultsFields = (
) => {
const dependentVariable = getDependentVar(jobConfig.analysis);
const resultsField = jobConfig.dest.results_field;
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis);
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis, true);
if (a === `${resultsField}.is_training`) {
return -1;
}
Expand All @@ -96,6 +96,14 @@ export const sortRegressionResultsFields = (
if (b === dependentVariable) {
return 1;
}

if (a === `${resultsField}.prediction_probability`) {
return -1;
}
if (b === `${resultsField}.prediction_probability`) {
return 1;
}

return a.localeCompare(b);
};

Expand All @@ -107,7 +115,7 @@ export const sortRegressionResultsColumns = (
) => (a: string, b: string) => {
const dependentVariable = getDependentVar(jobConfig.analysis);
const resultsField = jobConfig.dest.results_field;
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis);
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis, true);

const typeofA = typeof obj[a];
const typeofB = typeof obj[b];
Expand Down Expand Up @@ -136,6 +144,14 @@ export const sortRegressionResultsColumns = (
return 1;
}

if (a === `${resultsField}.prediction_probability`) {
return -1;
}

if (b === `${resultsField}.prediction_probability`) {
return 1;
}

if (typeofA !== 'string' && typeofB === 'string') {
return 1;
}
Expand Down Expand Up @@ -184,6 +200,43 @@ export function getFlattenedFields(obj: EsDocSource, resultsField: string): EsFi
return flatDocFields.filter(f => f !== ML__ID_COPY);
}

export const getDefaultClassificationFields = (
docs: EsDoc[],
jobConfig: DataFrameAnalyticsConfig
): EsFieldName[] => {
if (docs.length === 0) {
return [];
}
const resultsField = jobConfig.dest.results_field;
const newDocFields = getFlattenedFields(docs[0]._source, resultsField);
return newDocFields
.filter(k => {
if (k === `${resultsField}.is_training`) {
return true;
}
// predicted value of dependent variable
if (k === getPredictedFieldName(resultsField, jobConfig.analysis, true)) {
return true;
}
// actual value of dependent variable
if (k === getDependentVar(jobConfig.analysis)) {
return true;
}

if (k === `${resultsField}.prediction_probability`) {
return true;
}

if (k.split('.')[0] === resultsField) {
return false;
}

return docs.some(row => row._source[k] !== null);
})
.sort((a, b) => sortRegressionResultsFields(a, b, jobConfig))
.slice(0, DEFAULT_REGRESSION_COLUMNS);
};

export const getDefaultRegressionFields = (
docs: EsDoc[],
jobConfig: DataFrameAnalyticsConfig
Expand Down
Loading

0 comments on commit 0cd5bb0

Please sign in to comment.