From d9835f7fb4443a2baff05fc895dbd9fbc7a91353 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 21 Nov 2019 11:22:16 -0500 Subject: [PATCH] [ML] Fix r_squared eval when variance is 0 (#49439) (#49445) --- .../dataframe/evaluation/regression/RSquared.java | 6 +++++- .../evaluation/regression/RSquaredTests.java | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java index 9307d5ae0ae46..408f8ff0a6900 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -81,7 +81,11 @@ public void process(Aggregations aggs) { NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES); ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual"); // extendedStats.getVariance() is the statistical sumOfSquares divided by count - result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ? + final boolean validResult = residualSumOfSquares == null + || extendedStats == null + || extendedStats.getCount() == 0 + || extendedStats.getVariance() == 0; + result = validResult ? new Result(0.0) : new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount()))); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java index 4913d232f74cc..8c637c9cf179a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java @@ -74,6 +74,21 @@ public void testEvaluateWithZeroCount() { assertThat(result, equalTo(new RSquared.Result(0.0))); } + public void testEvaluateWithSingleCountZeroVariance() { + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("residual_sum_of_squares", 1), + createExtendedStatsAgg("extended_stats_actual", 0.0, 1), + createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + RSquared rSquared = new RSquared(); + rSquared.process(aggs); + + EvaluationMetricResult result = rSquared.getResult().get(); + assertThat(result, equalTo(new RSquared.Result(0.0))); + } + public void testEvaluate_GivenMissingAggs() { Aggregations aggs = new Aggregations(Collections.singletonList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377)