From e7fadf677495590fca0c5ade01b55a0aa30e72c7 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Tue, 12 Dec 2023 11:37:06 +0200 Subject: [PATCH] Fix count(null) and count(distinct null) Use `logical_nulls` when the array data type is `Null`. --- .../physical-expr/src/aggregate/count.rs | 56 +++++++++++++------ .../src/aggregate/count_distinct.rs | 5 ++ .../sqllogictest/test_files/aggregate.slt | 17 +++++- 3 files changed, 60 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 738ca4e915f7d..8e196e9693d57 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -121,14 +121,20 @@ impl GroupsAccumulator for CountGroupsAccumulator { // Add one to each group's counter for each non null, non // filtered value self.counts.resize(total_num_groups, 0); - accumulate_indices( - group_indices, - values.nulls(), // ignore values - opt_filter, - |group_index| { - self.counts[group_index] += 1; - }, - ); + let index_fn = |group_index| { + self.counts[group_index] += 1; + }; + + if values.data_type() == &DataType::Null { + accumulate_indices( + group_indices, + values.logical_nulls().as_ref(), + opt_filter, + index_fn, + ); + } else { + accumulate_indices(group_indices, values.nulls(), opt_filter, index_fn); + } Ok(()) } @@ -195,19 +201,35 @@ impl GroupsAccumulator for CountGroupsAccumulator { /// count null values for multiple columns /// for each row if one column value is null, then null_count + 1 fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { + fn and_nulls(acc: BooleanBuffer, arr: &ArrayRef) -> BooleanBuffer { + if arr.data_type() == &DataType::Null { + if let Some(nulls) = arr.logical_nulls() { + return acc.bitand(nulls.inner()); + } + } else if let Some(nulls) = arr.nulls() { + return acc.bitand(nulls.inner()); + } + + acc + } + if values.len() > 1 { - let result_bool_buf: Option = values - .iter() - .map(|a| a.nulls()) - .fold(None, |acc, b| match (acc, b) { - (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), - (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.inner().clone()), - _ => None, + let result_bool_buf: Option = + values.iter().fold(None, |acc, arr| { + Some(if let Some(acc) = acc { + and_nulls(acc, arr) + } else { + arr.logical_nulls()?.into_inner() + }) }); result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) } else { - values[0].null_count() + let values = &values[0]; + if values.data_type() == &DataType::Null { + values.len() + } else { + values.null_count() + } } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index f5242d983d4cf..c2fd32a96c4fb 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -152,7 +152,12 @@ impl Accumulator for DistinctCountAccumulator { if values.is_empty() { return Ok(()); } + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + (0..arr.len()).try_for_each(|index| { if !arr.is_null(index) { let scalar = ScalarValue::try_from_array(arr, index)?; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index bcda3464f49b0..78575c9dffc51 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1492,6 +1492,12 @@ SELECT count(c1, c2) FROM test ---- 3 +# count_null +query III +SELECT count(null), count(null, null), count(distinct null) FROM test +---- +0 0 0 + # count_multi_expr_group_by query I SELECT count(c1, c2) FROM test group by c1 order by c1 @@ -1501,6 +1507,15 @@ SELECT count(c1, c2) FROM test group by c1 order by c1 2 0 +# count_null_group_by +query III +SELECT count(null), count(null, null), count(distinct null) FROM test group by c1 order by c1 +---- +0 0 0 +0 0 0 +0 0 0 +0 0 0 + # aggreggte_with_alias query II select c1, sum(c2) as `Total Salary` from test group by c1 order by c1 @@ -3241,4 +3256,4 @@ select count(*) from (select count(*) from (select 1)); query I select count(*) from (select count(*) a, count(*) b from (select 1)); ---- -1 \ No newline at end of file +1