Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]
19
//!
20
//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
21
22
use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23
use arrow::buffer::{BooleanBuffer, NullBuffer};
24
use arrow::datatypes::ArrowPrimitiveType;
25
26
use datafusion_expr_common::groups_accumulator::EmitTo;
27
/// Track the accumulator null state per row: if any values for that
28
/// group were null and if any values have been seen at all for that group.
29
///
30
/// This is part of the inner loop for many [`GroupsAccumulator`]s,
31
/// and thus the performance is critical and so there are multiple
32
/// specialized implementations, invoked depending on the specific
33
/// combinations of the input.
34
///
35
/// Typically there are 4 potential combinations of inputs must be
36
/// special cased for performance:
37
///
38
/// * With / Without filter
39
/// * With / Without nulls in the input
40
///
41
/// If the input has nulls, then the accumulator must potentially
42
/// handle each input null value specially (e.g. for `SUM` to mark the
43
/// corresponding sum as null)
44
///
45
/// If there are filters present, `NullState` tracks if it has seen
46
/// *any* value for that group (as some values may be filtered
47
/// out). Without a filter, the accumulator is only passed groups that
48
/// had at least one value to accumulate so they do not need to track
49
/// if they have seen values for a particular group.
50
///
51
/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
52
#[derive(Debug)]
53
pub struct NullState {
54
    /// Have we seen any non-filtered input values for `group_index`?
55
    ///
56
    /// If `seen_values[i]` is true, have seen at least one non null
57
    /// value for group `i`
58
    ///
59
    /// If `seen_values[i]` is false, have not seen any values that
60
    /// pass the filter yet for group `i`
61
    seen_values: BooleanBufferBuilder,
62
}
63
64
impl Default for NullState {
65
0
    fn default() -> Self {
66
0
        Self::new()
67
0
    }
68
}
69
70
impl NullState {
71
15
    pub fn new() -> Self {
72
15
        Self {
73
15
            seen_values: BooleanBufferBuilder::new(0),
74
15
        }
75
15
    }
76
77
    /// return the size of all buffers allocated by this null state, not including self
78
3
    pub fn size(&self) -> usize {
79
3
        // capacity is in bits, so convert to bytes
80
3
        self.seen_values.capacity() / 8
81
3
    }
82
83
    /// Invokes `value_fn(group_index, value)` for each non null, non
84
    /// filtered value of `value`, while tracking which groups have
85
    /// seen null inputs and which groups have seen any inputs if necessary
86
    //
87
    /// # Arguments:
88
    ///
89
    /// * `values`: the input arguments to the accumulator
90
    /// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
91
    /// * `opt_filter`: if present, only rows for which is Some(true) are included
92
    /// * `value_fn`: function invoked for  (group_index, value) where value is non null
93
    ///
94
    /// See [`accumulate`], for more details on how value_fn is called
95
    ///
96
    /// When value_fn is called it also sets
97
    ///
98
    /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale
99
46
    pub fn accumulate<T, F>(
100
46
        &mut self,
101
46
        group_indices: &[usize],
102
46
        values: &PrimitiveArray<T>,
103
46
        opt_filter: Option<&BooleanArray>,
104
46
        total_num_groups: usize,
105
46
        mut value_fn: F,
106
46
    ) where
107
46
        T: ArrowPrimitiveType + Send,
108
46
        F: FnMut(usize, T::Native) + Send,
109
46
    {
110
46
        // ensure the seen_values is big enough (start everything at
111
46
        // "not seen" valid)
112
46
        let seen_values =
113
46
            initialize_builder(&mut self.seen_values, total_num_groups, false);
114
123
        accumulate(group_indices, values, opt_filter, |group_index, value| {
115
123
            seen_values.set_bit(group_index, true);
116
123
            value_fn(group_index, value);
117
123
        });
118
46
    }
119
120
    /// Invokes `value_fn(group_index, value)` for each non null, non
121
    /// filtered value in `values`, while tracking which groups have
122
    /// seen null inputs and which groups have seen any inputs, for
123
    /// [`BooleanArray`]s.
124
    ///
125
    /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be
126
    /// handled specially.
127
    ///
128
    /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
129
    /// more details on other arguments.
130
0
    pub fn accumulate_boolean<F>(
131
0
        &mut self,
132
0
        group_indices: &[usize],
133
0
        values: &BooleanArray,
134
0
        opt_filter: Option<&BooleanArray>,
135
0
        total_num_groups: usize,
136
0
        mut value_fn: F,
137
0
    ) where
138
0
        F: FnMut(usize, bool) + Send,
139
0
    {
140
0
        let data = values.values();
141
0
        assert_eq!(data.len(), group_indices.len());
142
143
        // ensure the seen_values is big enough (start everything at
144
        // "not seen" valid)
145
0
        let seen_values =
146
0
            initialize_builder(&mut self.seen_values, total_num_groups, false);
147
0
148
0
        // These could be made more performant by iterating in chunks of 64 bits at a time
149
0
        match (values.null_count() > 0, opt_filter) {
150
            // no nulls, no filter,
151
            (false, None) => {
152
                // if we have previously seen nulls, ensure the null
153
                // buffer is big enough (start everything at valid)
154
0
                group_indices.iter().zip(data.iter()).for_each(
155
0
                    |(&group_index, new_value)| {
156
0
                        seen_values.set_bit(group_index, true);
157
0
                        value_fn(group_index, new_value)
158
0
                    },
159
0
                )
160
            }
161
            // nulls, no filter
162
            (true, None) => {
163
0
                let nulls = values.nulls().unwrap();
164
0
                group_indices
165
0
                    .iter()
166
0
                    .zip(data.iter())
167
0
                    .zip(nulls.iter())
168
0
                    .for_each(|((&group_index, new_value), is_valid)| {
169
0
                        if is_valid {
170
0
                            seen_values.set_bit(group_index, true);
171
0
                            value_fn(group_index, new_value);
172
0
                        }
173
0
                    })
174
            }
175
            // no nulls, but a filter
176
0
            (false, Some(filter)) => {
177
0
                assert_eq!(filter.len(), group_indices.len());
178
179
0
                group_indices
180
0
                    .iter()
181
0
                    .zip(data.iter())
182
0
                    .zip(filter.iter())
183
0
                    .for_each(|((&group_index, new_value), filter_value)| {
184
0
                        if let Some(true) = filter_value {
185
0
                            seen_values.set_bit(group_index, true);
186
0
                            value_fn(group_index, new_value);
187
0
                        }
188
0
                    })
189
            }
190
            // both null values and filters
191
0
            (true, Some(filter)) => {
192
0
                assert_eq!(filter.len(), group_indices.len());
193
0
                filter
194
0
                    .iter()
195
0
                    .zip(group_indices.iter())
196
0
                    .zip(values.iter())
197
0
                    .for_each(|((filter_value, &group_index), new_value)| {
198
0
                        if let Some(true) = filter_value {
199
0
                            if let Some(new_value) = new_value {
200
0
                                seen_values.set_bit(group_index, true);
201
0
                                value_fn(group_index, new_value)
202
0
                            }
203
0
                        }
204
0
                    })
205
            }
206
        }
207
0
    }
208
209
    /// Creates the a [`NullBuffer`] representing which group_indices
210
    /// should have null values (because they never saw any values)
211
    /// for the `emit_to` rows.
212
    ///
213
    /// resets the internal state appropriately
214
29
    pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
215
29
        let nulls: BooleanBuffer = self.seen_values.finish();
216
217
29
        let nulls = match emit_to {
218
17
            EmitTo::All => nulls,
219
12
            EmitTo::First(n) => {
220
12
                // split off the first N values in seen_values
221
12
                //
222
12
                // TODO make this more efficient rather than two
223
12
                // copies and bitwise manipulation
224
12
                let first_n_null: BooleanBuffer = nulls.iter().take(n).collect();
225
                // reset the existing seen buffer
226
12
                for seen in nulls.iter().skip(n) {
227
12
                    self.seen_values.append(seen);
228
12
                }
229
12
                first_n_null
230
            }
231
        };
232
29
        NullBuffer::new(nulls)
233
29
    }
234
}
235
236
/// Invokes `value_fn(group_index, value)` for each non null, non
237
/// filtered value of `value`,
238
///
239
/// # Arguments:
240
///
241
/// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
242
/// * `values`: the input arguments to the accumulator
243
/// * `opt_filter`: if present, only rows for which is Some(true) are included
244
/// * `value_fn`: function invoked for  (group_index, value) where value is non null
245
///
246
/// # Example
247
///
248
/// ```text
249
///  ┌─────────┐   ┌─────────┐   ┌ ─ ─ ─ ─ ┐
250
///  │ ┌─────┐ │   │ ┌─────┐ │     ┌─────┐
251
///  │ │  2  │ │   │ │ 200 │ │   │ │  t  │ │
252
///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
253
///  │ │  2  │ │   │ │ 100 │ │   │ │  f  │ │
254
///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
255
///  │ │  0  │ │   │ │ 200 │ │   │ │  t  │ │
256
///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
257
///  │ │  1  │ │   │ │ 200 │ │   │ │NULL │ │
258
///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
259
///  │ │  0  │ │   │ │ 300 │ │   │ │  t  │ │
260
///  │ └─────┘ │   │ └─────┘ │     └─────┘
261
///  └─────────┘   └─────────┘   └ ─ ─ ─ ─ ┘
262
///
263
/// group_indices   values        opt_filter
264
/// ```
265
///
266
/// In the example above, `value_fn` is invoked for each (group_index,
267
/// value) pair where `opt_filter[i]` is true and values is non null
268
///
269
/// ```text
270
/// value_fn(2, 200)
271
/// value_fn(0, 200)
272
/// value_fn(0, 300)
273
/// ```
274
46
pub fn accumulate<T, F>(
275
46
    group_indices: &[usize],
276
46
    values: &PrimitiveArray<T>,
277
46
    opt_filter: Option<&BooleanArray>,
278
46
    mut value_fn: F,
279
46
) where
280
46
    T: ArrowPrimitiveType + Send,
281
46
    F: FnMut(usize, T::Native) + Send,
282
46
{
283
46
    let data: &[T::Native] = values.values();
284
46
    assert_eq!(data.len(), group_indices.len());
285
286
46
    match (values.null_count() > 0, opt_filter) {
287
        // no nulls, no filter,
288
        (false, None) => {
289
46
            let iter = group_indices.iter().zip(data.iter());
290
169
            for (&
group_index, &new_value123
) in iter {
291
123
                value_fn(group_index, new_value);
292
123
            }
293
        }
294
        // nulls, no filter
295
0
        (true, None) => {
296
0
            let nulls = values.nulls().unwrap();
297
0
            // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
298
0
            // iterate over in chunks of 64 bits for more efficient null checking
299
0
            let group_indices_chunks = group_indices.chunks_exact(64);
300
0
            let data_chunks = data.chunks_exact(64);
301
0
            let bit_chunks = nulls.inner().bit_chunks();
302
0
303
0
            let group_indices_remainder = group_indices_chunks.remainder();
304
0
            let data_remainder = data_chunks.remainder();
305
0
306
0
            group_indices_chunks
307
0
                .zip(data_chunks)
308
0
                .zip(bit_chunks.iter())
309
0
                .for_each(|((group_index_chunk, data_chunk), mask)| {
310
0
                    // index_mask has value 1 << i in the loop
311
0
                    let mut index_mask = 1;
312
0
                    group_index_chunk.iter().zip(data_chunk.iter()).for_each(
313
0
                        |(&group_index, &new_value)| {
314
0
                            // valid bit was set, real value
315
0
                            let is_valid = (mask & index_mask) != 0;
316
0
                            if is_valid {
317
0
                                value_fn(group_index, new_value);
318
0
                            }
319
0
                            index_mask <<= 1;
320
0
                        },
321
0
                    )
322
0
                });
323
0
324
0
            // handle any remaining bits (after the initial 64)
325
0
            let remainder_bits = bit_chunks.remainder_bits();
326
0
            group_indices_remainder
327
0
                .iter()
328
0
                .zip(data_remainder.iter())
329
0
                .enumerate()
330
0
                .for_each(|(i, (&group_index, &new_value))| {
331
0
                    let is_valid = remainder_bits & (1 << i) != 0;
332
0
                    if is_valid {
333
0
                        value_fn(group_index, new_value);
334
0
                    }
335
0
                });
336
0
        }
337
        // no nulls, but a filter
338
0
        (false, Some(filter)) => {
339
0
            assert_eq!(filter.len(), group_indices.len());
340
            // The performance with a filter could be improved by
341
            // iterating over the filter in chunks, rather than a single
342
            // iterator. TODO file a ticket
343
0
            group_indices
344
0
                .iter()
345
0
                .zip(data.iter())
346
0
                .zip(filter.iter())
347
0
                .for_each(|((&group_index, &new_value), filter_value)| {
348
0
                    if let Some(true) = filter_value {
349
0
                        value_fn(group_index, new_value);
350
0
                    }
351
0
                })
352
        }
353
        // both null values and filters
354
0
        (true, Some(filter)) => {
355
0
            assert_eq!(filter.len(), group_indices.len());
356
            // The performance with a filter could be improved by
357
            // iterating over the filter in chunks, rather than using
358
            // iterators. TODO file a ticket
359
0
            filter
360
0
                .iter()
361
0
                .zip(group_indices.iter())
362
0
                .zip(values.iter())
363
0
                .for_each(|((filter_value, &group_index), new_value)| {
364
0
                    if let Some(true) = filter_value {
365
0
                        if let Some(new_value) = new_value {
366
0
                            value_fn(group_index, new_value)
367
0
                        }
368
0
                    }
369
0
                })
370
        }
371
    }
372
46
}
373
374
/// This function is called to update the accumulator state per row
375
/// when the value is not needed (e.g. COUNT)
376
///
377
/// `F`: Invoked like `value_fn(group_index) for all non null values
378
/// passing the filter. Note that no tracking is done for null inputs
379
/// or which groups have seen any values
380
///
381
/// See [`NullState::accumulate`], for more details on other
382
/// arguments.
383
63
pub fn accumulate_indices<F>(
384
63
    group_indices: &[usize],
385
63
    nulls: Option<&NullBuffer>,
386
63
    opt_filter: Option<&BooleanArray>,
387
63
    mut index_fn: F,
388
63
) where
389
63
    F: FnMut(usize) + Send,
390
63
{
391
63
    match (nulls, opt_filter) {
392
        (None, None) => {
393
98.5k
            for &group_index in 
group_indices.iter()63
{
394
98.5k
                index_fn(group_index)
395
            }
396
        }
397
0
        (None, Some(filter)) => {
398
0
            assert_eq!(filter.len(), group_indices.len());
399
            // The performance with a filter could be improved by
400
            // iterating over the filter in chunks, rather than a single
401
            // iterator. TODO file a ticket
402
0
            let iter = group_indices.iter().zip(filter.iter());
403
0
            for (&group_index, filter_value) in iter {
404
0
                if let Some(true) = filter_value {
405
0
                    index_fn(group_index)
406
0
                }
407
            }
408
        }
409
0
        (Some(valids), None) => {
410
0
            assert_eq!(valids.len(), group_indices.len());
411
            // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
412
            // iterate over in chunks of 64 bits for more efficient null checking
413
0
            let group_indices_chunks = group_indices.chunks_exact(64);
414
0
            let bit_chunks = valids.inner().bit_chunks();
415
0
416
0
            let group_indices_remainder = group_indices_chunks.remainder();
417
0
418
0
            group_indices_chunks.zip(bit_chunks.iter()).for_each(
419
0
                |(group_index_chunk, mask)| {
420
0
                    // index_mask has value 1 << i in the loop
421
0
                    let mut index_mask = 1;
422
0
                    group_index_chunk.iter().for_each(|&group_index| {
423
0
                        // valid bit was set, real vale
424
0
                        let is_valid = (mask & index_mask) != 0;
425
0
                        if is_valid {
426
0
                            index_fn(group_index);
427
0
                        }
428
0
                        index_mask <<= 1;
429
0
                    })
430
0
                },
431
0
            );
432
0
433
0
            // handle any remaining bits (after the initial 64)
434
0
            let remainder_bits = bit_chunks.remainder_bits();
435
0
            group_indices_remainder
436
0
                .iter()
437
0
                .enumerate()
438
0
                .for_each(|(i, &group_index)| {
439
0
                    let is_valid = remainder_bits & (1 << i) != 0;
440
0
                    if is_valid {
441
0
                        index_fn(group_index)
442
0
                    }
443
0
                });
444
0
        }
445
446
0
        (Some(valids), Some(filter)) => {
447
0
            assert_eq!(filter.len(), group_indices.len());
448
0
            assert_eq!(valids.len(), group_indices.len());
449
            // The performance with a filter could likely be improved by
450
            // iterating over the filter in chunks, rather than using
451
            // iterators. TODO file a ticket
452
0
            filter
453
0
                .iter()
454
0
                .zip(group_indices.iter())
455
0
                .zip(valids.iter())
456
0
                .for_each(|((filter_value, &group_index), is_valid)| {
457
0
                    if let (Some(true), true) = (filter_value, is_valid) {
458
0
                        index_fn(group_index)
459
0
                    }
460
0
                })
461
        }
462
    }
463
63
}
464
465
/// Ensures that `builder` contains a `BooleanBufferBuilder with at
466
/// least `total_num_groups`.
467
///
468
/// All new entries are initialized to `default_value`
469
46
fn initialize_builder(
470
46
    builder: &mut BooleanBufferBuilder,
471
46
    total_num_groups: usize,
472
46
    default_value: bool,
473
46
) -> &mut BooleanBufferBuilder {
474
46
    if builder.len() < total_num_groups {
475
28
        let new_groups = total_num_groups - builder.len();
476
28
        builder.append_n(new_groups, default_value);
477
28
    }
18
478
46
    builder
479
46
}
480
481
#[cfg(test)]
482
mod test {
483
    use super::*;
484
485
    use arrow::array::UInt32Array;
486
    use rand::{rngs::ThreadRng, Rng};
487
    use std::collections::HashSet;
488
489
    #[test]
490
    fn accumulate() {
491
        let group_indices = (0..100).collect();
492
        let values = (0..100).map(|i| (i + 1) * 10).collect();
493
        let values_with_nulls = (0..100)
494
            .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
495
            .collect();
496
497
        // default to every fifth value being false, every even
498
        // being null
499
        let filter: BooleanArray = (0..100)
500
            .map(|i| {
501
                let is_even = i % 2 == 0;
502
                let is_fifth = i % 5 == 0;
503
                if is_even {
504
                    None
505
                } else if is_fifth {
506
                    Some(false)
507
                } else {
508
                    Some(true)
509
                }
510
            })
511
            .collect();
512
513
        Fixture {
514
            group_indices,
515
            values,
516
            values_with_nulls,
517
            filter,
518
        }
519
        .run()
520
    }
521
522
    #[test]
523
    fn accumulate_fuzz() {
524
        let mut rng = rand::thread_rng();
525
        for _ in 0..100 {
526
            Fixture::new_random(&mut rng).run();
527
        }
528
    }
529
530
    /// Values for testing (there are enough values to exercise the 64 bit chunks
531
    struct Fixture {
532
        /// 100..0
533
        group_indices: Vec<usize>,
534
535
        /// 10, 20, ... 1010
536
        values: Vec<u32>,
537
538
        /// same as values, but every third is null:
539
        /// None, Some(20), Some(30), None ...
540
        values_with_nulls: Vec<Option<u32>>,
541
542
        /// filter (defaults to None)
543
        filter: BooleanArray,
544
    }
545
546
    impl Fixture {
547
        fn new_random(rng: &mut ThreadRng) -> Self {
548
            // Number of input values in a batch
549
            let num_values: usize = rng.gen_range(1..200);
550
            // number of distinct groups
551
            let num_groups: usize = rng.gen_range(2..1000);
552
            let max_group = num_groups - 1;
553
554
            let group_indices: Vec<usize> = (0..num_values)
555
                .map(|_| rng.gen_range(0..max_group))
556
                .collect();
557
558
            let values: Vec<u32> = (0..num_values).map(|_| rng.gen()).collect();
559
560
            // 10% chance of false
561
            // 10% change of null
562
            // 80% chance of true
563
            let filter: BooleanArray = (0..num_values)
564
                .map(|_| {
565
                    let filter_value = rng.gen_range(0.0..1.0);
566
                    if filter_value < 0.1 {
567
                        Some(false)
568
                    } else if filter_value < 0.2 {
569
                        None
570
                    } else {
571
                        Some(true)
572
                    }
573
                })
574
                .collect();
575
576
            // random values with random number and location of nulls
577
            // random null percentage
578
            let null_pct: f32 = rng.gen_range(0.0..1.0);
579
            let values_with_nulls: Vec<Option<u32>> = (0..num_values)
580
                .map(|_| {
581
                    let is_null = null_pct < rng.gen_range(0.0..1.0);
582
                    if is_null {
583
                        None
584
                    } else {
585
                        Some(rng.gen())
586
                    }
587
                })
588
                .collect();
589
590
            Self {
591
                group_indices,
592
                values,
593
                values_with_nulls,
594
                filter,
595
            }
596
        }
597
598
        /// returns `Self::values` an Array
599
        fn values_array(&self) -> UInt32Array {
600
            UInt32Array::from(self.values.clone())
601
        }
602
603
        /// returns `Self::values_with_nulls` as an Array
604
        fn values_with_nulls_array(&self) -> UInt32Array {
605
            UInt32Array::from(self.values_with_nulls.clone())
606
        }
607
608
        /// Calls `NullState::accumulate` and `accumulate_indices`
609
        /// with all combinations of nulls and filter values
610
        fn run(&self) {
611
            let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
612
613
            let group_indices = &self.group_indices;
614
            let values_array = self.values_array();
615
            let values_with_nulls_array = self.values_with_nulls_array();
616
            let filter = &self.filter;
617
618
            // no null, no filters
619
            Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
620
621
            // nulls, no filters
622
            Self::accumulate_test(
623
                group_indices,
624
                &values_with_nulls_array,
625
                None,
626
                total_num_groups,
627
            );
628
629
            // no nulls, filters
630
            Self::accumulate_test(
631
                group_indices,
632
                &values_array,
633
                Some(filter),
634
                total_num_groups,
635
            );
636
637
            // nulls, filters
638
            Self::accumulate_test(
639
                group_indices,
640
                &values_with_nulls_array,
641
                Some(filter),
642
                total_num_groups,
643
            );
644
        }
645
646
        /// Calls `NullState::accumulate` and `accumulate_indices` to
647
        /// ensure it generates the correct values.
648
        ///
649
        fn accumulate_test(
650
            group_indices: &[usize],
651
            values: &UInt32Array,
652
            opt_filter: Option<&BooleanArray>,
653
            total_num_groups: usize,
654
        ) {
655
            Self::accumulate_values_test(
656
                group_indices,
657
                values,
658
                opt_filter,
659
                total_num_groups,
660
            );
661
            Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
662
663
            // Convert values into a boolean array (anything above the
664
            // average is true, otherwise false)
665
            let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
666
            let boolean_values: BooleanArray =
667
                values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
668
            Self::accumulate_boolean_test(
669
                group_indices,
670
                &boolean_values,
671
                opt_filter,
672
                total_num_groups,
673
            );
674
        }
675
676
        /// This is effectively a different implementation of
677
        /// accumulate that we compare with the above implementation
678
        fn accumulate_values_test(
679
            group_indices: &[usize],
680
            values: &UInt32Array,
681
            opt_filter: Option<&BooleanArray>,
682
            total_num_groups: usize,
683
        ) {
684
            let mut accumulated_values = vec![];
685
            let mut null_state = NullState::new();
686
687
            null_state.accumulate(
688
                group_indices,
689
                values,
690
                opt_filter,
691
                total_num_groups,
692
                |group_index, value| {
693
                    accumulated_values.push((group_index, value));
694
                },
695
            );
696
697
            // Figure out the expected values
698
            let mut expected_values = vec![];
699
            let mut mock = MockNullState::new();
700
701
            match opt_filter {
702
                None => group_indices.iter().zip(values.iter()).for_each(
703
                    |(&group_index, value)| {
704
                        if let Some(value) = value {
705
                            mock.saw_value(group_index);
706
                            expected_values.push((group_index, value));
707
                        }
708
                    },
709
                ),
710
                Some(filter) => {
711
                    group_indices
712
                        .iter()
713
                        .zip(values.iter())
714
                        .zip(filter.iter())
715
                        .for_each(|((&group_index, value), is_included)| {
716
                            // if value passed filter
717
                            if let Some(true) = is_included {
718
                                if let Some(value) = value {
719
                                    mock.saw_value(group_index);
720
                                    expected_values.push((group_index, value));
721
                                }
722
                            }
723
                        });
724
                }
725
            }
726
727
            assert_eq!(accumulated_values, expected_values,
728
                       "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
729
            let seen_values = null_state.seen_values.finish_cloned();
730
            mock.validate_seen_values(&seen_values);
731
732
            // Validate the final buffer (one value per group)
733
            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
734
735
            let null_buffer = null_state.build(EmitTo::All);
736
737
            assert_eq!(null_buffer, expected_null_buffer);
738
        }
739
740
        // Calls `accumulate_indices`
741
        // and opt_filter and ensures it calls the right values
742
        fn accumulate_indices_test(
743
            group_indices: &[usize],
744
            nulls: Option<&NullBuffer>,
745
            opt_filter: Option<&BooleanArray>,
746
        ) {
747
            let mut accumulated_values = vec![];
748
749
            accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
750
                accumulated_values.push(group_index);
751
            });
752
753
            // Figure out the expected values
754
            let mut expected_values = vec![];
755
756
            match (nulls, opt_filter) {
757
                (None, None) => group_indices.iter().for_each(|&group_index| {
758
                    expected_values.push(group_index);
759
                }),
760
                (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
761
                    |(&group_index, is_valid)| {
762
                        if is_valid {
763
                            expected_values.push(group_index);
764
                        }
765
                    },
766
                ),
767
                (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
768
                    |(&group_index, is_included)| {
769
                        if let Some(true) = is_included {
770
                            expected_values.push(group_index);
771
                        }
772
                    },
773
                ),
774
                (Some(nulls), Some(filter)) => {
775
                    group_indices
776
                        .iter()
777
                        .zip(nulls.iter())
778
                        .zip(filter.iter())
779
                        .for_each(|((&group_index, is_valid), is_included)| {
780
                            // if value passed filter
781
                            if let (true, Some(true)) = (is_valid, is_included) {
782
                                expected_values.push(group_index);
783
                            }
784
                        });
785
                }
786
            }
787
788
            assert_eq!(accumulated_values, expected_values,
789
                       "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
790
        }
791
792
        /// This is effectively a different implementation of
793
        /// accumulate_boolean that we compare with the above implementation
794
        fn accumulate_boolean_test(
795
            group_indices: &[usize],
796
            values: &BooleanArray,
797
            opt_filter: Option<&BooleanArray>,
798
            total_num_groups: usize,
799
        ) {
800
            let mut accumulated_values = vec![];
801
            let mut null_state = NullState::new();
802
803
            null_state.accumulate_boolean(
804
                group_indices,
805
                values,
806
                opt_filter,
807
                total_num_groups,
808
                |group_index, value| {
809
                    accumulated_values.push((group_index, value));
810
                },
811
            );
812
813
            // Figure out the expected values
814
            let mut expected_values = vec![];
815
            let mut mock = MockNullState::new();
816
817
            match opt_filter {
818
                None => group_indices.iter().zip(values.iter()).for_each(
819
                    |(&group_index, value)| {
820
                        if let Some(value) = value {
821
                            mock.saw_value(group_index);
822
                            expected_values.push((group_index, value));
823
                        }
824
                    },
825
                ),
826
                Some(filter) => {
827
                    group_indices
828
                        .iter()
829
                        .zip(values.iter())
830
                        .zip(filter.iter())
831
                        .for_each(|((&group_index, value), is_included)| {
832
                            // if value passed filter
833
                            if let Some(true) = is_included {
834
                                if let Some(value) = value {
835
                                    mock.saw_value(group_index);
836
                                    expected_values.push((group_index, value));
837
                                }
838
                            }
839
                        });
840
                }
841
            }
842
843
            assert_eq!(accumulated_values, expected_values,
844
                       "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
845
846
            let seen_values = null_state.seen_values.finish_cloned();
847
            mock.validate_seen_values(&seen_values);
848
849
            // Validate the final buffer (one value per group)
850
            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
851
852
            let null_buffer = null_state.build(EmitTo::All);
853
854
            assert_eq!(null_buffer, expected_null_buffer);
855
        }
856
    }
857
858
    /// Parallel implementation of NullState to check expected values
859
    #[derive(Debug, Default)]
860
    struct MockNullState {
861
        /// group indices that had values that passed the filter
862
        seen_values: HashSet<usize>,
863
    }
864
865
    impl MockNullState {
866
        fn new() -> Self {
867
            Default::default()
868
        }
869
870
        fn saw_value(&mut self, group_index: usize) {
871
            self.seen_values.insert(group_index);
872
        }
873
874
        /// did this group index see any input?
875
        fn expected_seen(&self, group_index: usize) -> bool {
876
            self.seen_values.contains(&group_index)
877
        }
878
879
        /// Validate that the seen_values matches self.seen_values
880
        fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
881
            for (group_index, is_seen) in seen_values.iter().enumerate() {
882
                let expected_seen = self.expected_seen(group_index);
883
                assert_eq!(
884
                    expected_seen, is_seen,
885
                    "mismatch at for group {group_index}"
886
                );
887
            }
888
        }
889
890
        /// Create the expected null buffer based on if the input had nulls and a filter
891
        fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
892
            (0..total_num_groups)
893
                .map(|group_index| self.expected_seen(group_index))
894
                .collect()
895
        }
896
    }
897
}