Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/bit_and_or_xor.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators
19
20
use std::any::Any;
21
use std::collections::HashSet;
22
use std::fmt::{Display, Formatter};
23
24
use ahash::RandomState;
25
use arrow::array::{downcast_integer, Array, ArrayRef, AsArray};
26
use arrow::datatypes::{
27
    ArrowNativeType, ArrowNumericType, DataType, Int16Type, Int32Type, Int64Type,
28
    Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
29
};
30
use arrow_schema::Field;
31
32
use datafusion_common::cast::as_list_array;
33
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
34
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35
use datafusion_expr::type_coercion::aggregates::INTEGERS;
36
use datafusion_expr::utils::format_state_name;
37
use datafusion_expr::{
38
    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
39
    Signature, Volatility,
40
};
41
42
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
43
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
44
use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign};
45
use std::sync::OnceLock;
46
47
/// This macro helps create group accumulators based on bitwise operations typically used internally
48
/// and might not be necessary for users to call directly.
49
macro_rules! group_accumulator_helper {
50
    ($t:ty, $dt:expr, $opr:expr) => {
51
        match $opr {
52
            BitwiseOperationType::And => Ok(Box::new(
53
0
                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y))
54
                    .with_starting_value(!0),
55
            )),
56
            BitwiseOperationType::Or => Ok(Box::new(
57
0
                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
58
            )),
59
            BitwiseOperationType::Xor => Ok(Box::new(
60
0
                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
61
            )),
62
        }
63
    };
64
}
65
66
/// `accumulator_helper` is a macro accepting (ArrowPrimitiveType, BitwiseOperationType, bool)
67
macro_rules! accumulator_helper {
68
    ($t:ty, $opr:expr, $is_distinct: expr) => {
69
        match $opr {
70
            BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()),
71
            BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()),
72
            BitwiseOperationType::Xor => {
73
                if $is_distinct {
74
                    Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
75
                } else {
76
                    Ok(Box::<BitXorAccumulator<$t>>::default())
77
                }
78
            }
79
        }
80
    };
81
}
82
83
/// AND, OR and XOR only supports a subset of numeric types
84
///
85
/// `args` is [AccumulatorArgs]
86
/// `opr` is [BitwiseOperationType]
87
/// `is_distinct` is boolean value indicating whether the operation is distinct or not.
88
macro_rules! downcast_bitwise_accumulator {
89
    ($args:ident, $opr:expr, $is_distinct: expr) => {
90
        match $args.return_type {
91
            DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct),
92
            DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct),
93
            DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct),
94
            DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct),
95
            DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct),
96
            DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct),
97
            DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct),
98
            DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct),
99
            _ => {
100
                not_impl_err!(
101
                    "{} not supported for {}: {}",
102
                    stringify!($opr),
103
                    $args.name,
104
                    $args.return_type
105
                )
106
            }
107
        }
108
    };
109
}
110
111
/// Simplifies the creation of User-Defined Aggregate Functions (UDAFs) for performing bitwise operations in a declarative manner.
112
///
113
/// `EXPR_FN` identifier used to name the generated expression function.
114
/// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function.
115
/// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed.
116
/// `DOCUMENTATION` documentation for the UDAF
117
macro_rules! make_bitwise_udaf_expr_and_func {
118
    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => {
119
        make_udaf_expr!(
120
            $EXPR_FN,
121
            expr_x,
122
            concat!(
123
                "Returns the bitwise",
124
                stringify!($OPR_TYPE),
125
                "of a group of values"
126
            ),
127
            $AGGREGATE_UDF_FN
128
        );
129
        create_func!(
130
            $EXPR_FN,
131
            $AGGREGATE_UDF_FN,
132
            BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION)
133
        );
134
    };
135
}
136
137
static BIT_AND_DOC: OnceLock<Documentation> = OnceLock::new();
138
139
0
fn get_bit_and_doc() -> &'static Documentation {
140
0
    BIT_AND_DOC.get_or_init(|| {
141
0
        Documentation::builder()
142
0
            .with_doc_section(DOC_SECTION_GENERAL)
143
0
            .with_description("Computes the bitwise AND of all non-null input values.")
144
0
            .with_syntax_example("bit_and(expression)")
145
0
            .with_standard_argument("expression", "Integer")
146
0
            .build()
147
0
            .unwrap()
148
0
    })
149
0
}
150
151
static BIT_OR_DOC: OnceLock<Documentation> = OnceLock::new();
152
153
0
fn get_bit_or_doc() -> &'static Documentation {
154
0
    BIT_OR_DOC.get_or_init(|| {
155
0
        Documentation::builder()
156
0
            .with_doc_section(DOC_SECTION_GENERAL)
157
0
            .with_description("Computes the bitwise OR of all non-null input values.")
158
0
            .with_syntax_example("bit_or(expression)")
159
0
            .with_standard_argument("expression", "Integer")
160
0
            .build()
161
0
            .unwrap()
162
0
    })
163
0
}
164
165
static BIT_XOR_DOC: OnceLock<Documentation> = OnceLock::new();
166
167
0
fn get_bit_xor_doc() -> &'static Documentation {
168
0
    BIT_XOR_DOC.get_or_init(|| {
169
0
        Documentation::builder()
170
0
            .with_doc_section(DOC_SECTION_GENERAL)
171
0
            .with_description(
172
0
                "Computes the bitwise exclusive OR of all non-null input values.",
173
0
            )
174
0
            .with_syntax_example("bit_xor(expression)")
175
0
            .with_standard_argument("expression", "Integer")
176
0
            .build()
177
0
            .unwrap()
178
0
    })
179
0
}
180
181
make_bitwise_udaf_expr_and_func!(
182
    bit_and,
183
    bit_and_udaf,
184
    BitwiseOperationType::And,
185
    get_bit_and_doc()
186
);
187
make_bitwise_udaf_expr_and_func!(
188
    bit_or,
189
    bit_or_udaf,
190
    BitwiseOperationType::Or,
191
    get_bit_or_doc()
192
);
193
make_bitwise_udaf_expr_and_func!(
194
    bit_xor,
195
    bit_xor_udaf,
196
    BitwiseOperationType::Xor,
197
    get_bit_xor_doc()
198
);
199
200
/// The different types of bitwise operations that can be performed.
201
#[derive(Debug, Clone, Eq, PartialEq)]
202
enum BitwiseOperationType {
203
    And,
204
    Or,
205
    Xor,
206
}
207
208
impl Display for BitwiseOperationType {
209
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210
0
        write!(f, "{:?}", self)
211
0
    }
212
}
213
214
/// [BitwiseOperation] struct encapsulates information about a bitwise operation.
215
#[derive(Debug)]
216
struct BitwiseOperation {
217
    signature: Signature,
218
    /// `operation` indicates the type of bitwise operation to be performed.
219
    operation: BitwiseOperationType,
220
    func_name: &'static str,
221
    documentation: &'static Documentation,
222
}
223
224
impl BitwiseOperation {
225
0
    pub fn new(
226
0
        operator: BitwiseOperationType,
227
0
        func_name: &'static str,
228
0
        documentation: &'static Documentation,
229
0
    ) -> Self {
230
0
        Self {
231
0
            operation: operator,
232
0
            signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable),
233
0
            func_name,
234
0
            documentation,
235
0
        }
236
0
    }
237
}
238
239
impl AggregateUDFImpl for BitwiseOperation {
240
0
    fn as_any(&self) -> &dyn Any {
241
0
        self
242
0
    }
243
244
0
    fn name(&self) -> &str {
245
0
        self.func_name
246
0
    }
247
248
0
    fn signature(&self) -> &Signature {
249
0
        &self.signature
250
0
    }
251
252
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
253
0
        let arg_type = &arg_types[0];
254
0
        if !arg_type.is_integer() {
255
0
            return exec_err!(
256
0
                "[return_type] {} not supported for {}",
257
0
                self.name(),
258
0
                arg_type
259
0
            );
260
0
        }
261
0
        Ok(arg_type.clone())
262
0
    }
263
264
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
265
0
        downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct)
266
0
    }
267
268
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
269
0
        if self.operation == BitwiseOperationType::Xor && args.is_distinct {
270
0
            Ok(vec![Field::new_list(
271
0
                format_state_name(
272
0
                    args.name,
273
0
                    format!("{} distinct", self.name()).as_str(),
274
0
                ),
275
0
                // See COMMENTS.md to understand why nullable is set to true
276
0
                Field::new("item", args.return_type.clone(), true),
277
0
                false,
278
0
            )])
279
        } else {
280
0
            Ok(vec![Field::new(
281
0
                format_state_name(args.name, self.name()),
282
0
                args.return_type.clone(),
283
0
                true,
284
0
            )])
285
        }
286
0
    }
287
288
0
    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
289
0
        true
290
0
    }
291
292
0
    fn create_groups_accumulator(
293
0
        &self,
294
0
        args: AccumulatorArgs,
295
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
296
0
        let data_type = args.return_type;
297
0
        let operation = &self.operation;
298
0
        downcast_integer! {
299
0
            data_type => (group_accumulator_helper, data_type, operation),
300
0
            _ => not_impl_err!(
301
0
                "GroupsAccumulator not supported for {} with {}",
302
0
                self.name(),
303
0
                data_type
304
0
            ),
305
        }
306
0
    }
307
308
0
    fn reverse_expr(&self) -> ReversedUDAF {
309
0
        ReversedUDAF::Identical
310
0
    }
311
312
0
    fn documentation(&self) -> Option<&Documentation> {
313
0
        Some(self.documentation)
314
0
    }
315
}
316
317
struct BitAndAccumulator<T: ArrowNumericType> {
318
    value: Option<T::Native>,
319
}
320
321
impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
322
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
323
0
        write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
324
0
    }
325
}
326
327
impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
328
0
    fn default() -> Self {
329
0
        Self { value: None }
330
0
    }
331
}
332
333
impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
334
where
335
    T::Native: std::ops::BitAnd<Output = T::Native>,
336
{
337
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
338
0
        if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) {
339
0
            let v = self.value.get_or_insert(x);
340
0
            *v = *v & x;
341
0
        }
342
0
        Ok(())
343
0
    }
344
345
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
346
0
        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
347
0
    }
348
349
0
    fn size(&self) -> usize {
350
0
        std::mem::size_of_val(self)
351
0
    }
352
353
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
354
0
        Ok(vec![self.evaluate()?])
355
0
    }
356
357
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
358
0
        self.update_batch(states)
359
0
    }
360
}
361
362
struct BitOrAccumulator<T: ArrowNumericType> {
363
    value: Option<T::Native>,
364
}
365
366
impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
367
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
368
0
        write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
369
0
    }
370
}
371
372
impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
373
0
    fn default() -> Self {
374
0
        Self { value: None }
375
0
    }
376
}
377
378
impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
379
where
380
    T::Native: std::ops::BitOr<Output = T::Native>,
381
{
382
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
383
0
        if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) {
384
0
            let v = self.value.get_or_insert(T::Native::usize_as(0));
385
0
            *v = *v | x;
386
0
        }
387
0
        Ok(())
388
0
    }
389
390
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
391
0
        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
392
0
    }
393
394
0
    fn size(&self) -> usize {
395
0
        std::mem::size_of_val(self)
396
0
    }
397
398
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
399
0
        Ok(vec![self.evaluate()?])
400
0
    }
401
402
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
403
0
        self.update_batch(states)
404
0
    }
405
}
406
407
struct BitXorAccumulator<T: ArrowNumericType> {
408
    value: Option<T::Native>,
409
}
410
411
impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
412
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
413
0
        write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
414
0
    }
415
}
416
417
impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
418
0
    fn default() -> Self {
419
0
        Self { value: None }
420
0
    }
421
}
422
423
impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
424
where
425
    T::Native: std::ops::BitXor<Output = T::Native>,
426
{
427
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
428
0
        if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) {
429
0
            let v = self.value.get_or_insert(T::Native::usize_as(0));
430
0
            *v = *v ^ x;
431
0
        }
432
0
        Ok(())
433
0
    }
434
435
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
436
0
        // XOR is it's own inverse
437
0
        self.update_batch(values)
438
0
    }
439
440
0
    fn supports_retract_batch(&self) -> bool {
441
0
        true
442
0
    }
443
444
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
445
0
        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
446
0
    }
447
448
0
    fn size(&self) -> usize {
449
0
        std::mem::size_of_val(self)
450
0
    }
451
452
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
453
0
        Ok(vec![self.evaluate()?])
454
0
    }
455
456
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
457
0
        self.update_batch(states)
458
0
    }
459
}
460
461
struct DistinctBitXorAccumulator<T: ArrowNumericType> {
462
    values: HashSet<T::Native, RandomState>,
463
}
464
465
impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
466
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
467
0
        write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
468
0
    }
469
}
470
471
impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
472
0
    fn default() -> Self {
473
0
        Self {
474
0
            values: HashSet::default(),
475
0
        }
476
0
    }
477
}
478
479
impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
480
where
481
    T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq,
482
{
483
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
484
0
        if values.is_empty() {
485
0
            return Ok(());
486
0
        }
487
0
488
0
        let array = values[0].as_primitive::<T>();
489
0
        match array.nulls().filter(|x| x.null_count() > 0) {
490
0
            Some(n) => {
491
0
                for idx in n.valid_indices() {
492
0
                    self.values.insert(array.value(idx));
493
0
                }
494
            }
495
0
            None => array.values().iter().for_each(|x| {
496
0
                self.values.insert(*x);
497
0
            }),
498
        }
499
0
        Ok(())
500
0
    }
501
502
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
503
0
        let mut acc = T::Native::usize_as(0);
504
0
        for distinct_value in self.values.iter() {
505
0
            acc = acc ^ *distinct_value;
506
0
        }
507
0
        let v = (!self.values.is_empty()).then_some(acc);
508
0
        ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE)
509
0
    }
510
511
0
    fn size(&self) -> usize {
512
0
        std::mem::size_of_val(self)
513
0
            + self.values.capacity() * std::mem::size_of::<T::Native>()
514
0
    }
515
516
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
517
        // 1. Stores aggregate state in `ScalarValue::List`
518
        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
519
0
        let state_out = {
520
0
            let values = self
521
0
                .values
522
0
                .iter()
523
0
                .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE))
524
0
                .collect::<Result<Vec<_>>>()?;
525
526
0
            let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE);
527
0
            vec![ScalarValue::List(arr)]
528
0
        };
529
0
        Ok(state_out)
530
0
    }
531
532
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
533
0
        if let Some(state) = states.first() {
534
0
            let list_arr = as_list_array(state)?;
535
0
            for arr in list_arr.iter().flatten() {
536
0
                self.update_batch(&[arr])?;
537
            }
538
0
        }
539
0
        Ok(())
540
0
    }
541
}
542
543
#[cfg(test)]
544
mod tests {
545
    use std::sync::Arc;
546
547
    use arrow::array::{ArrayRef, UInt64Array};
548
    use arrow::datatypes::UInt64Type;
549
    use datafusion_common::ScalarValue;
550
551
    use crate::bit_and_or_xor::BitXorAccumulator;
552
    use datafusion_expr::Accumulator;
553
554
    #[test]
555
    fn test_bit_xor_accumulator() {
556
        let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
557
        let batches: Vec<_> = vec![vec![1, 2], vec![1]]
558
            .into_iter()
559
            .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
560
            .collect();
561
562
        let added = &[Arc::clone(&batches[0])];
563
        let retracted = &[Arc::clone(&batches[1])];
564
565
        // XOR of 1..3 is 3
566
        accumulator.update_batch(added).unwrap();
567
        assert_eq!(
568
            accumulator.evaluate().unwrap(),
569
            ScalarValue::UInt64(Some(3))
570
        );
571
572
        // Removing [1] ^ 3 = 2
573
        accumulator.retract_batch(retracted).unwrap();
574
        assert_eq!(
575
            accumulator.evaluate().unwrap(),
576
            ScalarValue::UInt64(Some(2))
577
        );
578
    }
579
}