/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/bit_and_or_xor.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators |
19 | | |
20 | | use std::any::Any; |
21 | | use std::collections::HashSet; |
22 | | use std::fmt::{Display, Formatter}; |
23 | | |
24 | | use ahash::RandomState; |
25 | | use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; |
26 | | use arrow::datatypes::{ |
27 | | ArrowNativeType, ArrowNumericType, DataType, Int16Type, Int32Type, Int64Type, |
28 | | Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, |
29 | | }; |
30 | | use arrow_schema::Field; |
31 | | |
32 | | use datafusion_common::cast::as_list_array; |
33 | | use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; |
34 | | use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; |
35 | | use datafusion_expr::type_coercion::aggregates::INTEGERS; |
36 | | use datafusion_expr::utils::format_state_name; |
37 | | use datafusion_expr::{ |
38 | | Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, |
39 | | Signature, Volatility, |
40 | | }; |
41 | | |
42 | | use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; |
43 | | use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; |
44 | | use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign}; |
45 | | use std::sync::OnceLock; |
46 | | |
47 | | /// This macro helps create group accumulators based on bitwise operations typically used internally |
48 | | /// and might not be necessary for users to call directly. |
49 | | macro_rules! group_accumulator_helper { |
50 | | ($t:ty, $dt:expr, $opr:expr) => { |
51 | | match $opr { |
52 | | BitwiseOperationType::And => Ok(Box::new( |
53 | 0 | PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y)) |
54 | | .with_starting_value(!0), |
55 | | )), |
56 | | BitwiseOperationType::Or => Ok(Box::new( |
57 | 0 | PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)), |
58 | | )), |
59 | | BitwiseOperationType::Xor => Ok(Box::new( |
60 | 0 | PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)), |
61 | | )), |
62 | | } |
63 | | }; |
64 | | } |
65 | | |
66 | | /// `accumulator_helper` is a macro accepting (ArrowPrimitiveType, BitwiseOperationType, bool) |
67 | | macro_rules! accumulator_helper { |
68 | | ($t:ty, $opr:expr, $is_distinct: expr) => { |
69 | | match $opr { |
70 | | BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()), |
71 | | BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()), |
72 | | BitwiseOperationType::Xor => { |
73 | | if $is_distinct { |
74 | | Ok(Box::<DistinctBitXorAccumulator<$t>>::default()) |
75 | | } else { |
76 | | Ok(Box::<BitXorAccumulator<$t>>::default()) |
77 | | } |
78 | | } |
79 | | } |
80 | | }; |
81 | | } |
82 | | |
83 | | /// AND, OR and XOR only supports a subset of numeric types |
84 | | /// |
85 | | /// `args` is [AccumulatorArgs] |
86 | | /// `opr` is [BitwiseOperationType] |
87 | | /// `is_distinct` is boolean value indicating whether the operation is distinct or not. |
88 | | macro_rules! downcast_bitwise_accumulator { |
89 | | ($args:ident, $opr:expr, $is_distinct: expr) => { |
90 | | match $args.return_type { |
91 | | DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), |
92 | | DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), |
93 | | DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), |
94 | | DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct), |
95 | | DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct), |
96 | | DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct), |
97 | | DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct), |
98 | | DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct), |
99 | | _ => { |
100 | | not_impl_err!( |
101 | | "{} not supported for {}: {}", |
102 | | stringify!($opr), |
103 | | $args.name, |
104 | | $args.return_type |
105 | | ) |
106 | | } |
107 | | } |
108 | | }; |
109 | | } |
110 | | |
111 | | /// Simplifies the creation of User-Defined Aggregate Functions (UDAFs) for performing bitwise operations in a declarative manner. |
112 | | /// |
113 | | /// `EXPR_FN` identifier used to name the generated expression function. |
114 | | /// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function. |
115 | | /// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed. |
116 | | /// `DOCUMENTATION` documentation for the UDAF |
117 | | macro_rules! make_bitwise_udaf_expr_and_func { |
118 | | ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => { |
119 | | make_udaf_expr!( |
120 | | $EXPR_FN, |
121 | | expr_x, |
122 | | concat!( |
123 | | "Returns the bitwise", |
124 | | stringify!($OPR_TYPE), |
125 | | "of a group of values" |
126 | | ), |
127 | | $AGGREGATE_UDF_FN |
128 | | ); |
129 | | create_func!( |
130 | | $EXPR_FN, |
131 | | $AGGREGATE_UDF_FN, |
132 | | BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION) |
133 | | ); |
134 | | }; |
135 | | } |
136 | | |
137 | | static BIT_AND_DOC: OnceLock<Documentation> = OnceLock::new(); |
138 | | |
139 | 0 | fn get_bit_and_doc() -> &'static Documentation { |
140 | 0 | BIT_AND_DOC.get_or_init(|| { |
141 | 0 | Documentation::builder() |
142 | 0 | .with_doc_section(DOC_SECTION_GENERAL) |
143 | 0 | .with_description("Computes the bitwise AND of all non-null input values.") |
144 | 0 | .with_syntax_example("bit_and(expression)") |
145 | 0 | .with_standard_argument("expression", "Integer") |
146 | 0 | .build() |
147 | 0 | .unwrap() |
148 | 0 | }) |
149 | 0 | } |
150 | | |
151 | | static BIT_OR_DOC: OnceLock<Documentation> = OnceLock::new(); |
152 | | |
153 | 0 | fn get_bit_or_doc() -> &'static Documentation { |
154 | 0 | BIT_OR_DOC.get_or_init(|| { |
155 | 0 | Documentation::builder() |
156 | 0 | .with_doc_section(DOC_SECTION_GENERAL) |
157 | 0 | .with_description("Computes the bitwise OR of all non-null input values.") |
158 | 0 | .with_syntax_example("bit_or(expression)") |
159 | 0 | .with_standard_argument("expression", "Integer") |
160 | 0 | .build() |
161 | 0 | .unwrap() |
162 | 0 | }) |
163 | 0 | } |
164 | | |
165 | | static BIT_XOR_DOC: OnceLock<Documentation> = OnceLock::new(); |
166 | | |
167 | 0 | fn get_bit_xor_doc() -> &'static Documentation { |
168 | 0 | BIT_XOR_DOC.get_or_init(|| { |
169 | 0 | Documentation::builder() |
170 | 0 | .with_doc_section(DOC_SECTION_GENERAL) |
171 | 0 | .with_description( |
172 | 0 | "Computes the bitwise exclusive OR of all non-null input values.", |
173 | 0 | ) |
174 | 0 | .with_syntax_example("bit_xor(expression)") |
175 | 0 | .with_standard_argument("expression", "Integer") |
176 | 0 | .build() |
177 | 0 | .unwrap() |
178 | 0 | }) |
179 | 0 | } |
180 | | |
181 | | make_bitwise_udaf_expr_and_func!( |
182 | | bit_and, |
183 | | bit_and_udaf, |
184 | | BitwiseOperationType::And, |
185 | | get_bit_and_doc() |
186 | | ); |
187 | | make_bitwise_udaf_expr_and_func!( |
188 | | bit_or, |
189 | | bit_or_udaf, |
190 | | BitwiseOperationType::Or, |
191 | | get_bit_or_doc() |
192 | | ); |
193 | | make_bitwise_udaf_expr_and_func!( |
194 | | bit_xor, |
195 | | bit_xor_udaf, |
196 | | BitwiseOperationType::Xor, |
197 | | get_bit_xor_doc() |
198 | | ); |
199 | | |
200 | | /// The different types of bitwise operations that can be performed. |
201 | | #[derive(Debug, Clone, Eq, PartialEq)] |
202 | | enum BitwiseOperationType { |
203 | | And, |
204 | | Or, |
205 | | Xor, |
206 | | } |
207 | | |
208 | | impl Display for BitwiseOperationType { |
209 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
210 | 0 | write!(f, "{:?}", self) |
211 | 0 | } |
212 | | } |
213 | | |
214 | | /// [BitwiseOperation] struct encapsulates information about a bitwise operation. |
215 | | #[derive(Debug)] |
216 | | struct BitwiseOperation { |
217 | | signature: Signature, |
218 | | /// `operation` indicates the type of bitwise operation to be performed. |
219 | | operation: BitwiseOperationType, |
220 | | func_name: &'static str, |
221 | | documentation: &'static Documentation, |
222 | | } |
223 | | |
224 | | impl BitwiseOperation { |
225 | 0 | pub fn new( |
226 | 0 | operator: BitwiseOperationType, |
227 | 0 | func_name: &'static str, |
228 | 0 | documentation: &'static Documentation, |
229 | 0 | ) -> Self { |
230 | 0 | Self { |
231 | 0 | operation: operator, |
232 | 0 | signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable), |
233 | 0 | func_name, |
234 | 0 | documentation, |
235 | 0 | } |
236 | 0 | } |
237 | | } |
238 | | |
239 | | impl AggregateUDFImpl for BitwiseOperation { |
240 | 0 | fn as_any(&self) -> &dyn Any { |
241 | 0 | self |
242 | 0 | } |
243 | | |
244 | 0 | fn name(&self) -> &str { |
245 | 0 | self.func_name |
246 | 0 | } |
247 | | |
248 | 0 | fn signature(&self) -> &Signature { |
249 | 0 | &self.signature |
250 | 0 | } |
251 | | |
252 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
253 | 0 | let arg_type = &arg_types[0]; |
254 | 0 | if !arg_type.is_integer() { |
255 | 0 | return exec_err!( |
256 | 0 | "[return_type] {} not supported for {}", |
257 | 0 | self.name(), |
258 | 0 | arg_type |
259 | 0 | ); |
260 | 0 | } |
261 | 0 | Ok(arg_type.clone()) |
262 | 0 | } |
263 | | |
264 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
265 | 0 | downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct) |
266 | 0 | } |
267 | | |
268 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
269 | 0 | if self.operation == BitwiseOperationType::Xor && args.is_distinct { |
270 | 0 | Ok(vec![Field::new_list( |
271 | 0 | format_state_name( |
272 | 0 | args.name, |
273 | 0 | format!("{} distinct", self.name()).as_str(), |
274 | 0 | ), |
275 | 0 | // See COMMENTS.md to understand why nullable is set to true |
276 | 0 | Field::new("item", args.return_type.clone(), true), |
277 | 0 | false, |
278 | 0 | )]) |
279 | | } else { |
280 | 0 | Ok(vec![Field::new( |
281 | 0 | format_state_name(args.name, self.name()), |
282 | 0 | args.return_type.clone(), |
283 | 0 | true, |
284 | 0 | )]) |
285 | | } |
286 | 0 | } |
287 | | |
288 | 0 | fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { |
289 | 0 | true |
290 | 0 | } |
291 | | |
292 | 0 | fn create_groups_accumulator( |
293 | 0 | &self, |
294 | 0 | args: AccumulatorArgs, |
295 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
296 | 0 | let data_type = args.return_type; |
297 | 0 | let operation = &self.operation; |
298 | 0 | downcast_integer! { |
299 | 0 | data_type => (group_accumulator_helper, data_type, operation), |
300 | 0 | _ => not_impl_err!( |
301 | 0 | "GroupsAccumulator not supported for {} with {}", |
302 | 0 | self.name(), |
303 | 0 | data_type |
304 | 0 | ), |
305 | | } |
306 | 0 | } |
307 | | |
308 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
309 | 0 | ReversedUDAF::Identical |
310 | 0 | } |
311 | | |
312 | 0 | fn documentation(&self) -> Option<&Documentation> { |
313 | 0 | Some(self.documentation) |
314 | 0 | } |
315 | | } |
316 | | |
317 | | struct BitAndAccumulator<T: ArrowNumericType> { |
318 | | value: Option<T::Native>, |
319 | | } |
320 | | |
321 | | impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> { |
322 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
323 | 0 | write!(f, "BitAndAccumulator({})", T::DATA_TYPE) |
324 | 0 | } |
325 | | } |
326 | | |
327 | | impl<T: ArrowNumericType> Default for BitAndAccumulator<T> { |
328 | 0 | fn default() -> Self { |
329 | 0 | Self { value: None } |
330 | 0 | } |
331 | | } |
332 | | |
333 | | impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T> |
334 | | where |
335 | | T::Native: std::ops::BitAnd<Output = T::Native>, |
336 | | { |
337 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
338 | 0 | if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) { |
339 | 0 | let v = self.value.get_or_insert(x); |
340 | 0 | *v = *v & x; |
341 | 0 | } |
342 | 0 | Ok(()) |
343 | 0 | } |
344 | | |
345 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
346 | 0 | ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE) |
347 | 0 | } |
348 | | |
349 | 0 | fn size(&self) -> usize { |
350 | 0 | std::mem::size_of_val(self) |
351 | 0 | } |
352 | | |
353 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
354 | 0 | Ok(vec![self.evaluate()?]) |
355 | 0 | } |
356 | | |
357 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
358 | 0 | self.update_batch(states) |
359 | 0 | } |
360 | | } |
361 | | |
362 | | struct BitOrAccumulator<T: ArrowNumericType> { |
363 | | value: Option<T::Native>, |
364 | | } |
365 | | |
366 | | impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> { |
367 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
368 | 0 | write!(f, "BitOrAccumulator({})", T::DATA_TYPE) |
369 | 0 | } |
370 | | } |
371 | | |
372 | | impl<T: ArrowNumericType> Default for BitOrAccumulator<T> { |
373 | 0 | fn default() -> Self { |
374 | 0 | Self { value: None } |
375 | 0 | } |
376 | | } |
377 | | |
378 | | impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T> |
379 | | where |
380 | | T::Native: std::ops::BitOr<Output = T::Native>, |
381 | | { |
382 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
383 | 0 | if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) { |
384 | 0 | let v = self.value.get_or_insert(T::Native::usize_as(0)); |
385 | 0 | *v = *v | x; |
386 | 0 | } |
387 | 0 | Ok(()) |
388 | 0 | } |
389 | | |
390 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
391 | 0 | ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE) |
392 | 0 | } |
393 | | |
394 | 0 | fn size(&self) -> usize { |
395 | 0 | std::mem::size_of_val(self) |
396 | 0 | } |
397 | | |
398 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
399 | 0 | Ok(vec![self.evaluate()?]) |
400 | 0 | } |
401 | | |
402 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
403 | 0 | self.update_batch(states) |
404 | 0 | } |
405 | | } |
406 | | |
407 | | struct BitXorAccumulator<T: ArrowNumericType> { |
408 | | value: Option<T::Native>, |
409 | | } |
410 | | |
411 | | impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> { |
412 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
413 | 0 | write!(f, "BitXorAccumulator({})", T::DATA_TYPE) |
414 | 0 | } |
415 | | } |
416 | | |
417 | | impl<T: ArrowNumericType> Default for BitXorAccumulator<T> { |
418 | 0 | fn default() -> Self { |
419 | 0 | Self { value: None } |
420 | 0 | } |
421 | | } |
422 | | |
423 | | impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T> |
424 | | where |
425 | | T::Native: std::ops::BitXor<Output = T::Native>, |
426 | | { |
427 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
428 | 0 | if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) { |
429 | 0 | let v = self.value.get_or_insert(T::Native::usize_as(0)); |
430 | 0 | *v = *v ^ x; |
431 | 0 | } |
432 | 0 | Ok(()) |
433 | 0 | } |
434 | | |
435 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
436 | 0 | // XOR is it's own inverse |
437 | 0 | self.update_batch(values) |
438 | 0 | } |
439 | | |
440 | 0 | fn supports_retract_batch(&self) -> bool { |
441 | 0 | true |
442 | 0 | } |
443 | | |
444 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
445 | 0 | ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE) |
446 | 0 | } |
447 | | |
448 | 0 | fn size(&self) -> usize { |
449 | 0 | std::mem::size_of_val(self) |
450 | 0 | } |
451 | | |
452 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
453 | 0 | Ok(vec![self.evaluate()?]) |
454 | 0 | } |
455 | | |
456 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
457 | 0 | self.update_batch(states) |
458 | 0 | } |
459 | | } |
460 | | |
461 | | struct DistinctBitXorAccumulator<T: ArrowNumericType> { |
462 | | values: HashSet<T::Native, RandomState>, |
463 | | } |
464 | | |
465 | | impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> { |
466 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
467 | 0 | write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) |
468 | 0 | } |
469 | | } |
470 | | |
471 | | impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> { |
472 | 0 | fn default() -> Self { |
473 | 0 | Self { |
474 | 0 | values: HashSet::default(), |
475 | 0 | } |
476 | 0 | } |
477 | | } |
478 | | |
479 | | impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T> |
480 | | where |
481 | | T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq, |
482 | | { |
483 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
484 | 0 | if values.is_empty() { |
485 | 0 | return Ok(()); |
486 | 0 | } |
487 | 0 |
|
488 | 0 | let array = values[0].as_primitive::<T>(); |
489 | 0 | match array.nulls().filter(|x| x.null_count() > 0) { |
490 | 0 | Some(n) => { |
491 | 0 | for idx in n.valid_indices() { |
492 | 0 | self.values.insert(array.value(idx)); |
493 | 0 | } |
494 | | } |
495 | 0 | None => array.values().iter().for_each(|x| { |
496 | 0 | self.values.insert(*x); |
497 | 0 | }), |
498 | | } |
499 | 0 | Ok(()) |
500 | 0 | } |
501 | | |
502 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
503 | 0 | let mut acc = T::Native::usize_as(0); |
504 | 0 | for distinct_value in self.values.iter() { |
505 | 0 | acc = acc ^ *distinct_value; |
506 | 0 | } |
507 | 0 | let v = (!self.values.is_empty()).then_some(acc); |
508 | 0 | ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE) |
509 | 0 | } |
510 | | |
511 | 0 | fn size(&self) -> usize { |
512 | 0 | std::mem::size_of_val(self) |
513 | 0 | + self.values.capacity() * std::mem::size_of::<T::Native>() |
514 | 0 | } |
515 | | |
516 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
517 | | // 1. Stores aggregate state in `ScalarValue::List` |
518 | | // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set |
519 | 0 | let state_out = { |
520 | 0 | let values = self |
521 | 0 | .values |
522 | 0 | .iter() |
523 | 0 | .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE)) |
524 | 0 | .collect::<Result<Vec<_>>>()?; |
525 | | |
526 | 0 | let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE); |
527 | 0 | vec![ScalarValue::List(arr)] |
528 | 0 | }; |
529 | 0 | Ok(state_out) |
530 | 0 | } |
531 | | |
532 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
533 | 0 | if let Some(state) = states.first() { |
534 | 0 | let list_arr = as_list_array(state)?; |
535 | 0 | for arr in list_arr.iter().flatten() { |
536 | 0 | self.update_batch(&[arr])?; |
537 | | } |
538 | 0 | } |
539 | 0 | Ok(()) |
540 | 0 | } |
541 | | } |
542 | | |
543 | | #[cfg(test)] |
544 | | mod tests { |
545 | | use std::sync::Arc; |
546 | | |
547 | | use arrow::array::{ArrayRef, UInt64Array}; |
548 | | use arrow::datatypes::UInt64Type; |
549 | | use datafusion_common::ScalarValue; |
550 | | |
551 | | use crate::bit_and_or_xor::BitXorAccumulator; |
552 | | use datafusion_expr::Accumulator; |
553 | | |
554 | | #[test] |
555 | | fn test_bit_xor_accumulator() { |
556 | | let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None }; |
557 | | let batches: Vec<_> = vec![vec![1, 2], vec![1]] |
558 | | .into_iter() |
559 | | .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef) |
560 | | .collect(); |
561 | | |
562 | | let added = &[Arc::clone(&batches[0])]; |
563 | | let retracted = &[Arc::clone(&batches[1])]; |
564 | | |
565 | | // XOR of 1..3 is 3 |
566 | | accumulator.update_batch(added).unwrap(); |
567 | | assert_eq!( |
568 | | accumulator.evaluate().unwrap(), |
569 | | ScalarValue::UInt64(Some(3)) |
570 | | ); |
571 | | |
572 | | // Removing [1] ^ 3 = 2 |
573 | | accumulator.retract_batch(retracted).unwrap(); |
574 | | assert_eq!( |
575 | | accumulator.evaluate().unwrap(), |
576 | | ScalarValue::UInt64(Some(2)) |
577 | | ); |
578 | | } |
579 | | } |