Skip to content

Commit

Permalink
[7.5] Change format of MulticlassConfusionMatrix result to be more se…
Browse files Browse the repository at this point in the history
…lf-explanatory (#48174) (#48295)
  • Loading branch information
przemekwitek authored Oct 21, 2019
1 parent f7e0ad7 commit ed224d9
Show file tree
Hide file tree
Showing 9 changed files with 713 additions and 276 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.List;
import java.util.Objects;
import java.util.TreeMap;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
Expand Down Expand Up @@ -97,52 +97,52 @@ public int hashCode() {
public static class Result implements EvaluationMetric.Result {

private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (Long) a[1]));

static {
PARSER.declareObject(
constructorArg(),
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT);
}

public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

// Immutable
private final Map<String, Map<String, Long>> confusionMatrix;
private final long otherClassesCount;
private final List<ActualClass> confusionMatrix;
private final Long otherActualClassCount;

public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
this.otherClassesCount = otherClassesCount;
public Result(@Nullable List<ActualClass> confusionMatrix, @Nullable Long otherActualClassCount) {
this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null;
this.otherActualClassCount = otherActualClassCount;
}

@Override
public String getMetricName() {
return NAME;
}

public Map<String, Map<String, Long>> getConfusionMatrix() {
public List<ActualClass> getConfusionMatrix() {
return confusionMatrix;
}

public long getOtherClassesCount() {
return otherClassesCount;
public Long getOtherActualClassCount() {
return otherActualClassCount;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
if (confusionMatrix != null) {
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
}
if (otherActualClassCount != null) {
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
}
builder.endObject();
return builder;
}
Expand All @@ -153,12 +153,140 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
&& this.otherClassesCount == that.otherClassesCount;
&& Objects.equals(this.otherActualClassCount, that.otherActualClassCount);
}

@Override
public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount);
return Objects.hash(confusionMatrix, otherActualClassCount);
}
}

public static class ActualClass implements ToXContentObject {

private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_actual_class",
true,
a -> new ActualClass((String) a[0], (Long) a[1], (List<PredictedClass>) a[2], (Long) a[3]));

static {
PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS);
PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
}

private final String actualClass;
private final Long actualClassDocCount;
private final List<PredictedClass> predictedClasses;
private final Long otherPredictedClassDocCount;

public ActualClass(@Nullable String actualClass,
@Nullable Long actualClassDocCount,
@Nullable List<PredictedClass> predictedClasses,
@Nullable Long otherPredictedClassDocCount) {
this.actualClass = actualClass;
this.actualClassDocCount = actualClassDocCount;
this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null;
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (actualClass != null) {
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
}
if (actualClassDocCount != null) {
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
}
if (predictedClasses != null) {
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
}
if (otherPredictedClassDocCount != null) {
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& Objects.equals(this.actualClassDocCount, that.actualClassDocCount)
&& Objects.equals(this.predictedClasses, that.predictedClasses)
&& Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount);
}

@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
}

@Override
public String toString() {
return Strings.toString(this);
}
}

public static class PredictedClass implements ToXContentObject {

private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
private static final ParseField COUNT = new ParseField("count");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1]));

static {
PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS);
PARSER.declareLong(optionalConstructorArg(), COUNT);
}

private final String predictedClass;
private final Long count;

public PredictedClass(@Nullable String predictedClass, @Nullable Long count) {
this.predictedClass = predictedClass;
this.count = count;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (predictedClass != null) {
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
}
if (count != null) {
builder.field(COUNT.getPreferredName(), count);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PredictedClass that = (PredictedClass) o;
return Objects.equals(this.predictedClass, that.predictedClass)
&& Objects.equals(this.count, that.count);
}

@Override
public int hashCode() {
return Objects.hash(predictedClass, count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
Expand Down Expand Up @@ -1807,7 +1809,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "horse", "cat"));
.add(docForClassification(indexName, "ant", "cat"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);

MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
Expand All @@ -1827,22 +1829,26 @@ public void testEvaluateDataFrame_Classification() throws IOException {
MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 3L);
expectedConfusionMatrix.get("cat").put("dog", 1L);
expectedConfusionMatrix.get("cat").put("horse", 0L);
expectedConfusionMatrix.get("cat").put("_other_", 1L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 1L);
expectedConfusionMatrix.get("dog").put("dog", 3L);
expectedConfusionMatrix.get("dog").put("horse", 0L);
expectedConfusionMatrix.put("horse", new HashMap<>());
expectedConfusionMatrix.get("horse").put("cat", 1L);
expectedConfusionMatrix.get("horse").put("dog", 0L);
expectedConfusionMatrix.get("horse").put("horse", 0L);
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
assertThat(
mcmResult.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass(
"ant",
1L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
0L),
new ActualClass(
"cat",
5L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1L),
new ActualClass(
"dog",
4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
}
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
EvaluateDataFrameRequest evaluateDataFrameRequest =
Expand All @@ -1859,16 +1865,14 @@ public void testEvaluateDataFrame_Classification() throws IOException {
MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 3L);
expectedConfusionMatrix.get("cat").put("dog", 1L);
expectedConfusionMatrix.get("cat").put("_other_", 1L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 1L);
expectedConfusionMatrix.get("dog").put("dog", 3L);
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
assertThat(
mcmResult.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
)));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
Expand Down Expand Up @@ -3355,32 +3357,30 @@ public void testEvaluateDataFrame_Classification() throws Exception {
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>

Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
// end::evaluate-data-frame-results-classification

assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat(
confusionMatrix,
equalTo(
new HashMap<String, Map<String, Long>>() {{
put("cat", new HashMap<String, Long>() {{
put("cat", 3L);
put("dog", 1L);
put("ant", 0L);
put("_other_", 1L);
}});
put("dog", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 3L);
put("ant", 0L);
}});
put("ant", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 0L);
put("ant", 0L);
}});
}}));
Arrays.asList(
new ActualClass(
"ant",
1L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
0L),
new ActualClass(
"cat",
5L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1L),
new ActualClass(
"dog",
4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(otherClassesCount, equalTo(0L));
}
}
Expand Down
Loading

0 comments on commit ed224d9

Please sign in to comment.