From ba79462432596aa52c95832a4b99d2740e1174ab Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 30 Jul 2024 14:05:43 -0500 Subject: [PATCH] fix: regr_count now returns Uint64 Fixes https://github.com/apache/datafusion/issues/11726 --- datafusion/functions-aggregate/src/regr.rs | 8 +++-- .../sqllogictest/test_files/aggregate.slt | 32 +++++++++---------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 8d04ae87157d..aad110a13e13 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -153,7 +153,11 @@ impl AggregateUDFImpl for Regr { return plan_err!("Covariance requires numeric input types"); } - Ok(DataType::Float64) + if matches!(self.regr_type, RegrType::Count) { + Ok(DataType::UInt64) + } else { + Ok(DataType::Float64) + } } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { @@ -480,7 +484,7 @@ impl Accumulator for RegrAccumulator { let nullif_cond = self.count <= 1 || var_pop_x == 0.0; nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x) } - RegrType::Count => Ok(ScalarValue::Float64(Some(self.count as f64))), + RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))), RegrType::R2 => { // Only 0/1 point or all x(or y) is the same let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index fa228d499d1f..5cc66bb493ac 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4742,27 +4742,27 @@ select regr_sxy(NULL, 'bar'); # regr_*() NULL results -query RRRRRRRRR +query RRIRRRRRR select regr_slope(1,1), regr_intercept(1,1), regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), regr_sxy(1,1); ---- NULL NULL 1 NULL 1 1 0 0 0 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(1, NULL), regr_intercept(1, NULL), regr_count(1, NULL), regr_r2(1, NULL), regr_avgx(1, NULL), regr_avgy(1, NULL), regr_sxx(1, NULL), regr_syy(1, NULL), regr_sxy(1, NULL); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query RRRRRRRRR +query RRIRRRRRR select regr_slope(NULL, 1), regr_intercept(NULL, 1), regr_count(NULL, 1), regr_r2(NULL, 1), regr_avgx(NULL, 1), regr_avgy(NULL, 1), regr_sxx(NULL, 1), regr_syy(NULL, 1), regr_sxy(NULL, 1); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query RRRRRRRRR +query RRIRRRRRR select regr_slope(NULL, NULL), regr_intercept(NULL, NULL), regr_count(NULL, NULL), regr_r2(NULL, NULL), regr_avgx(NULL, NULL), regr_avgy(NULL, NULL), regr_sxx(NULL, NULL), regr_syy(NULL, NULL), regr_sxy(NULL, NULL); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), regr_r2(column2, column1), regr_avgx(column2, column1), regr_avgy(column2, column1), regr_sxx(column2, column1), regr_syy(column2, column1), regr_sxy(column2, column1) from (values (1,2), (1,4), (1,6)); ---- NULL NULL 3 NULL 1 4 0 8 0 @@ -4770,7 +4770,7 @@ NULL NULL 3 NULL 1 4 0 8 0 # regr_*() basic tests -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -4785,7 +4785,7 @@ from (values (1,2), (2,4), (3,6)); ---- 2 0 3 1 2 4 2 8 4 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -4803,7 +4803,7 @@ from aggregate_test_100; # regr_*() functions ignore NULLs -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -4818,7 +4818,7 @@ from (values (1,NULL), (2,4), (3,6)); ---- 2 0 2 1 2.5 5 0.5 2 1 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -4833,7 +4833,7 @@ from (values (1,NULL), (NULL,4), (3,6)); ---- NULL NULL 1 NULL 3 6 0 0 0 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -4848,7 +4848,7 @@ from (values (1,NULL), (NULL,4), (NULL,NULL)); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query TRRRRRRRRR rowsort +query TRRIRRRRRR rowsort select column3, regr_slope(column2, column1), @@ -4873,7 +4873,7 @@ c NULL NULL 1 NULL 1 10 0 0 0 statement ok set datafusion.execution.batch_size = 1; -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -4891,7 +4891,7 @@ from aggregate_test_100; statement ok set datafusion.execution.batch_size = 2; -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -4909,7 +4909,7 @@ from aggregate_test_100; statement ok set datafusion.execution.batch_size = 3; -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -4930,7 +4930,7 @@ set datafusion.execution.batch_size = 8192; # regr_*() testing retract_batch() from RegrAccumulator's internal implementation -query RRRRRRRRR +query RRIRRRRRR SELECT regr_slope(column2, column1) OVER w AS slope, regr_intercept(column2, column1) OVER w AS intercept, @@ -4951,7 +4951,7 @@ NULL NULL 1 NULL 1 2 0 0 0 4.5 -7 3 0.964285714286 4 11 2 42 9 3 0 3 1 5 15 2 18 6 -query RRRRRRRRR +query RRIRRRRRR SELECT regr_slope(column2, column1) OVER w AS slope, regr_intercept(column2, column1) OVER w AS intercept,