/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/count.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use ahash::RandomState; |
19 | | use datafusion_common::stats::Precision; |
20 | | use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; |
21 | | use datafusion_physical_expr::expressions; |
22 | | use std::collections::HashSet; |
23 | | use std::ops::BitAnd; |
24 | | use std::{fmt::Debug, sync::Arc}; |
25 | | |
26 | | use arrow::{ |
27 | | array::{ArrayRef, AsArray}, |
28 | | compute, |
29 | | datatypes::{ |
30 | | DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, |
31 | | Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, |
32 | | Time32MillisecondType, Time32SecondType, Time64MicrosecondType, |
33 | | Time64NanosecondType, TimeUnit, TimestampMicrosecondType, |
34 | | TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, |
35 | | UInt16Type, UInt32Type, UInt64Type, UInt8Type, |
36 | | }, |
37 | | }; |
38 | | |
39 | | use arrow::{ |
40 | | array::{Array, BooleanArray, Int64Array, PrimitiveArray}, |
41 | | buffer::BooleanBuffer, |
42 | | }; |
43 | | use datafusion_common::{ |
44 | | downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, |
45 | | }; |
46 | | use datafusion_expr::function::StateFieldsArgs; |
47 | | use datafusion_expr::{ |
48 | | function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, |
49 | | EmitTo, GroupsAccumulator, Signature, Volatility, |
50 | | }; |
51 | | use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; |
52 | | use datafusion_functions_aggregate_common::aggregate::count_distinct::{ |
53 | | BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, |
54 | | PrimitiveDistinctCountAccumulator, |
55 | | }; |
56 | | use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; |
57 | | use datafusion_physical_expr_common::binary_map::OutputType; |
58 | | |
59 | | use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; |
60 | | make_udaf_expr_and_func!( |
61 | | Count, |
62 | | count, |
63 | | expr, |
64 | | "Count the number of non-null values in the column", |
65 | | count_udaf |
66 | | ); |
67 | | |
68 | 0 | pub fn count_distinct(expr: Expr) -> Expr { |
69 | 0 | Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( |
70 | 0 | count_udaf(), |
71 | 0 | vec![expr], |
72 | 0 | true, |
73 | 0 | None, |
74 | 0 | None, |
75 | 0 | None, |
76 | 0 | )) |
77 | 0 | } |
78 | | |
79 | | pub struct Count { |
80 | | signature: Signature, |
81 | | } |
82 | | |
83 | | impl Debug for Count { |
84 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
85 | 0 | f.debug_struct("Count") |
86 | 0 | .field("name", &self.name()) |
87 | 0 | .field("signature", &self.signature) |
88 | 0 | .finish() |
89 | 0 | } |
90 | | } |
91 | | |
92 | | impl Default for Count { |
93 | 1 | fn default() -> Self { |
94 | 1 | Self::new() |
95 | 1 | } |
96 | | } |
97 | | |
98 | | impl Count { |
99 | 1 | pub fn new() -> Self { |
100 | 1 | Self { |
101 | 1 | signature: Signature::one_of( |
102 | 1 | // TypeSignature::Any(0) is required to handle `Count()` with no args |
103 | 1 | vec![TypeSignature::VariadicAny, TypeSignature::Any(0)], |
104 | 1 | Volatility::Immutable, |
105 | 1 | ), |
106 | 1 | } |
107 | 1 | } |
108 | | } |
109 | | |
110 | | impl AggregateUDFImpl for Count { |
111 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
112 | 0 | self |
113 | 0 | } |
114 | | |
115 | 11 | fn name(&self) -> &str { |
116 | 11 | "count" |
117 | 11 | } |
118 | | |
119 | 10 | fn signature(&self) -> &Signature { |
120 | 10 | &self.signature |
121 | 10 | } |
122 | | |
123 | 10 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
124 | 10 | Ok(DataType::Int64) |
125 | 10 | } |
126 | | |
127 | 10 | fn is_nullable(&self) -> bool { |
128 | 10 | false |
129 | 10 | } |
130 | | |
131 | 26 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
132 | 26 | if args.is_distinct { |
133 | 0 | Ok(vec![Field::new_list( |
134 | 0 | format_state_name(args.name, "count distinct"), |
135 | 0 | // See COMMENTS.md to understand why nullable is set to true |
136 | 0 | Field::new("item", args.input_types[0].clone(), true), |
137 | 0 | false, |
138 | 0 | )]) |
139 | | } else { |
140 | 26 | Ok(vec![Field::new( |
141 | 26 | format_state_name(args.name, "count"), |
142 | 26 | DataType::Int64, |
143 | 26 | false, |
144 | 26 | )]) |
145 | | } |
146 | 26 | } |
147 | | |
148 | 3 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
149 | 3 | if !acc_args.is_distinct { |
150 | 3 | return Ok(Box::new(CountAccumulator::new())); |
151 | 0 | } |
152 | 0 |
|
153 | 0 | if acc_args.exprs.len() > 1 { |
154 | 0 | return not_impl_err!("COUNT DISTINCT with multiple arguments"); |
155 | 0 | } |
156 | | |
157 | 0 | let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; |
158 | 0 | Ok(match data_type { |
159 | | // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator |
160 | 0 | DataType::Int8 => Box::new( |
161 | 0 | PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type), |
162 | 0 | ), |
163 | 0 | DataType::Int16 => Box::new( |
164 | 0 | PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type), |
165 | 0 | ), |
166 | 0 | DataType::Int32 => Box::new( |
167 | 0 | PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type), |
168 | 0 | ), |
169 | 0 | DataType::Int64 => Box::new( |
170 | 0 | PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type), |
171 | 0 | ), |
172 | 0 | DataType::UInt8 => Box::new( |
173 | 0 | PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type), |
174 | 0 | ), |
175 | 0 | DataType::UInt16 => Box::new( |
176 | 0 | PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type), |
177 | 0 | ), |
178 | 0 | DataType::UInt32 => Box::new( |
179 | 0 | PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type), |
180 | 0 | ), |
181 | 0 | DataType::UInt64 => Box::new( |
182 | 0 | PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type), |
183 | 0 | ), |
184 | 0 | DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< |
185 | 0 | Decimal128Type, |
186 | 0 | >::new(data_type)), |
187 | 0 | DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< |
188 | 0 | Decimal256Type, |
189 | 0 | >::new(data_type)), |
190 | | |
191 | 0 | DataType::Date32 => Box::new( |
192 | 0 | PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type), |
193 | 0 | ), |
194 | 0 | DataType::Date64 => Box::new( |
195 | 0 | PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type), |
196 | 0 | ), |
197 | 0 | DataType::Time32(TimeUnit::Millisecond) => Box::new( |
198 | 0 | PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new( |
199 | 0 | data_type, |
200 | 0 | ), |
201 | 0 | ), |
202 | 0 | DataType::Time32(TimeUnit::Second) => Box::new( |
203 | 0 | PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type), |
204 | 0 | ), |
205 | 0 | DataType::Time64(TimeUnit::Microsecond) => Box::new( |
206 | 0 | PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new( |
207 | 0 | data_type, |
208 | 0 | ), |
209 | 0 | ), |
210 | 0 | DataType::Time64(TimeUnit::Nanosecond) => Box::new( |
211 | 0 | PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type), |
212 | 0 | ), |
213 | 0 | DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( |
214 | 0 | PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new( |
215 | 0 | data_type, |
216 | 0 | ), |
217 | 0 | ), |
218 | 0 | DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( |
219 | 0 | PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new( |
220 | 0 | data_type, |
221 | 0 | ), |
222 | 0 | ), |
223 | 0 | DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( |
224 | 0 | PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new( |
225 | 0 | data_type, |
226 | 0 | ), |
227 | 0 | ), |
228 | 0 | DataType::Timestamp(TimeUnit::Second, _) => Box::new( |
229 | 0 | PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type), |
230 | 0 | ), |
231 | | |
232 | | DataType::Float16 => { |
233 | 0 | Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()) |
234 | | } |
235 | | DataType::Float32 => { |
236 | 0 | Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()) |
237 | | } |
238 | | DataType::Float64 => { |
239 | 0 | Box::new(FloatDistinctCountAccumulator::<Float64Type>::new()) |
240 | | } |
241 | | |
242 | | DataType::Utf8 => { |
243 | 0 | Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8)) |
244 | | } |
245 | | DataType::Utf8View => { |
246 | 0 | Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) |
247 | | } |
248 | | DataType::LargeUtf8 => { |
249 | 0 | Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8)) |
250 | | } |
251 | 0 | DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new( |
252 | 0 | OutputType::Binary, |
253 | 0 | )), |
254 | 0 | DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( |
255 | 0 | OutputType::BinaryView, |
256 | 0 | )), |
257 | 0 | DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new( |
258 | 0 | OutputType::Binary, |
259 | 0 | )), |
260 | | |
261 | | // Use the generic accumulator based on `ScalarValue` for all other types |
262 | 0 | _ => Box::new(DistinctCountAccumulator { |
263 | 0 | values: HashSet::default(), |
264 | 0 | state_data_type: data_type.clone(), |
265 | 0 | }), |
266 | | }) |
267 | 3 | } |
268 | | |
269 | 0 | fn aliases(&self) -> &[String] { |
270 | 0 | &[] |
271 | 0 | } |
272 | | |
273 | 15 | fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { |
274 | 15 | // groups accumulator only supports `COUNT(c1)`, not |
275 | 15 | // `COUNT(c1, c2)`, etc |
276 | 15 | if args.is_distinct { |
277 | 0 | return false; |
278 | 15 | } |
279 | 15 | args.exprs.len() == 1 |
280 | 15 | } |
281 | | |
282 | 15 | fn create_groups_accumulator( |
283 | 15 | &self, |
284 | 15 | _args: AccumulatorArgs, |
285 | 15 | ) -> Result<Box<dyn GroupsAccumulator>> { |
286 | 15 | // instantiate specialized accumulator |
287 | 15 | Ok(Box::new(CountGroupsAccumulator::new())) |
288 | 15 | } |
289 | | |
290 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
291 | 0 | ReversedUDAF::Identical |
292 | 0 | } |
293 | | |
294 | 0 | fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> { |
295 | 0 | Ok(ScalarValue::Int64(Some(0))) |
296 | 0 | } |
297 | | |
298 | 0 | fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> { |
299 | 0 | if statistics_args.is_distinct { |
300 | 0 | return None; |
301 | 0 | } |
302 | 0 | if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows { |
303 | 0 | if statistics_args.exprs.len() == 1 { |
304 | | // TODO optimize with exprs other than Column |
305 | 0 | if let Some(col_expr) = statistics_args.exprs[0] |
306 | 0 | .as_any() |
307 | 0 | .downcast_ref::<expressions::Column>() |
308 | | { |
309 | 0 | let current_val = &statistics_args.statistics.column_statistics |
310 | 0 | [col_expr.index()] |
311 | 0 | .null_count; |
312 | 0 | if let &Precision::Exact(val) = current_val { |
313 | 0 | return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); |
314 | 0 | } |
315 | 0 | } else if let Some(lit_expr) = statistics_args.exprs[0] |
316 | 0 | .as_any() |
317 | 0 | .downcast_ref::<expressions::Literal>() |
318 | | { |
319 | 0 | if lit_expr.value() == &COUNT_STAR_EXPANSION { |
320 | 0 | return Some(ScalarValue::Int64(Some(num_rows as i64))); |
321 | 0 | } |
322 | 0 | } |
323 | 0 | } |
324 | 0 | } |
325 | 0 | None |
326 | 0 | } |
327 | | } |
328 | | |
329 | | #[derive(Debug)] |
330 | | struct CountAccumulator { |
331 | | count: i64, |
332 | | } |
333 | | |
334 | | impl CountAccumulator { |
335 | | /// new count accumulator |
336 | 3 | pub fn new() -> Self { |
337 | 3 | Self { count: 0 } |
338 | 3 | } |
339 | | } |
340 | | |
341 | | impl Accumulator for CountAccumulator { |
342 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
343 | 0 | Ok(vec![ScalarValue::Int64(Some(self.count))]) |
344 | 0 | } |
345 | | |
346 | 6 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
347 | 6 | let array = &values[0]; |
348 | 6 | self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; |
349 | 6 | Ok(()) |
350 | 6 | } |
351 | | |
352 | 6 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
353 | 6 | let array = &values[0]; |
354 | 6 | self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; |
355 | 6 | Ok(()) |
356 | 6 | } |
357 | | |
358 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
359 | 0 | let counts = downcast_value!(states[0], Int64Array); |
360 | 0 | let delta = &arrow::compute::sum(counts); |
361 | 0 | if let Some(d) = delta { |
362 | 0 | self.count += *d; |
363 | 0 | } |
364 | 0 | Ok(()) |
365 | 0 | } |
366 | | |
367 | 8 | fn evaluate(&mut self) -> Result<ScalarValue> { |
368 | 8 | Ok(ScalarValue::Int64(Some(self.count))) |
369 | 8 | } |
370 | | |
371 | 3 | fn supports_retract_batch(&self) -> bool { |
372 | 3 | true |
373 | 3 | } |
374 | | |
375 | 0 | fn size(&self) -> usize { |
376 | 0 | std::mem::size_of_val(self) |
377 | 0 | } |
378 | | } |
379 | | |
380 | | /// An accumulator to compute the counts of [`PrimitiveArray<T>`]. |
381 | | /// Stores values as native types, and does overflow checking |
382 | | /// |
383 | | /// Unlike most other accumulators, COUNT never produces NULLs. If no |
384 | | /// non-null values are seen in any group the output is 0. Thus, this |
385 | | /// accumulator has no additional null or seen filter tracking. |
386 | | #[derive(Debug)] |
387 | | struct CountGroupsAccumulator { |
388 | | /// Count per group. |
389 | | /// |
390 | | /// Note this is an i64 and not a u64 (or usize) because the |
391 | | /// output type of count is `DataType::Int64`. Thus by using `i64` |
392 | | /// for the counts, the output [`Int64Array`] can be created |
393 | | /// without copy. |
394 | | counts: Vec<i64>, |
395 | | } |
396 | | |
397 | | impl CountGroupsAccumulator { |
398 | 15 | pub fn new() -> Self { |
399 | 15 | Self { counts: vec![] } |
400 | 15 | } |
401 | | } |
402 | | |
403 | | impl GroupsAccumulator for CountGroupsAccumulator { |
404 | 63 | fn update_batch( |
405 | 63 | &mut self, |
406 | 63 | values: &[ArrayRef], |
407 | 63 | group_indices: &[usize], |
408 | 63 | opt_filter: Option<&BooleanArray>, |
409 | 63 | total_num_groups: usize, |
410 | 63 | ) -> Result<()> { |
411 | 63 | assert_eq!(values.len(), 1, "single argument to update_batch"0 ); |
412 | 63 | let values = &values[0]; |
413 | 63 | |
414 | 63 | // Add one to each group's counter for each non null, non |
415 | 63 | // filtered value |
416 | 63 | self.counts.resize(total_num_groups, 0); |
417 | 63 | accumulate_indices( |
418 | 63 | group_indices, |
419 | 63 | values.logical_nulls().as_ref(), |
420 | 63 | opt_filter, |
421 | 98.5k | |group_index| { |
422 | 98.5k | self.counts[group_index] += 1; |
423 | 98.5k | }, |
424 | 63 | ); |
425 | 63 | |
426 | 63 | Ok(()) |
427 | 63 | } |
428 | | |
429 | 8 | fn merge_batch( |
430 | 8 | &mut self, |
431 | 8 | values: &[ArrayRef], |
432 | 8 | group_indices: &[usize], |
433 | 8 | opt_filter: Option<&BooleanArray>, |
434 | 8 | total_num_groups: usize, |
435 | 8 | ) -> Result<()> { |
436 | 8 | assert_eq!(values.len(), 1, "one argument to merge_batch"0 ); |
437 | | // first batch is counts, second is partial sums |
438 | 8 | let partial_counts = values[0].as_primitive::<Int64Type>(); |
439 | 8 | |
440 | 8 | // intermediate counts are always created as non null |
441 | 8 | assert_eq!(partial_counts.null_count(), 0); |
442 | 8 | let partial_counts = partial_counts.values(); |
443 | 8 | |
444 | 8 | // Adds the counts with the partial counts |
445 | 8 | self.counts.resize(total_num_groups, 0); |
446 | 8 | match opt_filter { |
447 | 0 | Some(filter) => filter |
448 | 0 | .iter() |
449 | 0 | .zip(group_indices.iter()) |
450 | 0 | .zip(partial_counts.iter()) |
451 | 0 | .for_each(|((filter_value, &group_index), partial_count)| { |
452 | 0 | if let Some(true) = filter_value { |
453 | 0 | self.counts[group_index] += partial_count; |
454 | 0 | } |
455 | 0 | }), |
456 | 8 | None => group_indices.iter().zip(partial_counts.iter()).for_each( |
457 | 48 | |(&group_index, partial_count)| { |
458 | 48 | self.counts[group_index] += partial_count; |
459 | 48 | }, |
460 | 8 | ), |
461 | | } |
462 | | |
463 | 8 | Ok(()) |
464 | 8 | } |
465 | | |
466 | 4 | fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> { |
467 | 4 | let counts = emit_to.take_needed(&mut self.counts); |
468 | 4 | |
469 | 4 | // Count is always non null (null inputs just don't contribute to the overall values) |
470 | 4 | let nulls = None; |
471 | 4 | let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls); |
472 | 4 | |
473 | 4 | Ok(Arc::new(array)) |
474 | 4 | } |
475 | | |
476 | | // return arrays for counts |
477 | 13 | fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { |
478 | 13 | let counts = emit_to.take_needed(&mut self.counts); |
479 | 13 | let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls |
480 | 13 | Ok(vec![Arc::new(counts) as ArrayRef]) |
481 | 13 | } |
482 | | |
483 | | /// Converts an input batch directly to a state batch |
484 | | /// |
485 | | /// The state of `COUNT` is always a single Int64Array: |
486 | | /// * `1` (for non-null, non filtered values) |
487 | | /// * `0` (for null values) |
488 | 2 | fn convert_to_state( |
489 | 2 | &self, |
490 | 2 | values: &[ArrayRef], |
491 | 2 | opt_filter: Option<&BooleanArray>, |
492 | 2 | ) -> Result<Vec<ArrayRef>> { |
493 | 2 | let values = &values[0]; |
494 | | |
495 | 2 | let state_array = match (values.logical_nulls(), opt_filter) { |
496 | | (None, None) => { |
497 | | // In case there is no nulls in input and no filter, returning array of 1 |
498 | 2 | Arc::new(Int64Array::from_value(1, values.len())) |
499 | | } |
500 | 0 | (Some(nulls), None) => { |
501 | 0 | // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls) |
502 | 0 | // of input array to Int64 |
503 | 0 | let nulls = BooleanArray::new(nulls.into_inner(), None); |
504 | 0 | compute::cast(&nulls, &DataType::Int64)? |
505 | | } |
506 | 0 | (None, Some(filter)) => { |
507 | 0 | // If there is only filter |
508 | 0 | // - applying filter null mask to filter values by bitand filter values and nulls buffers |
509 | 0 | // (using buffers guarantees absence of nulls in result) |
510 | 0 | // - casting result of bitand to Int64 array |
511 | 0 | let (filter_values, filter_nulls) = filter.clone().into_parts(); |
512 | | |
513 | 0 | let state_buf = match filter_nulls { |
514 | 0 | Some(filter_nulls) => &filter_values & filter_nulls.inner(), |
515 | 0 | None => filter_values, |
516 | | }; |
517 | | |
518 | 0 | let boolean_state = BooleanArray::new(state_buf, None); |
519 | 0 | compute::cast(&boolean_state, &DataType::Int64)? |
520 | | } |
521 | 0 | (Some(nulls), Some(filter)) => { |
522 | 0 | // For both input nulls and filter |
523 | 0 | // - applying filter null mask to filter values by bitand filter values and nulls buffers |
524 | 0 | // (using buffers guarantees absence of nulls in result) |
525 | 0 | // - applying values null mask to filter buffer by another bitand on filter result and |
526 | 0 | // nulls from input values |
527 | 0 | // - casting result to Int64 array |
528 | 0 | let (filter_values, filter_nulls) = filter.clone().into_parts(); |
529 | | |
530 | 0 | let filter_buf = match filter_nulls { |
531 | 0 | Some(filter_nulls) => &filter_values & filter_nulls.inner(), |
532 | 0 | None => filter_values, |
533 | | }; |
534 | 0 | let state_buf = &filter_buf & nulls.inner(); |
535 | 0 |
|
536 | 0 | let boolean_state = BooleanArray::new(state_buf, None); |
537 | 0 | compute::cast(&boolean_state, &DataType::Int64)? |
538 | | } |
539 | | }; |
540 | | |
541 | 2 | Ok(vec![state_array]) |
542 | 2 | } |
543 | | |
544 | 11 | fn supports_convert_to_state(&self) -> bool { |
545 | 11 | true |
546 | 11 | } |
547 | | |
548 | 75 | fn size(&self) -> usize { |
549 | 75 | self.counts.capacity() * std::mem::size_of::<usize>() |
550 | 75 | } |
551 | | } |
552 | | |
553 | | /// count null values for multiple columns |
554 | | /// for each row if one column value is null, then null_count + 1 |
555 | 12 | fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { |
556 | 12 | if values.len() > 1 { |
557 | 0 | let result_bool_buf: Option<BooleanBuffer> = values |
558 | 0 | .iter() |
559 | 0 | .map(|a| a.logical_nulls()) |
560 | 0 | .fold(None, |acc, b| match (acc, b) { |
561 | 0 | (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), |
562 | 0 | (Some(acc), None) => Some(acc), |
563 | 0 | (None, Some(b)) => Some(b.into_inner()), |
564 | 0 | _ => None, |
565 | 0 | }); |
566 | 0 | result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) |
567 | | } else { |
568 | 12 | values[0] |
569 | 12 | .logical_nulls() |
570 | 12 | .map_or(0, |nulls| nulls.null_count()0 ) |
571 | | } |
572 | 12 | } |
573 | | |
574 | | /// General purpose distinct accumulator that works for any DataType by using |
575 | | /// [`ScalarValue`]. |
576 | | /// |
577 | | /// It stores intermediate results as a `ListArray` |
578 | | /// |
579 | | /// Note that many types have specialized accumulators that are (much) |
580 | | /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and |
581 | | /// [`BytesDistinctCountAccumulator`] |
582 | | #[derive(Debug)] |
583 | | struct DistinctCountAccumulator { |
584 | | values: HashSet<ScalarValue, RandomState>, |
585 | | state_data_type: DataType, |
586 | | } |
587 | | |
588 | | impl DistinctCountAccumulator { |
589 | | // calculating the size for fixed length values, taking first batch size * |
590 | | // number of batches This method is faster than .full_size(), however it is |
591 | | // not suitable for variable length values like strings or complex types |
592 | 0 | fn fixed_size(&self) -> usize { |
593 | 0 | std::mem::size_of_val(self) |
594 | 0 | + (std::mem::size_of::<ScalarValue>() * self.values.capacity()) |
595 | 0 | + self |
596 | 0 | .values |
597 | 0 | .iter() |
598 | 0 | .next() |
599 | 0 | .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) |
600 | 0 | .unwrap_or(0) |
601 | 0 | + std::mem::size_of::<DataType>() |
602 | 0 | } |
603 | | |
604 | | // calculates the size as accurately as possible. Note that calling this |
605 | | // method is expensive |
606 | 0 | fn full_size(&self) -> usize { |
607 | 0 | std::mem::size_of_val(self) |
608 | 0 | + (std::mem::size_of::<ScalarValue>() * self.values.capacity()) |
609 | 0 | + self |
610 | 0 | .values |
611 | 0 | .iter() |
612 | 0 | .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) |
613 | 0 | .sum::<usize>() |
614 | 0 | + std::mem::size_of::<DataType>() |
615 | 0 | } |
616 | | } |
617 | | |
618 | | impl Accumulator for DistinctCountAccumulator { |
619 | | /// Returns the distinct values seen so far as (one element) ListArray. |
620 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
621 | 0 | let scalars = self.values.iter().cloned().collect::<Vec<_>>(); |
622 | 0 | let arr = |
623 | 0 | ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type); |
624 | 0 | Ok(vec![ScalarValue::List(arr)]) |
625 | 0 | } |
626 | | |
627 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
628 | 0 | if values.is_empty() { |
629 | 0 | return Ok(()); |
630 | 0 | } |
631 | 0 |
|
632 | 0 | let arr = &values[0]; |
633 | 0 | if arr.data_type() == &DataType::Null { |
634 | 0 | return Ok(()); |
635 | 0 | } |
636 | 0 |
|
637 | 0 | (0..arr.len()).try_for_each(|index| { |
638 | 0 | if !arr.is_null(index) { |
639 | 0 | let scalar = ScalarValue::try_from_array(arr, index)?; |
640 | 0 | self.values.insert(scalar); |
641 | 0 | } |
642 | 0 | Ok(()) |
643 | 0 | }) |
644 | 0 | } |
645 | | |
646 | | /// Merges multiple sets of distinct values into the current set. |
647 | | /// |
648 | | /// The input to this function is a `ListArray` with **multiple** rows, |
649 | | /// where each row contains the values from a partial aggregate's phase (e.g. |
650 | | /// the result of calling `Self::state` on multiple accumulators). |
651 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
652 | 0 | if states.is_empty() { |
653 | 0 | return Ok(()); |
654 | 0 | } |
655 | 0 | assert_eq!(states.len(), 1, "array_agg states must be singleton!"); |
656 | 0 | let array = &states[0]; |
657 | 0 | let list_array = array.as_list::<i32>(); |
658 | 0 | for inner_array in list_array.iter() { |
659 | 0 | let Some(inner_array) = inner_array else { |
660 | 0 | return internal_err!( |
661 | 0 | "Intermediate results of COUNT DISTINCT should always be non null" |
662 | 0 | ); |
663 | | }; |
664 | 0 | self.update_batch(&[inner_array])?; |
665 | | } |
666 | 0 | Ok(()) |
667 | 0 | } |
668 | | |
669 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
670 | 0 | Ok(ScalarValue::Int64(Some(self.values.len() as i64))) |
671 | 0 | } |
672 | | |
673 | 0 | fn size(&self) -> usize { |
674 | 0 | match &self.state_data_type { |
675 | 0 | DataType::Boolean | DataType::Null => self.fixed_size(), |
676 | 0 | d if d.is_primitive() => self.fixed_size(), |
677 | 0 | _ => self.full_size(), |
678 | | } |
679 | 0 | } |
680 | | } |