From 44ea9a845e029fb6dec15743c38ae11c98676484 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 20 Dec 2019 10:15:14 +0100 Subject: [PATCH] Get rid of maxClassesCardinality internal parameter --- .../evaluation/classification/Accuracy.java | 13 +-- .../MulticlassConfusionMatrix.java | 8 +- .../evaluation/classification/Precision.java | 2 +- .../classification/AccuracyTests.java | 87 +++++++++++++++++++ .../MulticlassConfusionMatrixTests.java | 29 ++++--- .../ClassificationEvaluationIT.java | 9 -- 6 files changed, 109 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 471714e4ede95..c6636329a65d9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; import org.apache.lucene.util.SetOnce; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -78,26 +77,18 @@ public static Accuracy fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + private static final int MAX_CLASSES_CARDINALITY = 1000; - private final int maxClassesCardinality; private final MulticlassConfusionMatrix matrix; private final SetOnce actualField = new SetOnce<>(); private final SetOnce overallAccuracy = new SetOnce<>(); private final SetOnce result = new SetOnce<>(); public Accuracy() { - this((Integer) null); - } - - // Visible for testing - public Accuracy(@Nullable Integer maxClassesCardinality) { - this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; - this.matrix = new MulticlassConfusionMatrix(this.maxClassesCardinality, NAME.getPreferredName() + "_"); + this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_"); } public Accuracy(StreamInput in) throws IOException { - this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY; this.matrix = new MulticlassConfusionMatrix(in); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 3c4bf1f1cb5ca..e5a4de1605da0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -71,10 +71,10 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; - private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; - private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; - private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; + static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; + static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; + static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; + static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; private static final String OTHER_BUCKET_KEY = "_other_"; private static final String DEFAULT_AGG_NAME_PREFIX = ""; private static final int DEFAULT_SIZE = 10; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index c3da03f080be3..87b45949b85ba 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -108,7 +108,7 @@ public final Tuple, List> a .size(MAX_CLASSES_CARDINALITY)), List.of()); } - if (result == null) { // This is step 2 + if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index f548fdbfd4c11..cac591a17d303 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -5,13 +5,26 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import java.io.IOException; import java.util.List; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; public class AccuracyTests extends AbstractSerializingTestCase { @@ -40,6 +53,80 @@ public static Accuracy createRandom() { return new Accuracy(); } + public void testProcess() { + Aggregations aggs = new Aggregations(List.of( + mockTerms( + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockTermsBucket("dog", new Aggregations(List.of())), + mockTermsBucket("cat", new Aggregations(List.of()))), + 100L), + mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockFiltersBucket( + "dog", + 30, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "cat", + 70, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L), + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); + + Accuracy accuracy = new Accuracy(); + accuracy.process(aggs); + + assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty())); + + Result result = accuracy.getResult().get(); + assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat( + result.getClasses(), + equalTo( + List.of( + new PerClassResult("dog", 0.5), + new PerClassResult("cat", 0.5)))); + assertThat(result.getOverallAccuracy(), equalTo(0.5)); + } + + public void testProcess_GivenCardinalityTooHigh() { + Aggregations aggs = new Aggregations(List.of( + mockTerms( + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockTermsBucket("dog", new Aggregations(List.of())), + mockTermsBucket("cat", new Aggregations(List.of()))), + 100L), + mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockFiltersBucket( + "dog", + 30, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "cat", + 70, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L), + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); + + Accuracy accuracy = new Accuracy(); + accuracy.aggs("foo", "bar"); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs)); + assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); + } + public void testComputePerClassAccuracy() { assertThat( Accuracy.computePerClassAccuracy( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index 6713974040c66..8c02a3c2c6fc3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result; import java.io.IOException; import java.util.List; @@ -85,34 +86,34 @@ public void testAggs() { public void testEvaluate() { Aggregations aggs = new Aggregations(List.of( mockTerms( - "multiclass_confusion_matrix_step_1_by_actual_class", + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockTermsBucket("dog", new Aggregations(List.of())), mockTermsBucket("cat", new Aggregations(List.of()))), 0L), mockFilters( - "multiclass_confusion_matrix_step_2_by_actual_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockFiltersBucket( "dog", 30, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", 70, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), - mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); + mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get(); - assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Result result = confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( result.getConfusionMatrix(), equalTo( @@ -125,34 +126,34 @@ public void testEvaluate() { public void testEvaluate_OtherClassesCountGreaterThanZero() { Aggregations aggs = new Aggregations(List.of( mockTerms( - "multiclass_confusion_matrix_step_1_by_actual_class", + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockTermsBucket("dog", new Aggregations(List.of())), mockTermsBucket("cat", new Aggregations(List.of()))), 100L), mockFilters( - "multiclass_confusion_matrix_step_2_by_actual_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockFiltersBucket( "dog", 30, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", 85, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), - mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); + mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get(); - assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Result result = confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( result.getConfusionMatrix(), equalTo( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index be1961c36d4e0..da5439f6298dc 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -142,15 +142,6 @@ public void testEvaluate_Accuracy_BooleanField() { assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); } - public void testEvaluate_Accuracy_CardinalityTooHigh() { - ElasticsearchStatusException e = - expectThrows( - ElasticsearchStatusException.class, - () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy(4))))); - assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); - } - public void testEvaluate_Precision() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame(