From 26fad793db31bd4259f084c02bc9f6169fc95a3b Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 9 Jul 2019 15:39:46 -0500 Subject: [PATCH 1/3] [ML] Adds support for regression.mean_squared_error to eval API --- .../MlEvaluationNamedXContentProvider.java | 6 + .../regression/MeanSquaredError.java | 130 +++++++++++++ .../evaluation/regression/Regression.java | 123 +++++++++++++ .../client/MachineLearningIT.java | 51 ++++++ .../client/RestHighLevelClientTests.java | 18 +- .../ml/EvaluateDataFrameResponseTests.java | 4 + .../MeanSquaredErrorResultsTests.java | 53 ++++++ .../regression/MeanSquaredErrorTests.java | 49 +++++ .../regression/RegressionTests.java | 55 ++++++ .../MlEvaluationNamedXContentProvider.java | 14 ++ .../regression/MeanSquaredError.java | 142 +++++++++++++++ .../evaluation/regression/Regression.java | 171 ++++++++++++++++++ .../regression/RegressionMetric.java | 37 ++++ .../regression/MeanSquaredErrorTests.java | 82 +++++++++ .../regression/RegressionTests.java | 61 +++++++ .../test/ml/evaluate_data_frame.yml | 51 +++++- 16 files changed, 1039 insertions(+), 8 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredError.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorResultsTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 764ff41de86e0..70c3621d1c090 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -18,6 +18,8 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -38,12 +40,14 @@ public List getNamedXContentParsers() { // Evaluations new NamedXContentRegistry.Entry( Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), + new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent), // Evaluation metrics new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanSquaredError.NAME), MeanSquaredError::fromXContent), // Evaluation metrics results new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), @@ -51,6 +55,8 @@ EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.R EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(MeanSquaredError.NAME), MeanSquaredError.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredError.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredError.java new file mode 100644 index 0000000000000..b526291385e7b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -0,0 +1,130 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * Calculates the mean squared error between two known numerical fields. + * + * equation: mse = 1/n * Σ(y - y´)^2 + */ +public class MeanSquaredError implements EvaluationMetric { + + public static final String NAME = "mean_squared_error"; + + private static final ObjectParser PARSER = + new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new); + + public static MeanSquaredError fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public MeanSquaredError() { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + // create static hash code from name as there are currently no unique fields per class instance + return Objects.hashCode(NAME); + } + + @Override + public String getName() { + return NAME; + } + + public static class Result implements EvaluationMetric.Result { + + public static final ParseField ERROR = new ParseField("error"); + private final double error; + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0])); + + static { + PARSER.declareDouble(constructorArg(), ERROR); + } + + public Result(double error) { + this.error = error; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ERROR.getPreferredName(), error); + builder.endObject(); + return builder; + } + + public double getError() { + return error; + } + + @Override + public String getMetricName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(that.error, this.error); + } + + @Override + public int hashCode() { + return Objects.hash(error); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java new file mode 100644 index 0000000000000..385bfd8bcc25a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of regression results. + */ +public class Regression implements Evaluation { + + public static final String NAME = "regression"; + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, a -> new Regression((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS); + } + + public static Regression fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Regression(String actualField, String predictedField) { + this(actualField, predictedField, Collections.singletonList(new MeanSquaredError())); + } + + Regression(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = actualField; + this.predictedField = predictedField; + this.metrics = metrics; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regression that = (Regression) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index b542db9c4b0bf..3b8d65a6fd0ee 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -123,6 +123,8 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; @@ -1578,6 +1580,33 @@ public void testEvaluateDataFrame() throws IOException { assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); + + String regressionIndex = "evaluate-regression-test-index"; + createIndex(regressionIndex, mappingForRegression()); + BulkRequest regressionBulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForRegression(regressionIndex, 0.3, 0.1)) // #0 + .add(docForRegression(regressionIndex, 0.3, 0.2)) // #1 + .add(docForRegression(regressionIndex, 0.3, 0.3)) // #2 + .add(docForRegression(regressionIndex, 0.3, 0.4)) // #3 + .add(docForRegression(regressionIndex, 0.3, 0.7)) // #4 + .add(docForRegression(regressionIndex, 0.5, 0.2)) // #5 + .add(docForRegression(regressionIndex, 0.5, 0.3)) // #6 + .add(docForRegression(regressionIndex, 0.5, 0.4)) // #7 + .add(docForRegression(regressionIndex, 0.5, 0.8)) // #8 + .add(docForRegression(regressionIndex, 0.5, 0.9)); // #9 + highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); + + evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression)); + + evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + MeanSquaredError.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredError.NAME); + assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME)); + assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9)); } private static XContentBuilder defaultMappingForTest() throws IOException { @@ -1615,6 +1644,28 @@ private static IndexRequest docForClassification(String indexName, boolean isTru .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p); } + private static final String actualRegression = "regression_actual"; + private static final String probabilityRegression = "regression_prob"; + + private static XContentBuilder mappingForRegression() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(actualRegression) + .field("type", "double") + .endObject() + .startObject(probabilityRegression) + .field("type", "double") + .endObject() + .endObject() + .endObject(); + } + + private static IndexRequest docForRegression(String indexName, double act, double p) { + return new IndexRequest() + .index(indexName) + .source(XContentType.JSON, actualRegression, act, probabilityRegression, p); + } + private void createIndex(String indexName, XContentBuilder mapping) throws IOException { highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index ae1cd5eb45edf..08a9a6c47decf 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -60,6 +60,8 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; @@ -674,7 +676,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(31, namedXContents.size()); + assertEquals(34, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -712,12 +714,14 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); - assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); - assertThat(names, hasItems(BinarySoftClassification.NAME)); - assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); - assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); - assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); - assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); + assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); + assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME)); + assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertThat(names, + hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredError.NAME)); + assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertThat(names, + hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredError.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java index b41d113686ccf..dbe93bcfd4eac 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorResultsTests; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -45,6 +46,9 @@ public static EvaluateDataFrameResponse randomResponse() { if (randomBoolean()) { metrics.add(ConfusionMatrixMetricResultTests.randomResult()); } + if (randomBoolean()) { + metrics.add(MeanSquaredErrorResultsTests.randomResult()); + } return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorResultsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorResultsTests.java new file mode 100644 index 0000000000000..c30a95406201e --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorResultsTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class MeanSquaredErrorResultsTests extends AbstractXContentTestCase { + + public static MeanSquaredError.Result randomResult() { + return new MeanSquaredError.Result(randomDouble()); + } + + @Override + protected MeanSquaredError.Result createTestInstance() { + return randomResult(); + } + + @Override + protected MeanSquaredError.Result doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredError.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java new file mode 100644 index 0000000000000..f0647f4c37a55 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class MeanSquaredErrorTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected MeanSquaredError createTestInstance() { + return new MeanSquaredError(); + } + + @Override + protected MeanSquaredError doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredError.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java new file mode 100644 index 0000000000000..e41070e29e350 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java @@ -0,0 +1,55 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.Collections; + +public class RegressionTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected Regression createTestInstance() { + Regression regression = randomBoolean() ? + new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) : + new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredError())); + System.out.println(Strings.toString(regression, true, true)); + return regression; + } + + @Override + protected Regression doParseInstance(XContentParser parser) throws IOException { + return Regression.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index f4a6dba88e3b1..f713aa0033d00 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -8,6 +8,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; @@ -28,6 +31,7 @@ public List getNamedXContentParsers() { // Evaluations namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent)); // Soft classification metrics namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent)); @@ -36,6 +40,9 @@ public List getNamedXContentParsers() { namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, ConfusionMatrix::fromXContent)); + // Regression metrics + namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent)); + return namedXContent; } @@ -45,6 +52,7 @@ public List getNamedWriteables() { // Evaluations namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new)); // Evaluation Metrics namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), @@ -55,6 +63,9 @@ public List getNamedWriteables() { Recall::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, + MeanSquaredError.NAME.getPreferredName(), + MeanSquaredError::new)); // Evaluation Metrics Results namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), @@ -63,6 +74,9 @@ public List getNamedWriteables() { ScoreByThresholdResult::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + MeanSquaredError.NAME.getPreferredName(), + MeanSquaredError.Result::new)); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java new file mode 100644 index 0000000000000..f81740157bc64 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + +/** + * Calculates the mean squared error between two known numerical fields. + * + * equation: mse = 1/n * Σ(y - y´)^2 + */ +public class MeanSquaredError implements RegressionMetric { + + public static final ParseField NAME = new ParseField("mean_squared_error"); + + private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;"; + private static final String AGG_NAME = "regression_" + NAME.getPreferredName(); + + private static String buildScript(Object...args) { + return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); + } + + private static final ObjectParser PARSER = + new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new); + + public static MeanSquaredError fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public MeanSquaredError(StreamInput in) { + + } + + public MeanSquaredError() { + + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public List aggs(String actualField, String predictedField) { + return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))); + } + + @Override + public EvaluationMetricResult evaluate(Aggregations aggs) { + NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME); + return value == null ? null : new Result(value.value()); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + // create static hash code from name as there are currently no unique fields per class instance + return Objects.hashCode(NAME.getPreferredName()); + } + + + public static class Result implements EvaluationMetricResult { + + private final double error; + + public Result(double error) { + this.error = error; + } + + public Result(StreamInput in) throws IOException { + this.error = in.readDouble(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(error); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("error", error); + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java new file mode 100644 index 0000000000000..455f44ae3c168 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of regression results. + */ +public class Regression implements Evaluation { + + public static final ParseField NAME = new ParseField("regression"); + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS); + } + + public static Regression fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Regression(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); + this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); + this.metrics = initMetrics(metrics); + } + + public Regression(StreamInput in) throws IOException { + this.actualField = in.readString(); + this.predictedField = in.readString(); + this.metrics = in.readNamedWriteableList(RegressionMetric.class); + } + + private static List initMetrics(@Nullable List parsedMetrics) { + List metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics; + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); + } + Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName)); + return metrics; + } + + private static List defaultMetrics() { + List defaultMetrics = new ArrayList<>(1); + defaultMetrics.add(new MeanSquaredError()); + return defaultMetrics; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public SearchSourceBuilder buildSearch() { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery(actualField)) + .filter(QueryBuilders.existsQuery(predictedField)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); + for (RegressionMetric metric : metrics) { + List aggs = metric.aggs(actualField, predictedField); + aggs.forEach(searchSourceBuilder::aggregation); + } + return searchSourceBuilder; + } + + @Override + public void evaluate(SearchResponse searchResponse, ActionListener> listener) { + List results = new ArrayList<>(metrics.size()); + for (RegressionMetric metric : metrics) { + results.add(metric.evaluate(searchResponse.getAggregations())); + } + listener.onResponse(results); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualField); + out.writeString(predictedField); + out.writeNamedWriteableList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + builder.startObject(METRICS.getPreferredName()); + for (RegressionMetric metric : metrics) { + builder.field(metric.getWriteableName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regression that = (Regression) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java new file mode 100644 index 0000000000000..1da48e2f305e6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.util.List; + +public interface RegressionMetric extends ToXContentObject, NamedWriteable { + + /** + * Returns the name of the metric (which may differ to the writeable name) + */ + String getMetricName(); + + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual value + * @param predictedField the field that stores the predicted value + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, String predictedField); + + /** + * Calculates the metric result + * @param aggs the aggregations + * @return the metric result + */ + EvaluationMetricResult evaluate(Aggregations aggs); +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java new file mode 100644 index 0000000000000..8c8832b8569bc --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; + +import java.io.IOException; +import java.util.Arrays; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MeanSquaredErrorTests extends AbstractSerializingTestCase { + + @Override + protected MeanSquaredError doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredError.fromXContent(parser); + } + + @Override + protected MeanSquaredError createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return MeanSquaredError::new; + } + + public static MeanSquaredError createRandom() { + return new MeanSquaredError(); + } + + public void testEvaluate() { + SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); + when(classInfo.getName()).thenReturn("foo"); + + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("regression_mean_squared_error", 0.8123), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + MeanSquaredError precision = new MeanSquaredError(); + EvaluationMetricResult result = precision.evaluate(aggs); + + String expected = "{\"error\":0.8123}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + public void testEvaluate_GivenMissingAggs() { + SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); + when(classInfo.getName()).thenReturn("foo"); + + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + MeanSquaredError precision = new MeanSquaredError(); + EvaluationMetricResult result = precision.evaluate(aggs); + assertThat(result, is(nullValue())); + } + + private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { + NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); + when(agg.getName()).thenReturn(name); + when(agg.value()).thenReturn(value); + return agg; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java new file mode 100644 index 0000000000000..009fc13e829f0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class RegressionTests extends AbstractSerializingTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Regression createRandom() { + List metrics = new ArrayList<>(); + metrics.add(MeanSquaredErrorTests.createRandom()); + return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics); + } + + @Override + protected Regression doParseInstance(XContentParser parser) throws IOException { + return Regression.fromXContent(parser); + } + + @Override + protected Regression createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Regression::new; + } + + public void testConstructor_GivenEmptyMetrics() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", "bar", Collections.emptyList())); + assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index ef844d61f1626..4c39d1b8bbdfe 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -8,6 +8,8 @@ setup: "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.0, + "regression_field_act": 10.9, + "regression_field_pred": 10.9, "all_true_field": true, "all_false_field": false } @@ -20,6 +22,8 @@ setup: "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.2, + "regression_field_act": 12.0, + "regression_field_pred": 9.9, "all_true_field": true, "all_false_field": false } @@ -32,6 +36,8 @@ setup: "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.3, + "regression_field_act": 20.9, + "regression_field_pred": 5.9, "all_true_field": true, "all_false_field": false } @@ -44,6 +50,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.3, + "regression_field_act": 11.9, + "regression_field_pred": 11.9, "all_true_field": true, "all_false_field": false } @@ -56,6 +64,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.4, + "regression_field_act": 42.9, + "regression_field_pred": 42.9, "all_true_field": true, "all_false_field": false } @@ -68,6 +78,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.5, + "regression_field_act": 0.42, + "regression_field_pred": 0.42, "all_true_field": true, "all_false_field": false } @@ -80,6 +92,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.9, + "regression_field_act": 1.1235813, + "regression_field_pred": 1.12358, "all_true_field": true, "all_false_field": false } @@ -92,6 +106,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.95, + "regression_field_act": -5.20, + "regression_field_pred": -5.1, "all_true_field": true, "all_false_field": false } @@ -356,7 +372,7 @@ setup: } --- -"Test binary_soft_classification given evaluation with emtpy metrics": +"Test binary_soft_classification given evaluation with empty metrics": - do: catch: /\[binary_soft_classification\] must have one or more metrics/ ml.evaluate_data_frame: @@ -518,3 +534,36 @@ setup: } } } +--- +"Test regression given empty metrics": + - do: + catch: /\[regression\] must have one or more metrics/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred", + "metrics": { } + } + } + } +--- +"Test regression mean_squared_error": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred", + "metrics": { "mean_squared_error": {} } + } + } + } + + - match: { regression.mean_squared_error.error: 28.67749840974834 } From e8e11f62a83a47f196b375f57258b3d1e5aaf676 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 10 Jul 2019 08:42:15 -0500 Subject: [PATCH 2/3] addressing PR comments --- .../MlEvaluationNamedXContentProvider.java | 7 +++--- ...Error.java => MeanSquaredErrorMetric.java} | 11 +++++----- .../evaluation/regression/Regression.java | 22 ++++++++++++------- .../client/MachineLearningIT.java | 6 ++--- .../client/RestHighLevelClientTests.java | 6 ++--- .../ml/EvaluateDataFrameResponseTests.java | 4 ++-- ...=> MeanSquaredErrorMetricResultTests.java} | 12 +++++----- ....java => MeanSquaredErrorMetricTests.java} | 10 ++++----- .../regression/RegressionTests.java | 16 +++++++++----- .../regression/MeanSquaredError.java | 5 ++--- .../regression/MeanSquaredErrorTests.java | 18 +++++---------- .../regression/RegressionTests.java | 4 +--- .../test/ml/evaluate_data_frame.yml | 16 ++++++++++++++ 13 files changed, 77 insertions(+), 60 deletions(-) rename client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/{MeanSquaredError.java => MeanSquaredErrorMetric.java} (92%) rename client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/{MeanSquaredErrorResultsTests.java => MeanSquaredErrorMetricResultTests.java} (75%) rename client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/{MeanSquaredErrorTests.java => MeanSquaredErrorMetricTests.java} (80%) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 70c3621d1c090..b6f07fd49492e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -18,7 +18,7 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.common.ParseField; @@ -47,7 +47,8 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), - new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanSquaredError.NAME), MeanSquaredError::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent), // Evaluation metrics results new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), @@ -56,7 +57,7 @@ EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMe new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(MeanSquaredError.NAME), MeanSquaredError.Result::fromXContent), + EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredError.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java similarity index 92% rename from client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredError.java rename to client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java index b526291385e7b..5b961dacbcc52 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java @@ -20,7 +20,6 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; @@ -37,18 +36,18 @@ * * equation: mse = 1/n * Σ(y - y´)^2 */ -public class MeanSquaredError implements EvaluationMetric { +public class MeanSquaredErrorMetric implements EvaluationMetric { public static final String NAME = "mean_squared_error"; - private static final ObjectParser PARSER = - new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new); + private static final ObjectParser PARSER = + new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new); - public static MeanSquaredError fromXContent(XContentParser parser) { + public static MeanSquaredErrorMetric fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - public MeanSquaredError() { + public MeanSquaredErrorMetric() { } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java index 385bfd8bcc25a..13b14f6e0b017 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java @@ -28,7 +28,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; -import java.util.Collections; +import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -45,7 +45,7 @@ public class Regression implements Evaluation { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME, a -> new Regression((String) a[0], (String) a[1], (List) a[2])); + NAME, true, a -> new Regression((String) a[0], (String) a[1], (List) a[2])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); @@ -76,10 +76,14 @@ public static Regression fromXContent(XContentParser parser) { private final List metrics; public Regression(String actualField, String predictedField) { - this(actualField, predictedField, Collections.singletonList(new MeanSquaredError())); + this(actualField, predictedField, (List)null); } - Regression(String actualField, String predictedField, @Nullable List metrics) { + public Regression(String actualField, String predictedField, EvaluationMetric... metrics) { + this(actualField, predictedField, Arrays.asList(metrics)); + } + + public Regression(String actualField, String predictedField, @Nullable List metrics) { this.actualField = actualField; this.predictedField = predictedField; this.metrics = metrics; @@ -96,11 +100,13 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.field(ACTUAL_FIELD.getPreferredName(), actualField); builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); - builder.startObject(METRICS.getPreferredName()); - for (EvaluationMetric metric : metrics) { - builder.field(metric.getName(), metric); + if (metrics != null) { + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); } - builder.endObject(); builder.endObject(); return builder; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 3b8d65a6fd0ee..d99d9ecd29d90 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -123,7 +123,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -1604,8 +1604,8 @@ public void testEvaluateDataFrame() throws IOException { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); - MeanSquaredError.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredError.NAME); - assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME)); + MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME); + assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME)); assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9)); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 08a9a6c47decf..77dc9ee53fd69 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -60,7 +60,7 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -718,10 +718,10 @@ public void testProvidedNamedXContents() { assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME)); assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, - hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredError.NAME)); + hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME)); assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, - hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredError.NAME)); + hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java index dbe93bcfd4eac..70740a3268f10 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorResultsTests; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -47,7 +47,7 @@ public static EvaluateDataFrameResponse randomResponse() { metrics.add(ConfusionMatrixMetricResultTests.randomResult()); } if (randomBoolean()) { - metrics.add(MeanSquaredErrorResultsTests.randomResult()); + metrics.add(MeanSquaredErrorMetricResultTests.randomResult()); } return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorResultsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricResultTests.java similarity index 75% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorResultsTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricResultTests.java index c30a95406201e..290938ba37048 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorResultsTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricResultTests.java @@ -25,20 +25,20 @@ import java.io.IOException; -public class MeanSquaredErrorResultsTests extends AbstractXContentTestCase { +public class MeanSquaredErrorMetricResultTests extends AbstractXContentTestCase { - public static MeanSquaredError.Result randomResult() { - return new MeanSquaredError.Result(randomDouble()); + public static MeanSquaredErrorMetric.Result randomResult() { + return new MeanSquaredErrorMetric.Result(randomDouble()); } @Override - protected MeanSquaredError.Result createTestInstance() { + protected MeanSquaredErrorMetric.Result createTestInstance() { return randomResult(); } @Override - protected MeanSquaredError.Result doParseInstance(XContentParser parser) throws IOException { - return MeanSquaredError.Result.fromXContent(parser); + protected MeanSquaredErrorMetric.Result doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredErrorMetric.Result.fromXContent(parser); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java similarity index 80% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java index f0647f4c37a55..9027462b21e75 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java @@ -25,7 +25,7 @@ import java.io.IOException; -public class MeanSquaredErrorTests extends AbstractXContentTestCase { +public class MeanSquaredErrorMetricTests extends AbstractXContentTestCase { @Override protected NamedXContentRegistry xContentRegistry() { @@ -33,13 +33,13 @@ protected NamedXContentRegistry xContentRegistry() { } @Override - protected MeanSquaredError createTestInstance() { - return new MeanSquaredError(); + protected MeanSquaredErrorMetric createTestInstance() { + return new MeanSquaredErrorMetric(); } @Override - protected MeanSquaredError doParseInstance(XContentParser parser) throws IOException { - return MeanSquaredError.fromXContent(parser); + protected MeanSquaredErrorMetric doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredErrorMetric.fromXContent(parser); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java index e41070e29e350..f5b3db9cec87c 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java @@ -19,13 +19,13 @@ package org.elasticsearch.client.ml.dataframe.evaluation.regression; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; import java.util.Collections; +import java.util.function.Predicate; public class RegressionTests extends AbstractXContentTestCase { @@ -36,11 +36,9 @@ protected NamedXContentRegistry xContentRegistry() { @Override protected Regression createTestInstance() { - Regression regression = randomBoolean() ? + return randomBoolean() ? new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) : - new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredError())); - System.out.println(Strings.toString(regression, true, true)); - return regression; + new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredErrorMetric())); } @Override @@ -50,6 +48,12 @@ protected Regression doParseInstance(XContentParser parser) throws IOException { @Override protected boolean supportsUnknownFields() { - return false; + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index f81740157bc64..8dd922b6ac26e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ObjectParser; @@ -103,9 +102,9 @@ public int hashCode() { return Objects.hashCode(NAME.getPreferredName()); } - public static class Result implements EvaluationMetricResult { + private static final String ERROR = "error"; private final double error; public Result(double error) { @@ -134,7 +133,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field("error", error); + builder.field(ERROR, error); builder.endObject(); return builder; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java index 8c8832b8569bc..4351351474761 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -12,10 +12,10 @@ import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; @@ -45,31 +45,25 @@ public static MeanSquaredError createRandom() { } public void testEvaluate() { - SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); - when(classInfo.getName()).thenReturn("foo"); - Aggregations aggs = new Aggregations(Arrays.asList( createSingleMetricAgg("regression_mean_squared_error", 0.8123), createSingleMetricAgg("some_other_single_metric_agg", 0.2377) )); - MeanSquaredError precision = new MeanSquaredError(); - EvaluationMetricResult result = precision.evaluate(aggs); + MeanSquaredError mse = new MeanSquaredError(); + EvaluationMetricResult result = mse.evaluate(aggs); String expected = "{\"error\":0.8123}"; assertThat(Strings.toString(result), equalTo(expected)); } public void testEvaluate_GivenMissingAggs() { - SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); - when(classInfo.getName()).thenReturn("foo"); - - Aggregations aggs = new Aggregations(Arrays.asList( + Aggregations aggs = new Aggregations(Collections.singletonList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377) )); - MeanSquaredError precision = new MeanSquaredError(); - EvaluationMetricResult result = precision.evaluate(aggs); + MeanSquaredError mse = new MeanSquaredError(); + EvaluationMetricResult result = mse.evaluate(aggs); assertThat(result, is(nullValue())); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index 009fc13e829f0..33ce6e56ff5c0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -33,8 +32,7 @@ protected NamedXContentRegistry xContentRegistry() { } public static Regression createRandom() { - List metrics = new ArrayList<>(); - metrics.add(MeanSquaredErrorTests.createRandom()); + List metrics = Collections.singletonList(MeanSquaredErrorTests.createRandom()); return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 4c39d1b8bbdfe..a11d80b2ab197 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -567,3 +567,19 @@ setup: } - match: { regression.mean_squared_error.error: 28.67749840974834 } +--- +"Test regression with null metrics": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred" + } + } + } + + - match: { regression.mean_squared_error.error: 28.67749840974834 } From 4c1de835124f0e24484247e3ca364c9889e50263 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 10 Jul 2019 10:08:52 -0500 Subject: [PATCH 3/3] fixing tests --- .../ml/qa/ml-with-security/build.gradle | 7 ++++--- .../test/ml/evaluate_data_frame.yml | 20 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 6077b8ab099f6..6aedffa4a4931 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -75,9 +75,9 @@ integTest.runner { 'ml/evaluate_data_frame/Test given missing index', 'ml/evaluate_data_frame/Test given index does not exist', 'ml/evaluate_data_frame/Test given missing evaluation', - 'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always true', - 'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always false', - 'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with emtpy metrics', + 'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always true', + 'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always false', + 'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with empty metrics', 'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field', 'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field', 'ml/evaluate_data_frame/Test binary_soft_classification given precision with threshold less than zero', @@ -86,6 +86,7 @@ integTest.runner { 'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds', + 'ml/evaluate_data_frame/Test regression given evaluation with empty metrics', 'ml/delete_job_force/Test cannot force delete a non-existent job', 'ml/delete_model_snapshot/Test delete snapshot missing snapshotId', 'ml/delete_model_snapshot/Test delete snapshot missing job_id', diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index a11d80b2ab197..d0ed46b0f0404 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -125,7 +125,7 @@ setup: indices.refresh: {} --- -"Test binary_soft_classifition auc_roc": +"Test binary_soft_classification auc_roc": - do: ml.evaluate_data_frame: body: > @@ -145,7 +145,7 @@ setup: - is_false: binary_soft_classification.auc_roc.curve --- -"Test binary_soft_classifition auc_roc given actual_field is int": +"Test binary_soft_classification auc_roc given actual_field is int": - do: ml.evaluate_data_frame: body: > @@ -165,7 +165,7 @@ setup: - is_false: binary_soft_classification.auc_roc.curve --- -"Test binary_soft_classifition auc_roc include curve": +"Test binary_soft_classification auc_roc include curve": - do: ml.evaluate_data_frame: body: > @@ -185,7 +185,7 @@ setup: - is_true: binary_soft_classification.auc_roc.curve --- -"Test binary_soft_classifition auc_roc given actual_field is always true": +"Test binary_soft_classification auc_roc given actual_field is always true": - do: catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/ ml.evaluate_data_frame: @@ -204,7 +204,7 @@ setup: } --- -"Test binary_soft_classifition auc_roc given actual_field is always false": +"Test binary_soft_classification auc_roc given actual_field is always false": - do: catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/ ml.evaluate_data_frame: @@ -223,7 +223,7 @@ setup: } --- -"Test binary_soft_classifition precision": +"Test binary_soft_classification precision": - do: ml.evaluate_data_frame: body: > @@ -246,7 +246,7 @@ setup: '0.5': 1.0 --- -"Test binary_soft_classifition recall": +"Test binary_soft_classification recall": - do: ml.evaluate_data_frame: body: > @@ -270,7 +270,7 @@ setup: '0.5': 0.6 --- -"Test binary_soft_classifition confusion_matrix": +"Test binary_soft_classification confusion_matrix": - do: ml.evaluate_data_frame: body: > @@ -306,7 +306,7 @@ setup: fn: 2 --- -"Test binary_soft_classifition default metrics": +"Test binary_soft_classification default metrics": - do: ml.evaluate_data_frame: body: > @@ -535,7 +535,7 @@ setup: } } --- -"Test regression given empty metrics": +"Test regression given evaluation with empty metrics": - do: catch: /\[regression\] must have one or more metrics/ ml.evaluate_data_frame: