From a853e829b29bce3dabc25e4673b5d9f477bb348a Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 26 Mar 2020 07:50:11 +0100 Subject: [PATCH] Do not fail Evaluate API when the actual and predicted fields' types differ --- .../MulticlassConfusionMatrix.java | 4 +- .../evaluation/classification/Precision.java | 2 +- .../ClassificationEvaluationIT.java | 454 ++++++++++-------- 3 files changed, 270 insertions(+), 190 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index f041e66698aa6..eb80d6afdfbb4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -142,7 +142,7 @@ public final Tuple, List> a if (result.get() == null) { // These are steps 2, 3, 4 etc. KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true))) .toArray(KeyedFilter[]::new); // Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that // too_many_buckets_exception exception is not thrown. @@ -153,7 +153,7 @@ public final Tuple, List> a topActualClassNames.get().stream() .skip(actualClasses.size()) .limit(actualClassesPerBatch) - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(actualField, className).lenient(true))) .toArray(KeyedFilter[]::new); if (keyedFiltersActual.length > 0) { return Tuple.tuple( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index c2fe3e4069fcc..6343f46e2874a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -107,7 +107,7 @@ public final Tuple, List> a if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true))) .toArray(KeyedFilter[]::new); Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField); return Tuple.tuple( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 773f91b5027d2..aa6283c9f131c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -28,23 +28,25 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notANumber; public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword"; - private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction"; + private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction"; private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword"; private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer"; - private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction"; + private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction"; private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword"; private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean"; - private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction"; + private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction"; @Before public void setup() { @@ -63,7 +65,8 @@ public void cleanup() { public void testEvaluate_DefaultMetrics() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -77,7 +80,7 @@ public void testEvaluate_AllMetrics() { ANIMALS_DATA_INDEX, new Classification( ANIMAL_NAME_KEYWORD_FIELD, - ANIMAL_NAME_PREDICTION_FIELD, + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); @@ -90,161 +93,212 @@ public void testEvaluate_AllMetrics() { Recall.NAME.getPreferredName())); } - public void testEvaluate_Accuracy_KeywordField() { + private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - assertThat( - accuracyResult.getClasses(), - equalTo( - List.of( - new Accuracy.PerClassResult("ant", 47.0 / 75), - new Accuracy.PerClassResult("cat", 47.0 / 75), - new Accuracy.PerClassResult("dog", 47.0 / 75), - new Accuracy.PerClassResult("fox", 47.0 / 75), - new Accuracy.PerClassResult("mouse", 47.0 / 75)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); + return accuracyResult; } - private void evaluateAccuracy_IntegerField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, List.of(new Accuracy()))); + public void testEvaluate_Accuracy_KeywordField() { + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("ant", 47.0 / 75), + new Accuracy.PerClassResult("cat", 47.0 / 75), + new Accuracy.PerClassResult("dog", 47.0 / 75), + new Accuracy.PerClassResult("fox", 47.0 / 75), + new Accuracy.PerClassResult("mouse", 47.0 / 75)); + double expectedOverallAccuracy = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Accuracy.Result accuracyResult = evaluateAccuracy(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); - Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - assertThat( - accuracyResult.getClasses(), - equalTo( - List.of( - new Accuracy.PerClassResult("1", 57.0 / 75), - new Accuracy.PerClassResult("2", 54.0 / 75), - new Accuracy.PerClassResult("3", 51.0 / 75), - new Accuracy.PerClassResult("4", 48.0 / 75), - new Accuracy.PerClassResult("5", 45.0 / 75)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); } public void testEvaluate_Accuracy_IntegerField() { - evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD); - } - - public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() { - evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD); - } - - private void evaluateAccuracy_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, List.of(new Accuracy()))); - - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); - - Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - assertThat( - accuracyResult.getClasses(), - equalTo( - List.of( - new Accuracy.PerClassResult("false", 18.0 / 30), - new Accuracy.PerClassResult("true", 27.0 / 45)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("1", 57.0 / 75), + new Accuracy.PerClassResult("2", 54.0 / 75), + new Accuracy.PerClassResult("3", 51.0 / 75), + new Accuracy.PerClassResult("4", 48.0 / 75), + new Accuracy.PerClassResult("5", 45.0 / 75)); + double expectedOverallAccuracy = 15.0 / 75; + + Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); } public void testEvaluate_Accuracy_BooleanField() { - evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); - } + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("false", 18.0 / 30), + new Accuracy.PerClassResult("true", 27.0 / 45)); + double expectedOverallAccuracy = 45.0 / 75; + + Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } + + public void testEvaluate_Accuracy_FieldTypeMismatch() { + { + // When actual and predicted fields have different types, the sets of classes are disjoint + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("1", 0.8), + new Accuracy.PerClassResult("2", 0.8), + new Accuracy.PerClassResult("3", 0.8), + new Accuracy.PerClassResult("4", 0.8), + new Accuracy.PerClassResult("5", 0.8)); + double expectedOverallAccuracy = 0.0; + + Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } + { + // When actual and predicted fields have different types, the sets of classes are disjoint + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("false", 0.6), + new Accuracy.PerClassResult("true", 0.4)); + double expectedOverallAccuracy = 0.0; - public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() { - evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } } - public void testEvaluate_Precision_KeywordField() { + private Precision.Result evaluatePrecision(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Precision()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - List.of( - new Precision.PerClassResult("ant", 1.0 / 15), - new Precision.PerClassResult("cat", 1.0 / 15), - new Precision.PerClassResult("dog", 1.0 / 15), - new Precision.PerClassResult("fox", 1.0 / 15), - new Precision.PerClassResult("mouse", 1.0 / 15)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75)); + return precisionResult; } - private void evaluatePrecision_IntegerField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, List.of(new Precision()))); + public void testEvaluate_Precision_KeywordField() { + List expectedPerClassResults = + List.of( + new Precision.PerClassResult("ant", 1.0 / 15), + new Precision.PerClassResult("cat", 1.0 / 15), + new Precision.PerClassResult("dog", 1.0 / 15), + new Precision.PerClassResult("fox", 1.0 / 15), + new Precision.PerClassResult("mouse", 1.0 / 15)); + double expectedAvgPrecision = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Precision.Result precisionResult = evaluatePrecision(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - List.of( - new Precision.PerClassResult("1", 0.2), - new Precision.PerClassResult("2", 0.2), - new Precision.PerClassResult("3", 0.2), - new Precision.PerClassResult("4", 0.2), - new Precision.PerClassResult("5", 0.2)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(0.2)); + evaluatePrecision(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); } public void testEvaluate_Precision_IntegerField() { - evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD); - } + List expectedPerClassResults = + List.of( + new Precision.PerClassResult("1", 0.2), + new Precision.PerClassResult("2", 0.2), + new Precision.PerClassResult("3", 0.2), + new Precision.PerClassResult("4", 0.2), + new Precision.PerClassResult("5", 0.2)); + double expectedAvgPrecision = 0.2; - public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() { - evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD); - } + Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - private void evaluatePrecision_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, List.of(new Precision()))); + // Actual and predicted fields are of different types but the values are matched correctly + precisionResult = evaluatePrecision(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - List.of( - new Precision.PerClassResult("false", 0.5), - new Precision.PerClassResult("true", 9.0 / 13)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52)); + evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); } public void testEvaluate_Precision_BooleanField() { - evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + List expectedPerClassResults = + List.of( + new Precision.PerClassResult("false", 0.5), + new Precision.PerClassResult("true", 9.0 / 13)); + double expectedAvgPrecision = 31.0 / 52; + + Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); + + // Actual and predicted fields are of different types but the values are matched correctly + precisionResult = evaluatePrecision(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); + + evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + + evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); } - public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() { - evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + public void testEvaluate_Precision_FieldTypeMismatch() { + { + Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + // When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + } + { + Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + // When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + } } public void testEvaluate_Precision_CardinalityTooHigh() { @@ -254,87 +308,112 @@ public void testEvaluate_Precision_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new Precision())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - public void testEvaluate_Recall_KeywordField() { + private Recall.Result evaluateRecall(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - List.of( - new Recall.PerClassResult("ant", 1.0 / 15), - new Recall.PerClassResult("cat", 1.0 / 15), - new Recall.PerClassResult("dog", 1.0 / 15), - new Recall.PerClassResult("fox", 1.0 / 15), - new Recall.PerClassResult("mouse", 1.0 / 15)))); - assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75)); + return recallResult; } - private void evaluateRecall_IntegerField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, List.of(new Recall()))); + public void testEvaluate_Recall_KeywordField() { + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("ant", 1.0 / 15), + new Recall.PerClassResult("cat", 1.0 / 15), + new Recall.PerClassResult("dog", 1.0 / 15), + new Recall.PerClassResult("fox", 1.0 / 15), + new Recall.PerClassResult("mouse", 1.0 / 15)); + double expectedAvgRecall = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Recall.Result recallResult = evaluateRecall(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - List.of( - new Recall.PerClassResult("1", 1.0), - new Recall.PerClassResult("2", 1.0), - new Recall.PerClassResult("3", 1.0), - new Recall.PerClassResult("4", 1.0), - new Recall.PerClassResult("5", 1.0)))); - assertThat(recallResult.getAvgRecall(), equalTo(1.0)); + evaluateRecall(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); } public void testEvaluate_Recall_IntegerField() { - evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD); - } + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("1", 1.0 / 15), + new Recall.PerClassResult("2", 2.0 / 15), + new Recall.PerClassResult("3", 3.0 / 15), + new Recall.PerClassResult("4", 4.0 / 15), + new Recall.PerClassResult("5", 5.0 / 15)); + double expectedAvgRecall = 3.0 / 15; - public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() { - evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD); - } + Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - private void evaluateRecall_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, List.of(new Recall()))); + // Actual and predicted fields are of different types but the values are matched correctly + recallResult = evaluateRecall(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); - Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - List.of( - new Recall.PerClassResult("true", 0.6), - new Recall.PerClassResult("false", 0.6)))); - assertThat(recallResult.getAvgRecall(), equalTo(0.6)); + evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); } public void testEvaluate_Recall_BooleanField() { - evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("true", 0.6), + new Recall.PerClassResult("false", 0.6)); + double expectedAvgRecall = 0.6; + + Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + + // Actual and predicted fields are of different types but the values are matched correctly + recallResult = evaluateRecall(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + + evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + + evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); } - public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() { - evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + public void testEvaluate_Recall_FieldTypeMismatch() { + { + // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("1", 0.0), + new Recall.PerClassResult("2", 0.0), + new Recall.PerClassResult("3", 0.0), + new Recall.PerClassResult("4", 0.0), + new Recall.PerClassResult("5", 0.0)); + double expectedAvgRecall = 0.0; + + Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + } + { + // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("true", 0.0), + new Recall.PerClassResult("false", 0.0)); + double expectedAvgRecall = 0.0; + + Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + } } public void testEvaluate_Recall_CardinalityTooHigh() { @@ -344,15 +423,16 @@ public void testEvaluate_Recall_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new Recall())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - private void evaluateWithMulticlassConfusionMatrix() { + private void evaluateMulticlassConfusionMatrix() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix()))); + new Classification( + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -412,16 +492,16 @@ private void evaluateWithMulticlassConfusionMatrix() { } public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get(); - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get(); - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get(); - ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix); + ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateMulticlassConfusionMatrix); assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class))); TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause(); @@ -433,7 +513,7 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3, null)))); + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -471,13 +551,13 @@ private static void createAnimalsIndex(String indexName) { client().admin().indices().prepareCreate(indexName) .setMapping( ANIMAL_NAME_KEYWORD_FIELD, "type=keyword", - ANIMAL_NAME_PREDICTION_FIELD, "type=keyword", + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, "type=keyword", NO_LEGS_KEYWORD_FIELD, "type=keyword", NO_LEGS_INTEGER_FIELD, "type=integer", - NO_LEGS_PREDICTION_FIELD, "type=integer", + NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer", IS_PREDATOR_KEYWORD_FIELD, "type=keyword", IS_PREDATOR_BOOLEAN_FIELD, "type=boolean", - IS_PREDATOR_PREDICTION_FIELD, "type=boolean") + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean") .get(); } @@ -492,13 +572,13 @@ private static void indexAnimalsData(String indexName) { new IndexRequest(indexName) .source( ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i), - ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()), + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()), NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1), NO_LEGS_INTEGER_FIELD, i + 1, - NO_LEGS_PREDICTION_FIELD, j + 1, + NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1, IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0), IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0, - IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0)); + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0)); } } } @@ -514,7 +594,7 @@ private static void indexDistinctAnimals(String indexName, int distinctAnimalCou for (int i = 0; i < distinctAnimalCount; i++) { bulkRequestBuilder.add( new IndexRequest(indexName) - .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); + .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, randomAlphaOfLength(5))); } BulkResponse bulkResponse = bulkRequestBuilder.get(); if (bulkResponse.hasFailures()) {