Skip to content

Commit

Permalink
Get rid of maxClassesCardinality internal parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Dec 20, 2019
1 parent c8b0259 commit 44ea9a8
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -78,26 +77,18 @@ public static Accuracy fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
private static final int MAX_CLASSES_CARDINALITY = 1000;

private final int maxClassesCardinality;
private final MulticlassConfusionMatrix matrix;
private final SetOnce<String> actualField = new SetOnce<>();
private final SetOnce<Double> overallAccuracy = new SetOnce<>();
private final SetOnce<Result> result = new SetOnce<>();

public Accuracy() {
this((Integer) null);
}

// Visible for testing
public Accuracy(@Nullable Integer maxClassesCardinality) {
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
this.matrix = new MulticlassConfusionMatrix(this.maxClassesCardinality, NAME.getPreferredName() + "_");
this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
}

public Accuracy(StreamInput in) throws IOException {
this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY;
this.matrix = new MulticlassConfusionMatrix(in);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
private static final String OTHER_BUCKET_KEY = "_other_";
private static final String DEFAULT_AGG_NAME_PREFIX = "";
private static final int DEFAULT_SIZE = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
.size(MAX_CLASSES_CARDINALITY)),
List.of());
}
if (result == null) { // This is step 2
if (result.get() == null) { // This is step 2
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;

public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
Expand Down Expand Up @@ -40,6 +53,80 @@ public static Accuracy createRandom() {
return new Accuracy();
}

public void testProcess() {
Aggregations aggs = new Aggregations(List.of(
mockTerms(
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockTermsBucket("dog", new Aggregations(List.of())),
mockTermsBucket("cat", new Aggregations(List.of()))),
100L),
mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockFiltersBucket(
"dog",
30,
new Aggregations(List.of(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(List.of(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));

Accuracy accuracy = new Accuracy();
accuracy.process(aggs);

assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));

Result result = accuracy.getResult().get();
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
result.getClasses(),
equalTo(
List.of(
new PerClassResult("dog", 0.5),
new PerClassResult("cat", 0.5))));
assertThat(result.getOverallAccuracy(), equalTo(0.5));
}

public void testProcess_GivenCardinalityTooHigh() {
Aggregations aggs = new Aggregations(List.of(
mockTerms(
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockTermsBucket("dog", new Aggregations(List.of())),
mockTermsBucket("cat", new Aggregations(List.of()))),
100L),
mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockFiltersBucket(
"dog",
30,
new Aggregations(List.of(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(List.of(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));

Accuracy accuracy = new Accuracy();
accuracy.aggs("foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}

public void testComputePerClassAccuracy() {
assertThat(
Accuracy.computePerClassAccuracy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -85,34 +86,34 @@ public void testAggs() {
public void testEvaluate() {
Aggregations aggs = new Aggregations(List.of(
mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class",
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockTermsBucket("dog", new Aggregations(List.of())),
mockTermsBucket("cat", new Aggregations(List.of()))),
0L),
mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockFiltersBucket(
"dog",
30,
new Aggregations(List.of(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(List.of(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));

MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs);

assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(
result.getConfusionMatrix(),
equalTo(
Expand All @@ -125,34 +126,34 @@ public void testEvaluate() {
public void testEvaluate_OtherClassesCountGreaterThanZero() {
Aggregations aggs = new Aggregations(List.of(
mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class",
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockTermsBucket("dog", new Aggregations(List.of())),
mockTermsBucket("cat", new Aggregations(List.of()))),
100L),
mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
List.of(
mockFiltersBucket(
"dog",
30,
new Aggregations(List.of(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
85,
new Aggregations(List.of(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));

MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs);

assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(
result.getConfusionMatrix(),
equalTo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,6 @@ public void testEvaluate_Accuracy_BooleanField() {
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
}

public void testEvaluate_Accuracy_CardinalityTooHigh() {
ElasticsearchStatusException e =
expectThrows(
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy(4)))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
}

public void testEvaluate_Precision() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
Expand Down

0 comments on commit 44ea9a8

Please sign in to comment.