Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/count.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use ahash::RandomState;
19
use datafusion_common::stats::Precision;
20
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
21
use datafusion_physical_expr::expressions;
22
use std::collections::HashSet;
23
use std::ops::BitAnd;
24
use std::{fmt::Debug, sync::Arc};
25
26
use arrow::{
27
    array::{ArrayRef, AsArray},
28
    compute,
29
    datatypes::{
30
        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
31
        Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
32
        Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
33
        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
34
        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
35
        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
36
    },
37
};
38
39
use arrow::{
40
    array::{Array, BooleanArray, Int64Array, PrimitiveArray},
41
    buffer::BooleanBuffer,
42
};
43
use datafusion_common::{
44
    downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue,
45
};
46
use datafusion_expr::function::StateFieldsArgs;
47
use datafusion_expr::{
48
    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
49
    EmitTo, GroupsAccumulator, Signature, Volatility,
50
};
51
use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
52
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
53
    BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
54
    PrimitiveDistinctCountAccumulator,
55
};
56
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
57
use datafusion_physical_expr_common::binary_map::OutputType;
58
59
use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
60
make_udaf_expr_and_func!(
61
    Count,
62
    count,
63
    expr,
64
    "Count the number of non-null values in the column",
65
    count_udaf
66
);
67
68
0
pub fn count_distinct(expr: Expr) -> Expr {
69
0
    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
70
0
        count_udaf(),
71
0
        vec![expr],
72
0
        true,
73
0
        None,
74
0
        None,
75
0
        None,
76
0
    ))
77
0
}
78
79
pub struct Count {
80
    signature: Signature,
81
}
82
83
impl Debug for Count {
84
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
85
0
        f.debug_struct("Count")
86
0
            .field("name", &self.name())
87
0
            .field("signature", &self.signature)
88
0
            .finish()
89
0
    }
90
}
91
92
impl Default for Count {
93
1
    fn default() -> Self {
94
1
        Self::new()
95
1
    }
96
}
97
98
impl Count {
99
1
    pub fn new() -> Self {
100
1
        Self {
101
1
            signature: Signature::one_of(
102
1
                // TypeSignature::Any(0) is required to handle `Count()` with no args
103
1
                vec![TypeSignature::VariadicAny, TypeSignature::Any(0)],
104
1
                Volatility::Immutable,
105
1
            ),
106
1
        }
107
1
    }
108
}
109
110
impl AggregateUDFImpl for Count {
111
0
    fn as_any(&self) -> &dyn std::any::Any {
112
0
        self
113
0
    }
114
115
11
    fn name(&self) -> &str {
116
11
        "count"
117
11
    }
118
119
10
    fn signature(&self) -> &Signature {
120
10
        &self.signature
121
10
    }
122
123
10
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
124
10
        Ok(DataType::Int64)
125
10
    }
126
127
10
    fn is_nullable(&self) -> bool {
128
10
        false
129
10
    }
130
131
26
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
132
26
        if args.is_distinct {
133
0
            Ok(vec![Field::new_list(
134
0
                format_state_name(args.name, "count distinct"),
135
0
                // See COMMENTS.md to understand why nullable is set to true
136
0
                Field::new("item", args.input_types[0].clone(), true),
137
0
                false,
138
0
            )])
139
        } else {
140
26
            Ok(vec![Field::new(
141
26
                format_state_name(args.name, "count"),
142
26
                DataType::Int64,
143
26
                false,
144
26
            )])
145
        }
146
26
    }
147
148
3
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
149
3
        if !acc_args.is_distinct {
150
3
            return Ok(Box::new(CountAccumulator::new()));
151
0
        }
152
0
153
0
        if acc_args.exprs.len() > 1 {
154
0
            return not_impl_err!("COUNT DISTINCT with multiple arguments");
155
0
        }
156
157
0
        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
158
0
        Ok(match data_type {
159
            // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
160
0
            DataType::Int8 => Box::new(
161
0
                PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
162
0
            ),
163
0
            DataType::Int16 => Box::new(
164
0
                PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
165
0
            ),
166
0
            DataType::Int32 => Box::new(
167
0
                PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
168
0
            ),
169
0
            DataType::Int64 => Box::new(
170
0
                PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
171
0
            ),
172
0
            DataType::UInt8 => Box::new(
173
0
                PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
174
0
            ),
175
0
            DataType::UInt16 => Box::new(
176
0
                PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
177
0
            ),
178
0
            DataType::UInt32 => Box::new(
179
0
                PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
180
0
            ),
181
0
            DataType::UInt64 => Box::new(
182
0
                PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
183
0
            ),
184
0
            DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
185
0
                Decimal128Type,
186
0
            >::new(data_type)),
187
0
            DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
188
0
                Decimal256Type,
189
0
            >::new(data_type)),
190
191
0
            DataType::Date32 => Box::new(
192
0
                PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
193
0
            ),
194
0
            DataType::Date64 => Box::new(
195
0
                PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
196
0
            ),
197
0
            DataType::Time32(TimeUnit::Millisecond) => Box::new(
198
0
                PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
199
0
                    data_type,
200
0
                ),
201
0
            ),
202
0
            DataType::Time32(TimeUnit::Second) => Box::new(
203
0
                PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
204
0
            ),
205
0
            DataType::Time64(TimeUnit::Microsecond) => Box::new(
206
0
                PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
207
0
                    data_type,
208
0
                ),
209
0
            ),
210
0
            DataType::Time64(TimeUnit::Nanosecond) => Box::new(
211
0
                PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
212
0
            ),
213
0
            DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
214
0
                PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
215
0
                    data_type,
216
0
                ),
217
0
            ),
218
0
            DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
219
0
                PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
220
0
                    data_type,
221
0
                ),
222
0
            ),
223
0
            DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
224
0
                PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
225
0
                    data_type,
226
0
                ),
227
0
            ),
228
0
            DataType::Timestamp(TimeUnit::Second, _) => Box::new(
229
0
                PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
230
0
            ),
231
232
            DataType::Float16 => {
233
0
                Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
234
            }
235
            DataType::Float32 => {
236
0
                Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
237
            }
238
            DataType::Float64 => {
239
0
                Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
240
            }
241
242
            DataType::Utf8 => {
243
0
                Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
244
            }
245
            DataType::Utf8View => {
246
0
                Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
247
            }
248
            DataType::LargeUtf8 => {
249
0
                Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
250
            }
251
0
            DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
252
0
                OutputType::Binary,
253
0
            )),
254
0
            DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
255
0
                OutputType::BinaryView,
256
0
            )),
257
0
            DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
258
0
                OutputType::Binary,
259
0
            )),
260
261
            // Use the generic accumulator based on `ScalarValue` for all other types
262
0
            _ => Box::new(DistinctCountAccumulator {
263
0
                values: HashSet::default(),
264
0
                state_data_type: data_type.clone(),
265
0
            }),
266
        })
267
3
    }
268
269
0
    fn aliases(&self) -> &[String] {
270
0
        &[]
271
0
    }
272
273
15
    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
274
15
        // groups accumulator only supports `COUNT(c1)`, not
275
15
        // `COUNT(c1, c2)`, etc
276
15
        if args.is_distinct {
277
0
            return false;
278
15
        }
279
15
        args.exprs.len() == 1
280
15
    }
281
282
15
    fn create_groups_accumulator(
283
15
        &self,
284
15
        _args: AccumulatorArgs,
285
15
    ) -> Result<Box<dyn GroupsAccumulator>> {
286
15
        // instantiate specialized accumulator
287
15
        Ok(Box::new(CountGroupsAccumulator::new()))
288
15
    }
289
290
0
    fn reverse_expr(&self) -> ReversedUDAF {
291
0
        ReversedUDAF::Identical
292
0
    }
293
294
0
    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
295
0
        Ok(ScalarValue::Int64(Some(0)))
296
0
    }
297
298
0
    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
299
0
        if statistics_args.is_distinct {
300
0
            return None;
301
0
        }
302
0
        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
303
0
            if statistics_args.exprs.len() == 1 {
304
                // TODO optimize with exprs other than Column
305
0
                if let Some(col_expr) = statistics_args.exprs[0]
306
0
                    .as_any()
307
0
                    .downcast_ref::<expressions::Column>()
308
                {
309
0
                    let current_val = &statistics_args.statistics.column_statistics
310
0
                        [col_expr.index()]
311
0
                    .null_count;
312
0
                    if let &Precision::Exact(val) = current_val {
313
0
                        return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
314
0
                    }
315
0
                } else if let Some(lit_expr) = statistics_args.exprs[0]
316
0
                    .as_any()
317
0
                    .downcast_ref::<expressions::Literal>()
318
                {
319
0
                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
320
0
                        return Some(ScalarValue::Int64(Some(num_rows as i64)));
321
0
                    }
322
0
                }
323
0
            }
324
0
        }
325
0
        None
326
0
    }
327
}
328
329
#[derive(Debug)]
330
struct CountAccumulator {
331
    count: i64,
332
}
333
334
impl CountAccumulator {
335
    /// new count accumulator
336
3
    pub fn new() -> Self {
337
3
        Self { count: 0 }
338
3
    }
339
}
340
341
impl Accumulator for CountAccumulator {
342
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
343
0
        Ok(vec![ScalarValue::Int64(Some(self.count))])
344
0
    }
345
346
6
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
347
6
        let array = &values[0];
348
6
        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
349
6
        Ok(())
350
6
    }
351
352
6
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
353
6
        let array = &values[0];
354
6
        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
355
6
        Ok(())
356
6
    }
357
358
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
359
0
        let counts = downcast_value!(states[0], Int64Array);
360
0
        let delta = &arrow::compute::sum(counts);
361
0
        if let Some(d) = delta {
362
0
            self.count += *d;
363
0
        }
364
0
        Ok(())
365
0
    }
366
367
8
    fn evaluate(&mut self) -> Result<ScalarValue> {
368
8
        Ok(ScalarValue::Int64(Some(self.count)))
369
8
    }
370
371
3
    fn supports_retract_batch(&self) -> bool {
372
3
        true
373
3
    }
374
375
0
    fn size(&self) -> usize {
376
0
        std::mem::size_of_val(self)
377
0
    }
378
}
379
380
/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
381
/// Stores values as native types, and does overflow checking
382
///
383
/// Unlike most other accumulators, COUNT never produces NULLs. If no
384
/// non-null values are seen in any group the output is 0. Thus, this
385
/// accumulator has no additional null or seen filter tracking.
386
#[derive(Debug)]
387
struct CountGroupsAccumulator {
388
    /// Count per group.
389
    ///
390
    /// Note this is an i64 and not a u64 (or usize) because the
391
    /// output type of count is `DataType::Int64`. Thus by using `i64`
392
    /// for the counts, the output [`Int64Array`] can be created
393
    /// without copy.
394
    counts: Vec<i64>,
395
}
396
397
impl CountGroupsAccumulator {
398
15
    pub fn new() -> Self {
399
15
        Self { counts: vec![] }
400
15
    }
401
}
402
403
impl GroupsAccumulator for CountGroupsAccumulator {
404
63
    fn update_batch(
405
63
        &mut self,
406
63
        values: &[ArrayRef],
407
63
        group_indices: &[usize],
408
63
        opt_filter: Option<&BooleanArray>,
409
63
        total_num_groups: usize,
410
63
    ) -> Result<()> {
411
63
        assert_eq!(values.len(), 1, 
"single argument to update_batch"0
);
412
63
        let values = &values[0];
413
63
414
63
        // Add one to each group's counter for each non null, non
415
63
        // filtered value
416
63
        self.counts.resize(total_num_groups, 0);
417
63
        accumulate_indices(
418
63
            group_indices,
419
63
            values.logical_nulls().as_ref(),
420
63
            opt_filter,
421
98.5k
            |group_index| {
422
98.5k
                self.counts[group_index] += 1;
423
98.5k
            },
424
63
        );
425
63
426
63
        Ok(())
427
63
    }
428
429
8
    fn merge_batch(
430
8
        &mut self,
431
8
        values: &[ArrayRef],
432
8
        group_indices: &[usize],
433
8
        opt_filter: Option<&BooleanArray>,
434
8
        total_num_groups: usize,
435
8
    ) -> Result<()> {
436
8
        assert_eq!(values.len(), 1, 
"one argument to merge_batch"0
);
437
        // first batch is counts, second is partial sums
438
8
        let partial_counts = values[0].as_primitive::<Int64Type>();
439
8
440
8
        // intermediate counts are always created as non null
441
8
        assert_eq!(partial_counts.null_count(), 0);
442
8
        let partial_counts = partial_counts.values();
443
8
444
8
        // Adds the counts with the partial counts
445
8
        self.counts.resize(total_num_groups, 0);
446
8
        match opt_filter {
447
0
            Some(filter) => filter
448
0
                .iter()
449
0
                .zip(group_indices.iter())
450
0
                .zip(partial_counts.iter())
451
0
                .for_each(|((filter_value, &group_index), partial_count)| {
452
0
                    if let Some(true) = filter_value {
453
0
                        self.counts[group_index] += partial_count;
454
0
                    }
455
0
                }),
456
8
            None => group_indices.iter().zip(partial_counts.iter()).for_each(
457
48
                |(&group_index, partial_count)| {
458
48
                    self.counts[group_index] += partial_count;
459
48
                },
460
8
            ),
461
        }
462
463
8
        Ok(())
464
8
    }
465
466
4
    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
467
4
        let counts = emit_to.take_needed(&mut self.counts);
468
4
469
4
        // Count is always non null (null inputs just don't contribute to the overall values)
470
4
        let nulls = None;
471
4
        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
472
4
473
4
        Ok(Arc::new(array))
474
4
    }
475
476
    // return arrays for counts
477
13
    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
478
13
        let counts = emit_to.take_needed(&mut self.counts);
479
13
        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
480
13
        Ok(vec![Arc::new(counts) as ArrayRef])
481
13
    }
482
483
    /// Converts an input batch directly to a state batch
484
    ///
485
    /// The state of `COUNT` is always a single Int64Array:
486
    /// * `1` (for non-null, non filtered values)
487
    /// * `0` (for null values)
488
2
    fn convert_to_state(
489
2
        &self,
490
2
        values: &[ArrayRef],
491
2
        opt_filter: Option<&BooleanArray>,
492
2
    ) -> Result<Vec<ArrayRef>> {
493
2
        let values = &values[0];
494
495
2
        let state_array = match (values.logical_nulls(), opt_filter) {
496
            (None, None) => {
497
                // In case there is no nulls in input and no filter, returning array of 1
498
2
                Arc::new(Int64Array::from_value(1, values.len()))
499
            }
500
0
            (Some(nulls), None) => {
501
0
                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
502
0
                // of input array to Int64
503
0
                let nulls = BooleanArray::new(nulls.into_inner(), None);
504
0
                compute::cast(&nulls, &DataType::Int64)?
505
            }
506
0
            (None, Some(filter)) => {
507
0
                // If there is only filter
508
0
                // - applying filter null mask to filter values by bitand filter values and nulls buffers
509
0
                //   (using buffers guarantees absence of nulls in result)
510
0
                // - casting result of bitand to Int64 array
511
0
                let (filter_values, filter_nulls) = filter.clone().into_parts();
512
513
0
                let state_buf = match filter_nulls {
514
0
                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
515
0
                    None => filter_values,
516
                };
517
518
0
                let boolean_state = BooleanArray::new(state_buf, None);
519
0
                compute::cast(&boolean_state, &DataType::Int64)?
520
            }
521
0
            (Some(nulls), Some(filter)) => {
522
0
                // For both input nulls and filter
523
0
                // - applying filter null mask to filter values by bitand filter values and nulls buffers
524
0
                //   (using buffers guarantees absence of nulls in result)
525
0
                // - applying values null mask to filter buffer by another bitand on filter result and
526
0
                //   nulls from input values
527
0
                // - casting result to Int64 array
528
0
                let (filter_values, filter_nulls) = filter.clone().into_parts();
529
530
0
                let filter_buf = match filter_nulls {
531
0
                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
532
0
                    None => filter_values,
533
                };
534
0
                let state_buf = &filter_buf & nulls.inner();
535
0
536
0
                let boolean_state = BooleanArray::new(state_buf, None);
537
0
                compute::cast(&boolean_state, &DataType::Int64)?
538
            }
539
        };
540
541
2
        Ok(vec![state_array])
542
2
    }
543
544
11
    fn supports_convert_to_state(&self) -> bool {
545
11
        true
546
11
    }
547
548
75
    fn size(&self) -> usize {
549
75
        self.counts.capacity() * std::mem::size_of::<usize>()
550
75
    }
551
}
552
553
/// count null values for multiple columns
554
/// for each row if one column value is null, then null_count + 1
555
12
fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
556
12
    if values.len() > 1 {
557
0
        let result_bool_buf: Option<BooleanBuffer> = values
558
0
            .iter()
559
0
            .map(|a| a.logical_nulls())
560
0
            .fold(None, |acc, b| match (acc, b) {
561
0
                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
562
0
                (Some(acc), None) => Some(acc),
563
0
                (None, Some(b)) => Some(b.into_inner()),
564
0
                _ => None,
565
0
            });
566
0
        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
567
    } else {
568
12
        values[0]
569
12
            .logical_nulls()
570
12
            .map_or(0, |nulls| 
nulls.null_count()0
)
571
    }
572
12
}
573
574
/// General purpose distinct accumulator that works for any DataType by using
575
/// [`ScalarValue`].
576
///
577
/// It stores intermediate results as a `ListArray`
578
///
579
/// Note that many types have specialized accumulators that are (much)
580
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
581
/// [`BytesDistinctCountAccumulator`]
582
#[derive(Debug)]
583
struct DistinctCountAccumulator {
584
    values: HashSet<ScalarValue, RandomState>,
585
    state_data_type: DataType,
586
}
587
588
impl DistinctCountAccumulator {
589
    // calculating the size for fixed length values, taking first batch size *
590
    // number of batches This method is faster than .full_size(), however it is
591
    // not suitable for variable length values like strings or complex types
592
0
    fn fixed_size(&self) -> usize {
593
0
        std::mem::size_of_val(self)
594
0
            + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
595
0
            + self
596
0
                .values
597
0
                .iter()
598
0
                .next()
599
0
                .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
600
0
                .unwrap_or(0)
601
0
            + std::mem::size_of::<DataType>()
602
0
    }
603
604
    // calculates the size as accurately as possible. Note that calling this
605
    // method is expensive
606
0
    fn full_size(&self) -> usize {
607
0
        std::mem::size_of_val(self)
608
0
            + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
609
0
            + self
610
0
                .values
611
0
                .iter()
612
0
                .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
613
0
                .sum::<usize>()
614
0
            + std::mem::size_of::<DataType>()
615
0
    }
616
}
617
618
impl Accumulator for DistinctCountAccumulator {
619
    /// Returns the distinct values seen so far as (one element) ListArray.
620
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
621
0
        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
622
0
        let arr =
623
0
            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
624
0
        Ok(vec![ScalarValue::List(arr)])
625
0
    }
626
627
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
628
0
        if values.is_empty() {
629
0
            return Ok(());
630
0
        }
631
0
632
0
        let arr = &values[0];
633
0
        if arr.data_type() == &DataType::Null {
634
0
            return Ok(());
635
0
        }
636
0
637
0
        (0..arr.len()).try_for_each(|index| {
638
0
            if !arr.is_null(index) {
639
0
                let scalar = ScalarValue::try_from_array(arr, index)?;
640
0
                self.values.insert(scalar);
641
0
            }
642
0
            Ok(())
643
0
        })
644
0
    }
645
646
    /// Merges multiple sets of distinct values into the current set.
647
    ///
648
    /// The input to this function is a `ListArray` with **multiple** rows,
649
    /// where each row contains the values from a partial aggregate's phase (e.g.
650
    /// the result of calling `Self::state` on multiple accumulators).
651
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
652
0
        if states.is_empty() {
653
0
            return Ok(());
654
0
        }
655
0
        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
656
0
        let array = &states[0];
657
0
        let list_array = array.as_list::<i32>();
658
0
        for inner_array in list_array.iter() {
659
0
            let Some(inner_array) = inner_array else {
660
0
                return internal_err!(
661
0
                    "Intermediate results of COUNT DISTINCT should always be non null"
662
0
                );
663
            };
664
0
            self.update_batch(&[inner_array])?;
665
        }
666
0
        Ok(())
667
0
    }
668
669
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
670
0
        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
671
0
    }
672
673
0
    fn size(&self) -> usize {
674
0
        match &self.state_data_type {
675
0
            DataType::Boolean | DataType::Null => self.fixed_size(),
676
0
            d if d.is_primitive() => self.fixed_size(),
677
0
            _ => self.full_size(),
678
        }
679
0
    }
680
}