/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] |
19 | | //! |
20 | | //! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator |
21 | | |
22 | | use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; |
23 | | use arrow::buffer::{BooleanBuffer, NullBuffer}; |
24 | | use arrow::datatypes::ArrowPrimitiveType; |
25 | | |
26 | | use datafusion_expr_common::groups_accumulator::EmitTo; |
27 | | /// Track the accumulator null state per row: if any values for that |
28 | | /// group were null and if any values have been seen at all for that group. |
29 | | /// |
30 | | /// This is part of the inner loop for many [`GroupsAccumulator`]s, |
31 | | /// and thus the performance is critical and so there are multiple |
32 | | /// specialized implementations, invoked depending on the specific |
33 | | /// combinations of the input. |
34 | | /// |
35 | | /// Typically there are 4 potential combinations of inputs must be |
36 | | /// special cased for performance: |
37 | | /// |
38 | | /// * With / Without filter |
39 | | /// * With / Without nulls in the input |
40 | | /// |
41 | | /// If the input has nulls, then the accumulator must potentially |
42 | | /// handle each input null value specially (e.g. for `SUM` to mark the |
43 | | /// corresponding sum as null) |
44 | | /// |
45 | | /// If there are filters present, `NullState` tracks if it has seen |
46 | | /// *any* value for that group (as some values may be filtered |
47 | | /// out). Without a filter, the accumulator is only passed groups that |
48 | | /// had at least one value to accumulate so they do not need to track |
49 | | /// if they have seen values for a particular group. |
50 | | /// |
51 | | /// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator |
52 | | #[derive(Debug)] |
53 | | pub struct NullState { |
54 | | /// Have we seen any non-filtered input values for `group_index`? |
55 | | /// |
56 | | /// If `seen_values[i]` is true, have seen at least one non null |
57 | | /// value for group `i` |
58 | | /// |
59 | | /// If `seen_values[i]` is false, have not seen any values that |
60 | | /// pass the filter yet for group `i` |
61 | | seen_values: BooleanBufferBuilder, |
62 | | } |
63 | | |
64 | | impl Default for NullState { |
65 | 0 | fn default() -> Self { |
66 | 0 | Self::new() |
67 | 0 | } |
68 | | } |
69 | | |
70 | | impl NullState { |
71 | 15 | pub fn new() -> Self { |
72 | 15 | Self { |
73 | 15 | seen_values: BooleanBufferBuilder::new(0), |
74 | 15 | } |
75 | 15 | } |
76 | | |
77 | | /// return the size of all buffers allocated by this null state, not including self |
78 | 3 | pub fn size(&self) -> usize { |
79 | 3 | // capacity is in bits, so convert to bytes |
80 | 3 | self.seen_values.capacity() / 8 |
81 | 3 | } |
82 | | |
83 | | /// Invokes `value_fn(group_index, value)` for each non null, non |
84 | | /// filtered value of `value`, while tracking which groups have |
85 | | /// seen null inputs and which groups have seen any inputs if necessary |
86 | | // |
87 | | /// # Arguments: |
88 | | /// |
89 | | /// * `values`: the input arguments to the accumulator |
90 | | /// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) |
91 | | /// * `opt_filter`: if present, only rows for which is Some(true) are included |
92 | | /// * `value_fn`: function invoked for (group_index, value) where value is non null |
93 | | /// |
94 | | /// See [`accumulate`], for more details on how value_fn is called |
95 | | /// |
96 | | /// When value_fn is called it also sets |
97 | | /// |
98 | | /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale |
99 | 46 | pub fn accumulate<T, F>( |
100 | 46 | &mut self, |
101 | 46 | group_indices: &[usize], |
102 | 46 | values: &PrimitiveArray<T>, |
103 | 46 | opt_filter: Option<&BooleanArray>, |
104 | 46 | total_num_groups: usize, |
105 | 46 | mut value_fn: F, |
106 | 46 | ) where |
107 | 46 | T: ArrowPrimitiveType + Send, |
108 | 46 | F: FnMut(usize, T::Native) + Send, |
109 | 46 | { |
110 | 46 | // ensure the seen_values is big enough (start everything at |
111 | 46 | // "not seen" valid) |
112 | 46 | let seen_values = |
113 | 46 | initialize_builder(&mut self.seen_values, total_num_groups, false); |
114 | 123 | accumulate(group_indices, values, opt_filter, |group_index, value| { |
115 | 123 | seen_values.set_bit(group_index, true); |
116 | 123 | value_fn(group_index, value); |
117 | 123 | }); |
118 | 46 | } |
119 | | |
120 | | /// Invokes `value_fn(group_index, value)` for each non null, non |
121 | | /// filtered value in `values`, while tracking which groups have |
122 | | /// seen null inputs and which groups have seen any inputs, for |
123 | | /// [`BooleanArray`]s. |
124 | | /// |
125 | | /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be |
126 | | /// handled specially. |
127 | | /// |
128 | | /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for |
129 | | /// more details on other arguments. |
130 | 0 | pub fn accumulate_boolean<F>( |
131 | 0 | &mut self, |
132 | 0 | group_indices: &[usize], |
133 | 0 | values: &BooleanArray, |
134 | 0 | opt_filter: Option<&BooleanArray>, |
135 | 0 | total_num_groups: usize, |
136 | 0 | mut value_fn: F, |
137 | 0 | ) where |
138 | 0 | F: FnMut(usize, bool) + Send, |
139 | 0 | { |
140 | 0 | let data = values.values(); |
141 | 0 | assert_eq!(data.len(), group_indices.len()); |
142 | | |
143 | | // ensure the seen_values is big enough (start everything at |
144 | | // "not seen" valid) |
145 | 0 | let seen_values = |
146 | 0 | initialize_builder(&mut self.seen_values, total_num_groups, false); |
147 | 0 |
|
148 | 0 | // These could be made more performant by iterating in chunks of 64 bits at a time |
149 | 0 | match (values.null_count() > 0, opt_filter) { |
150 | | // no nulls, no filter, |
151 | | (false, None) => { |
152 | | // if we have previously seen nulls, ensure the null |
153 | | // buffer is big enough (start everything at valid) |
154 | 0 | group_indices.iter().zip(data.iter()).for_each( |
155 | 0 | |(&group_index, new_value)| { |
156 | 0 | seen_values.set_bit(group_index, true); |
157 | 0 | value_fn(group_index, new_value) |
158 | 0 | }, |
159 | 0 | ) |
160 | | } |
161 | | // nulls, no filter |
162 | | (true, None) => { |
163 | 0 | let nulls = values.nulls().unwrap(); |
164 | 0 | group_indices |
165 | 0 | .iter() |
166 | 0 | .zip(data.iter()) |
167 | 0 | .zip(nulls.iter()) |
168 | 0 | .for_each(|((&group_index, new_value), is_valid)| { |
169 | 0 | if is_valid { |
170 | 0 | seen_values.set_bit(group_index, true); |
171 | 0 | value_fn(group_index, new_value); |
172 | 0 | } |
173 | 0 | }) |
174 | | } |
175 | | // no nulls, but a filter |
176 | 0 | (false, Some(filter)) => { |
177 | 0 | assert_eq!(filter.len(), group_indices.len()); |
178 | | |
179 | 0 | group_indices |
180 | 0 | .iter() |
181 | 0 | .zip(data.iter()) |
182 | 0 | .zip(filter.iter()) |
183 | 0 | .for_each(|((&group_index, new_value), filter_value)| { |
184 | 0 | if let Some(true) = filter_value { |
185 | 0 | seen_values.set_bit(group_index, true); |
186 | 0 | value_fn(group_index, new_value); |
187 | 0 | } |
188 | 0 | }) |
189 | | } |
190 | | // both null values and filters |
191 | 0 | (true, Some(filter)) => { |
192 | 0 | assert_eq!(filter.len(), group_indices.len()); |
193 | 0 | filter |
194 | 0 | .iter() |
195 | 0 | .zip(group_indices.iter()) |
196 | 0 | .zip(values.iter()) |
197 | 0 | .for_each(|((filter_value, &group_index), new_value)| { |
198 | 0 | if let Some(true) = filter_value { |
199 | 0 | if let Some(new_value) = new_value { |
200 | 0 | seen_values.set_bit(group_index, true); |
201 | 0 | value_fn(group_index, new_value) |
202 | 0 | } |
203 | 0 | } |
204 | 0 | }) |
205 | | } |
206 | | } |
207 | 0 | } |
208 | | |
209 | | /// Creates the a [`NullBuffer`] representing which group_indices |
210 | | /// should have null values (because they never saw any values) |
211 | | /// for the `emit_to` rows. |
212 | | /// |
213 | | /// resets the internal state appropriately |
214 | 29 | pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer { |
215 | 29 | let nulls: BooleanBuffer = self.seen_values.finish(); |
216 | | |
217 | 29 | let nulls = match emit_to { |
218 | 17 | EmitTo::All => nulls, |
219 | 12 | EmitTo::First(n) => { |
220 | 12 | // split off the first N values in seen_values |
221 | 12 | // |
222 | 12 | // TODO make this more efficient rather than two |
223 | 12 | // copies and bitwise manipulation |
224 | 12 | let first_n_null: BooleanBuffer = nulls.iter().take(n).collect(); |
225 | | // reset the existing seen buffer |
226 | 12 | for seen in nulls.iter().skip(n) { |
227 | 12 | self.seen_values.append(seen); |
228 | 12 | } |
229 | 12 | first_n_null |
230 | | } |
231 | | }; |
232 | 29 | NullBuffer::new(nulls) |
233 | 29 | } |
234 | | } |
235 | | |
236 | | /// Invokes `value_fn(group_index, value)` for each non null, non |
237 | | /// filtered value of `value`, |
238 | | /// |
239 | | /// # Arguments: |
240 | | /// |
241 | | /// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) |
242 | | /// * `values`: the input arguments to the accumulator |
243 | | /// * `opt_filter`: if present, only rows for which is Some(true) are included |
244 | | /// * `value_fn`: function invoked for (group_index, value) where value is non null |
245 | | /// |
246 | | /// # Example |
247 | | /// |
248 | | /// ```text |
249 | | /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ |
250 | | /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ |
251 | | /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ |
252 | | /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ |
253 | | /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ |
254 | | /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ |
255 | | /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ |
256 | | /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ |
257 | | /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ |
258 | | /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ |
259 | | /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ |
260 | | /// │ └─────┘ │ │ └─────┘ │ └─────┘ |
261 | | /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ |
262 | | /// |
263 | | /// group_indices values opt_filter |
264 | | /// ``` |
265 | | /// |
266 | | /// In the example above, `value_fn` is invoked for each (group_index, |
267 | | /// value) pair where `opt_filter[i]` is true and values is non null |
268 | | /// |
269 | | /// ```text |
270 | | /// value_fn(2, 200) |
271 | | /// value_fn(0, 200) |
272 | | /// value_fn(0, 300) |
273 | | /// ``` |
274 | 46 | pub fn accumulate<T, F>( |
275 | 46 | group_indices: &[usize], |
276 | 46 | values: &PrimitiveArray<T>, |
277 | 46 | opt_filter: Option<&BooleanArray>, |
278 | 46 | mut value_fn: F, |
279 | 46 | ) where |
280 | 46 | T: ArrowPrimitiveType + Send, |
281 | 46 | F: FnMut(usize, T::Native) + Send, |
282 | 46 | { |
283 | 46 | let data: &[T::Native] = values.values(); |
284 | 46 | assert_eq!(data.len(), group_indices.len()); |
285 | | |
286 | 46 | match (values.null_count() > 0, opt_filter) { |
287 | | // no nulls, no filter, |
288 | | (false, None) => { |
289 | 46 | let iter = group_indices.iter().zip(data.iter()); |
290 | 169 | for (&group_index, &new_value123 ) in iter { |
291 | 123 | value_fn(group_index, new_value); |
292 | 123 | } |
293 | | } |
294 | | // nulls, no filter |
295 | 0 | (true, None) => { |
296 | 0 | let nulls = values.nulls().unwrap(); |
297 | 0 | // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum |
298 | 0 | // iterate over in chunks of 64 bits for more efficient null checking |
299 | 0 | let group_indices_chunks = group_indices.chunks_exact(64); |
300 | 0 | let data_chunks = data.chunks_exact(64); |
301 | 0 | let bit_chunks = nulls.inner().bit_chunks(); |
302 | 0 |
|
303 | 0 | let group_indices_remainder = group_indices_chunks.remainder(); |
304 | 0 | let data_remainder = data_chunks.remainder(); |
305 | 0 |
|
306 | 0 | group_indices_chunks |
307 | 0 | .zip(data_chunks) |
308 | 0 | .zip(bit_chunks.iter()) |
309 | 0 | .for_each(|((group_index_chunk, data_chunk), mask)| { |
310 | 0 | // index_mask has value 1 << i in the loop |
311 | 0 | let mut index_mask = 1; |
312 | 0 | group_index_chunk.iter().zip(data_chunk.iter()).for_each( |
313 | 0 | |(&group_index, &new_value)| { |
314 | 0 | // valid bit was set, real value |
315 | 0 | let is_valid = (mask & index_mask) != 0; |
316 | 0 | if is_valid { |
317 | 0 | value_fn(group_index, new_value); |
318 | 0 | } |
319 | 0 | index_mask <<= 1; |
320 | 0 | }, |
321 | 0 | ) |
322 | 0 | }); |
323 | 0 |
|
324 | 0 | // handle any remaining bits (after the initial 64) |
325 | 0 | let remainder_bits = bit_chunks.remainder_bits(); |
326 | 0 | group_indices_remainder |
327 | 0 | .iter() |
328 | 0 | .zip(data_remainder.iter()) |
329 | 0 | .enumerate() |
330 | 0 | .for_each(|(i, (&group_index, &new_value))| { |
331 | 0 | let is_valid = remainder_bits & (1 << i) != 0; |
332 | 0 | if is_valid { |
333 | 0 | value_fn(group_index, new_value); |
334 | 0 | } |
335 | 0 | }); |
336 | 0 | } |
337 | | // no nulls, but a filter |
338 | 0 | (false, Some(filter)) => { |
339 | 0 | assert_eq!(filter.len(), group_indices.len()); |
340 | | // The performance with a filter could be improved by |
341 | | // iterating over the filter in chunks, rather than a single |
342 | | // iterator. TODO file a ticket |
343 | 0 | group_indices |
344 | 0 | .iter() |
345 | 0 | .zip(data.iter()) |
346 | 0 | .zip(filter.iter()) |
347 | 0 | .for_each(|((&group_index, &new_value), filter_value)| { |
348 | 0 | if let Some(true) = filter_value { |
349 | 0 | value_fn(group_index, new_value); |
350 | 0 | } |
351 | 0 | }) |
352 | | } |
353 | | // both null values and filters |
354 | 0 | (true, Some(filter)) => { |
355 | 0 | assert_eq!(filter.len(), group_indices.len()); |
356 | | // The performance with a filter could be improved by |
357 | | // iterating over the filter in chunks, rather than using |
358 | | // iterators. TODO file a ticket |
359 | 0 | filter |
360 | 0 | .iter() |
361 | 0 | .zip(group_indices.iter()) |
362 | 0 | .zip(values.iter()) |
363 | 0 | .for_each(|((filter_value, &group_index), new_value)| { |
364 | 0 | if let Some(true) = filter_value { |
365 | 0 | if let Some(new_value) = new_value { |
366 | 0 | value_fn(group_index, new_value) |
367 | 0 | } |
368 | 0 | } |
369 | 0 | }) |
370 | | } |
371 | | } |
372 | 46 | } |
373 | | |
374 | | /// This function is called to update the accumulator state per row |
375 | | /// when the value is not needed (e.g. COUNT) |
376 | | /// |
377 | | /// `F`: Invoked like `value_fn(group_index) for all non null values |
378 | | /// passing the filter. Note that no tracking is done for null inputs |
379 | | /// or which groups have seen any values |
380 | | /// |
381 | | /// See [`NullState::accumulate`], for more details on other |
382 | | /// arguments. |
383 | 63 | pub fn accumulate_indices<F>( |
384 | 63 | group_indices: &[usize], |
385 | 63 | nulls: Option<&NullBuffer>, |
386 | 63 | opt_filter: Option<&BooleanArray>, |
387 | 63 | mut index_fn: F, |
388 | 63 | ) where |
389 | 63 | F: FnMut(usize) + Send, |
390 | 63 | { |
391 | 63 | match (nulls, opt_filter) { |
392 | | (None, None) => { |
393 | 98.5k | for &group_index in group_indices.iter()63 { |
394 | 98.5k | index_fn(group_index) |
395 | | } |
396 | | } |
397 | 0 | (None, Some(filter)) => { |
398 | 0 | assert_eq!(filter.len(), group_indices.len()); |
399 | | // The performance with a filter could be improved by |
400 | | // iterating over the filter in chunks, rather than a single |
401 | | // iterator. TODO file a ticket |
402 | 0 | let iter = group_indices.iter().zip(filter.iter()); |
403 | 0 | for (&group_index, filter_value) in iter { |
404 | 0 | if let Some(true) = filter_value { |
405 | 0 | index_fn(group_index) |
406 | 0 | } |
407 | | } |
408 | | } |
409 | 0 | (Some(valids), None) => { |
410 | 0 | assert_eq!(valids.len(), group_indices.len()); |
411 | | // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum |
412 | | // iterate over in chunks of 64 bits for more efficient null checking |
413 | 0 | let group_indices_chunks = group_indices.chunks_exact(64); |
414 | 0 | let bit_chunks = valids.inner().bit_chunks(); |
415 | 0 |
|
416 | 0 | let group_indices_remainder = group_indices_chunks.remainder(); |
417 | 0 |
|
418 | 0 | group_indices_chunks.zip(bit_chunks.iter()).for_each( |
419 | 0 | |(group_index_chunk, mask)| { |
420 | 0 | // index_mask has value 1 << i in the loop |
421 | 0 | let mut index_mask = 1; |
422 | 0 | group_index_chunk.iter().for_each(|&group_index| { |
423 | 0 | // valid bit was set, real vale |
424 | 0 | let is_valid = (mask & index_mask) != 0; |
425 | 0 | if is_valid { |
426 | 0 | index_fn(group_index); |
427 | 0 | } |
428 | 0 | index_mask <<= 1; |
429 | 0 | }) |
430 | 0 | }, |
431 | 0 | ); |
432 | 0 |
|
433 | 0 | // handle any remaining bits (after the initial 64) |
434 | 0 | let remainder_bits = bit_chunks.remainder_bits(); |
435 | 0 | group_indices_remainder |
436 | 0 | .iter() |
437 | 0 | .enumerate() |
438 | 0 | .for_each(|(i, &group_index)| { |
439 | 0 | let is_valid = remainder_bits & (1 << i) != 0; |
440 | 0 | if is_valid { |
441 | 0 | index_fn(group_index) |
442 | 0 | } |
443 | 0 | }); |
444 | 0 | } |
445 | | |
446 | 0 | (Some(valids), Some(filter)) => { |
447 | 0 | assert_eq!(filter.len(), group_indices.len()); |
448 | 0 | assert_eq!(valids.len(), group_indices.len()); |
449 | | // The performance with a filter could likely be improved by |
450 | | // iterating over the filter in chunks, rather than using |
451 | | // iterators. TODO file a ticket |
452 | 0 | filter |
453 | 0 | .iter() |
454 | 0 | .zip(group_indices.iter()) |
455 | 0 | .zip(valids.iter()) |
456 | 0 | .for_each(|((filter_value, &group_index), is_valid)| { |
457 | 0 | if let (Some(true), true) = (filter_value, is_valid) { |
458 | 0 | index_fn(group_index) |
459 | 0 | } |
460 | 0 | }) |
461 | | } |
462 | | } |
463 | 63 | } |
464 | | |
465 | | /// Ensures that `builder` contains a `BooleanBufferBuilder with at |
466 | | /// least `total_num_groups`. |
467 | | /// |
468 | | /// All new entries are initialized to `default_value` |
469 | 46 | fn initialize_builder( |
470 | 46 | builder: &mut BooleanBufferBuilder, |
471 | 46 | total_num_groups: usize, |
472 | 46 | default_value: bool, |
473 | 46 | ) -> &mut BooleanBufferBuilder { |
474 | 46 | if builder.len() < total_num_groups { |
475 | 28 | let new_groups = total_num_groups - builder.len(); |
476 | 28 | builder.append_n(new_groups, default_value); |
477 | 28 | }18 |
478 | 46 | builder |
479 | 46 | } |
480 | | |
481 | | #[cfg(test)] |
482 | | mod test { |
483 | | use super::*; |
484 | | |
485 | | use arrow::array::UInt32Array; |
486 | | use rand::{rngs::ThreadRng, Rng}; |
487 | | use std::collections::HashSet; |
488 | | |
489 | | #[test] |
490 | | fn accumulate() { |
491 | | let group_indices = (0..100).collect(); |
492 | | let values = (0..100).map(|i| (i + 1) * 10).collect(); |
493 | | let values_with_nulls = (0..100) |
494 | | .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) }) |
495 | | .collect(); |
496 | | |
497 | | // default to every fifth value being false, every even |
498 | | // being null |
499 | | let filter: BooleanArray = (0..100) |
500 | | .map(|i| { |
501 | | let is_even = i % 2 == 0; |
502 | | let is_fifth = i % 5 == 0; |
503 | | if is_even { |
504 | | None |
505 | | } else if is_fifth { |
506 | | Some(false) |
507 | | } else { |
508 | | Some(true) |
509 | | } |
510 | | }) |
511 | | .collect(); |
512 | | |
513 | | Fixture { |
514 | | group_indices, |
515 | | values, |
516 | | values_with_nulls, |
517 | | filter, |
518 | | } |
519 | | .run() |
520 | | } |
521 | | |
522 | | #[test] |
523 | | fn accumulate_fuzz() { |
524 | | let mut rng = rand::thread_rng(); |
525 | | for _ in 0..100 { |
526 | | Fixture::new_random(&mut rng).run(); |
527 | | } |
528 | | } |
529 | | |
530 | | /// Values for testing (there are enough values to exercise the 64 bit chunks |
531 | | struct Fixture { |
532 | | /// 100..0 |
533 | | group_indices: Vec<usize>, |
534 | | |
535 | | /// 10, 20, ... 1010 |
536 | | values: Vec<u32>, |
537 | | |
538 | | /// same as values, but every third is null: |
539 | | /// None, Some(20), Some(30), None ... |
540 | | values_with_nulls: Vec<Option<u32>>, |
541 | | |
542 | | /// filter (defaults to None) |
543 | | filter: BooleanArray, |
544 | | } |
545 | | |
546 | | impl Fixture { |
547 | | fn new_random(rng: &mut ThreadRng) -> Self { |
548 | | // Number of input values in a batch |
549 | | let num_values: usize = rng.gen_range(1..200); |
550 | | // number of distinct groups |
551 | | let num_groups: usize = rng.gen_range(2..1000); |
552 | | let max_group = num_groups - 1; |
553 | | |
554 | | let group_indices: Vec<usize> = (0..num_values) |
555 | | .map(|_| rng.gen_range(0..max_group)) |
556 | | .collect(); |
557 | | |
558 | | let values: Vec<u32> = (0..num_values).map(|_| rng.gen()).collect(); |
559 | | |
560 | | // 10% chance of false |
561 | | // 10% change of null |
562 | | // 80% chance of true |
563 | | let filter: BooleanArray = (0..num_values) |
564 | | .map(|_| { |
565 | | let filter_value = rng.gen_range(0.0..1.0); |
566 | | if filter_value < 0.1 { |
567 | | Some(false) |
568 | | } else if filter_value < 0.2 { |
569 | | None |
570 | | } else { |
571 | | Some(true) |
572 | | } |
573 | | }) |
574 | | .collect(); |
575 | | |
576 | | // random values with random number and location of nulls |
577 | | // random null percentage |
578 | | let null_pct: f32 = rng.gen_range(0.0..1.0); |
579 | | let values_with_nulls: Vec<Option<u32>> = (0..num_values) |
580 | | .map(|_| { |
581 | | let is_null = null_pct < rng.gen_range(0.0..1.0); |
582 | | if is_null { |
583 | | None |
584 | | } else { |
585 | | Some(rng.gen()) |
586 | | } |
587 | | }) |
588 | | .collect(); |
589 | | |
590 | | Self { |
591 | | group_indices, |
592 | | values, |
593 | | values_with_nulls, |
594 | | filter, |
595 | | } |
596 | | } |
597 | | |
598 | | /// returns `Self::values` an Array |
599 | | fn values_array(&self) -> UInt32Array { |
600 | | UInt32Array::from(self.values.clone()) |
601 | | } |
602 | | |
603 | | /// returns `Self::values_with_nulls` as an Array |
604 | | fn values_with_nulls_array(&self) -> UInt32Array { |
605 | | UInt32Array::from(self.values_with_nulls.clone()) |
606 | | } |
607 | | |
608 | | /// Calls `NullState::accumulate` and `accumulate_indices` |
609 | | /// with all combinations of nulls and filter values |
610 | | fn run(&self) { |
611 | | let total_num_groups = *self.group_indices.iter().max().unwrap() + 1; |
612 | | |
613 | | let group_indices = &self.group_indices; |
614 | | let values_array = self.values_array(); |
615 | | let values_with_nulls_array = self.values_with_nulls_array(); |
616 | | let filter = &self.filter; |
617 | | |
618 | | // no null, no filters |
619 | | Self::accumulate_test(group_indices, &values_array, None, total_num_groups); |
620 | | |
621 | | // nulls, no filters |
622 | | Self::accumulate_test( |
623 | | group_indices, |
624 | | &values_with_nulls_array, |
625 | | None, |
626 | | total_num_groups, |
627 | | ); |
628 | | |
629 | | // no nulls, filters |
630 | | Self::accumulate_test( |
631 | | group_indices, |
632 | | &values_array, |
633 | | Some(filter), |
634 | | total_num_groups, |
635 | | ); |
636 | | |
637 | | // nulls, filters |
638 | | Self::accumulate_test( |
639 | | group_indices, |
640 | | &values_with_nulls_array, |
641 | | Some(filter), |
642 | | total_num_groups, |
643 | | ); |
644 | | } |
645 | | |
646 | | /// Calls `NullState::accumulate` and `accumulate_indices` to |
647 | | /// ensure it generates the correct values. |
648 | | /// |
649 | | fn accumulate_test( |
650 | | group_indices: &[usize], |
651 | | values: &UInt32Array, |
652 | | opt_filter: Option<&BooleanArray>, |
653 | | total_num_groups: usize, |
654 | | ) { |
655 | | Self::accumulate_values_test( |
656 | | group_indices, |
657 | | values, |
658 | | opt_filter, |
659 | | total_num_groups, |
660 | | ); |
661 | | Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter); |
662 | | |
663 | | // Convert values into a boolean array (anything above the |
664 | | // average is true, otherwise false) |
665 | | let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum(); |
666 | | let boolean_values: BooleanArray = |
667 | | values.iter().map(|v| v.map(|v| v as usize > avg)).collect(); |
668 | | Self::accumulate_boolean_test( |
669 | | group_indices, |
670 | | &boolean_values, |
671 | | opt_filter, |
672 | | total_num_groups, |
673 | | ); |
674 | | } |
675 | | |
676 | | /// This is effectively a different implementation of |
677 | | /// accumulate that we compare with the above implementation |
678 | | fn accumulate_values_test( |
679 | | group_indices: &[usize], |
680 | | values: &UInt32Array, |
681 | | opt_filter: Option<&BooleanArray>, |
682 | | total_num_groups: usize, |
683 | | ) { |
684 | | let mut accumulated_values = vec![]; |
685 | | let mut null_state = NullState::new(); |
686 | | |
687 | | null_state.accumulate( |
688 | | group_indices, |
689 | | values, |
690 | | opt_filter, |
691 | | total_num_groups, |
692 | | |group_index, value| { |
693 | | accumulated_values.push((group_index, value)); |
694 | | }, |
695 | | ); |
696 | | |
697 | | // Figure out the expected values |
698 | | let mut expected_values = vec![]; |
699 | | let mut mock = MockNullState::new(); |
700 | | |
701 | | match opt_filter { |
702 | | None => group_indices.iter().zip(values.iter()).for_each( |
703 | | |(&group_index, value)| { |
704 | | if let Some(value) = value { |
705 | | mock.saw_value(group_index); |
706 | | expected_values.push((group_index, value)); |
707 | | } |
708 | | }, |
709 | | ), |
710 | | Some(filter) => { |
711 | | group_indices |
712 | | .iter() |
713 | | .zip(values.iter()) |
714 | | .zip(filter.iter()) |
715 | | .for_each(|((&group_index, value), is_included)| { |
716 | | // if value passed filter |
717 | | if let Some(true) = is_included { |
718 | | if let Some(value) = value { |
719 | | mock.saw_value(group_index); |
720 | | expected_values.push((group_index, value)); |
721 | | } |
722 | | } |
723 | | }); |
724 | | } |
725 | | } |
726 | | |
727 | | assert_eq!(accumulated_values, expected_values, |
728 | | "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); |
729 | | let seen_values = null_state.seen_values.finish_cloned(); |
730 | | mock.validate_seen_values(&seen_values); |
731 | | |
732 | | // Validate the final buffer (one value per group) |
733 | | let expected_null_buffer = mock.expected_null_buffer(total_num_groups); |
734 | | |
735 | | let null_buffer = null_state.build(EmitTo::All); |
736 | | |
737 | | assert_eq!(null_buffer, expected_null_buffer); |
738 | | } |
739 | | |
740 | | // Calls `accumulate_indices` |
741 | | // and opt_filter and ensures it calls the right values |
742 | | fn accumulate_indices_test( |
743 | | group_indices: &[usize], |
744 | | nulls: Option<&NullBuffer>, |
745 | | opt_filter: Option<&BooleanArray>, |
746 | | ) { |
747 | | let mut accumulated_values = vec![]; |
748 | | |
749 | | accumulate_indices(group_indices, nulls, opt_filter, |group_index| { |
750 | | accumulated_values.push(group_index); |
751 | | }); |
752 | | |
753 | | // Figure out the expected values |
754 | | let mut expected_values = vec![]; |
755 | | |
756 | | match (nulls, opt_filter) { |
757 | | (None, None) => group_indices.iter().for_each(|&group_index| { |
758 | | expected_values.push(group_index); |
759 | | }), |
760 | | (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each( |
761 | | |(&group_index, is_valid)| { |
762 | | if is_valid { |
763 | | expected_values.push(group_index); |
764 | | } |
765 | | }, |
766 | | ), |
767 | | (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each( |
768 | | |(&group_index, is_included)| { |
769 | | if let Some(true) = is_included { |
770 | | expected_values.push(group_index); |
771 | | } |
772 | | }, |
773 | | ), |
774 | | (Some(nulls), Some(filter)) => { |
775 | | group_indices |
776 | | .iter() |
777 | | .zip(nulls.iter()) |
778 | | .zip(filter.iter()) |
779 | | .for_each(|((&group_index, is_valid), is_included)| { |
780 | | // if value passed filter |
781 | | if let (true, Some(true)) = (is_valid, is_included) { |
782 | | expected_values.push(group_index); |
783 | | } |
784 | | }); |
785 | | } |
786 | | } |
787 | | |
788 | | assert_eq!(accumulated_values, expected_values, |
789 | | "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); |
790 | | } |
791 | | |
792 | | /// This is effectively a different implementation of |
793 | | /// accumulate_boolean that we compare with the above implementation |
794 | | fn accumulate_boolean_test( |
795 | | group_indices: &[usize], |
796 | | values: &BooleanArray, |
797 | | opt_filter: Option<&BooleanArray>, |
798 | | total_num_groups: usize, |
799 | | ) { |
800 | | let mut accumulated_values = vec![]; |
801 | | let mut null_state = NullState::new(); |
802 | | |
803 | | null_state.accumulate_boolean( |
804 | | group_indices, |
805 | | values, |
806 | | opt_filter, |
807 | | total_num_groups, |
808 | | |group_index, value| { |
809 | | accumulated_values.push((group_index, value)); |
810 | | }, |
811 | | ); |
812 | | |
813 | | // Figure out the expected values |
814 | | let mut expected_values = vec![]; |
815 | | let mut mock = MockNullState::new(); |
816 | | |
817 | | match opt_filter { |
818 | | None => group_indices.iter().zip(values.iter()).for_each( |
819 | | |(&group_index, value)| { |
820 | | if let Some(value) = value { |
821 | | mock.saw_value(group_index); |
822 | | expected_values.push((group_index, value)); |
823 | | } |
824 | | }, |
825 | | ), |
826 | | Some(filter) => { |
827 | | group_indices |
828 | | .iter() |
829 | | .zip(values.iter()) |
830 | | .zip(filter.iter()) |
831 | | .for_each(|((&group_index, value), is_included)| { |
832 | | // if value passed filter |
833 | | if let Some(true) = is_included { |
834 | | if let Some(value) = value { |
835 | | mock.saw_value(group_index); |
836 | | expected_values.push((group_index, value)); |
837 | | } |
838 | | } |
839 | | }); |
840 | | } |
841 | | } |
842 | | |
843 | | assert_eq!(accumulated_values, expected_values, |
844 | | "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); |
845 | | |
846 | | let seen_values = null_state.seen_values.finish_cloned(); |
847 | | mock.validate_seen_values(&seen_values); |
848 | | |
849 | | // Validate the final buffer (one value per group) |
850 | | let expected_null_buffer = mock.expected_null_buffer(total_num_groups); |
851 | | |
852 | | let null_buffer = null_state.build(EmitTo::All); |
853 | | |
854 | | assert_eq!(null_buffer, expected_null_buffer); |
855 | | } |
856 | | } |
857 | | |
858 | | /// Parallel implementation of NullState to check expected values |
859 | | #[derive(Debug, Default)] |
860 | | struct MockNullState { |
861 | | /// group indices that had values that passed the filter |
862 | | seen_values: HashSet<usize>, |
863 | | } |
864 | | |
865 | | impl MockNullState { |
866 | | fn new() -> Self { |
867 | | Default::default() |
868 | | } |
869 | | |
870 | | fn saw_value(&mut self, group_index: usize) { |
871 | | self.seen_values.insert(group_index); |
872 | | } |
873 | | |
874 | | /// did this group index see any input? |
875 | | fn expected_seen(&self, group_index: usize) -> bool { |
876 | | self.seen_values.contains(&group_index) |
877 | | } |
878 | | |
879 | | /// Validate that the seen_values matches self.seen_values |
880 | | fn validate_seen_values(&self, seen_values: &BooleanBuffer) { |
881 | | for (group_index, is_seen) in seen_values.iter().enumerate() { |
882 | | let expected_seen = self.expected_seen(group_index); |
883 | | assert_eq!( |
884 | | expected_seen, is_seen, |
885 | | "mismatch at for group {group_index}" |
886 | | ); |
887 | | } |
888 | | } |
889 | | |
890 | | /// Create the expected null buffer based on if the input had nulls and a filter |
891 | | fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer { |
892 | | (0..total_num_groups) |
893 | | .map(|group_index| self.expected_seen(group_index)) |
894 | | .collect() |
895 | | } |
896 | | } |
897 | | } |