Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/expressions/binary.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
mod kernels;
19
20
use std::hash::{Hash, Hasher};
21
use std::{any::Any, sync::Arc};
22
23
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
24
use crate::physical_expr::down_cast_any_ref;
25
use crate::PhysicalExpr;
26
27
use arrow::array::*;
28
use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
29
use arrow::compute::kernels::cmp::*;
30
use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar};
31
use arrow::compute::kernels::concat_elements::concat_elements_utf8;
32
use arrow::compute::{cast, ilike, like, nilike, nlike};
33
use arrow::datatypes::*;
34
use arrow_schema::ArrowError;
35
use datafusion_common::cast::as_boolean_array;
36
use datafusion_common::{internal_err, Result, ScalarValue};
37
use datafusion_expr::interval_arithmetic::{apply_operator, Interval};
38
use datafusion_expr::sort_properties::ExprProperties;
39
use datafusion_expr::type_coercion::binary::get_result_type;
40
use datafusion_expr::{ColumnarValue, Operator};
41
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
42
43
use crate::expressions::binary::kernels::concat_elements_utf8view;
44
use kernels::{
45
    bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
46
    bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
47
    bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar,
48
};
49
50
/// Binary expression
51
#[derive(Debug, Hash, Clone)]
52
pub struct BinaryExpr {
53
    left: Arc<dyn PhysicalExpr>,
54
    op: Operator,
55
    right: Arc<dyn PhysicalExpr>,
56
    /// Specifies whether an error is returned on overflow or not
57
    fail_on_overflow: bool,
58
}
59
60
impl BinaryExpr {
61
    /// Create new binary expression
62
2.70k
    pub fn new(
63
2.70k
        left: Arc<dyn PhysicalExpr>,
64
2.70k
        op: Operator,
65
2.70k
        right: Arc<dyn PhysicalExpr>,
66
2.70k
    ) -> Self {
67
2.70k
        Self {
68
2.70k
            left,
69
2.70k
            op,
70
2.70k
            right,
71
2.70k
            fail_on_overflow: false,
72
2.70k
        }
73
2.70k
    }
74
75
    /// Create new binary expression with explicit fail_on_overflow value
76
66
    pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
77
66
        Self {
78
66
            left: self.left,
79
66
            op: self.op,
80
66
            right: self.right,
81
66
            fail_on_overflow,
82
66
        }
83
66
    }
84
85
    /// Get the left side of the binary expression
86
98
    pub fn left(&self) -> &Arc<dyn PhysicalExpr> {
87
98
        &self.left
88
98
    }
89
90
    /// Get the right side of the binary expression
91
91
    pub fn right(&self) -> &Arc<dyn PhysicalExpr> {
92
91
        &self.right
93
91
    }
94
95
    /// Get the operator for this binary expression
96
183
    pub fn op(&self) -> &Operator {
97
183
        &self.op
98
183
    }
99
}
100
101
impl std::fmt::Display for BinaryExpr {
102
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
103
        // Put parentheses around child binary expressions so that we can see the difference
104
        // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
105
        // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
106
        // equivalent and the parentheses are not necessary.
107
108
0
        fn write_child(
109
0
            f: &mut std::fmt::Formatter,
110
0
            expr: &dyn PhysicalExpr,
111
0
            precedence: u8,
112
0
        ) -> std::fmt::Result {
113
0
            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
114
0
                let p = child.op.precedence();
115
0
                if p == 0 || p < precedence {
116
0
                    write!(f, "({child})")?;
117
                } else {
118
0
                    write!(f, "{child}")?;
119
                }
120
            } else {
121
0
                write!(f, "{expr}")?;
122
            }
123
124
0
            Ok(())
125
0
        }
126
127
0
        let precedence = self.op.precedence();
128
0
        write_child(f, self.left.as_ref(), precedence)?;
129
0
        write!(f, " {} ", self.op)?;
130
0
        write_child(f, self.right.as_ref(), precedence)
131
0
    }
132
}
133
134
/// Invoke a boolean kernel on a pair of arrays
135
#[inline]
136
22.1k
fn boolean_op(
137
22.1k
    left: &dyn Array,
138
22.1k
    right: &dyn Array,
139
22.1k
    op: impl FnOnce(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>,
140
22.1k
) -> Result<Arc<(dyn Array + 'static)>, ArrowError> {
141
22.1k
    let ll = as_boolean_array(left).expect("boolean_op failed to downcast left array");
142
22.1k
    let rr = as_boolean_array(right).expect("boolean_op failed to downcast right array");
143
22.1k
    op(ll, rr).map(|t| Arc::new(t) as _)
144
22.1k
}
145
146
macro_rules! binary_string_array_flag_op {
147
    ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
148
        match $LEFT.data_type() {
149
            DataType::Utf8View | DataType::Utf8 => {
150
                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
151
            },
152
            DataType::LargeUtf8 => {
153
                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
154
            },
155
            other => internal_err!(
156
                "Data type {:?} not supported for binary_string_array_flag_op operation '{}' on string array",
157
                other, stringify!($OP)
158
            ),
159
        }
160
    }};
161
}
162
163
/// Invoke a compute kernel on a pair of binary data arrays with flags
164
macro_rules! compute_utf8_flag_op {
165
    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
166
        let ll = $LEFT
167
            .as_any()
168
            .downcast_ref::<$ARRAYTYPE>()
169
            .expect("compute_utf8_flag_op failed to downcast array");
170
        let rr = $RIGHT
171
            .as_any()
172
            .downcast_ref::<$ARRAYTYPE>()
173
            .expect("compute_utf8_flag_op failed to downcast array");
174
175
        let flag = if $FLAG {
176
            Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
177
        } else {
178
            None
179
        };
180
        let mut array = $OP(ll, rr, flag.as_ref())?;
181
        if $NOT {
182
            array = not(&array).unwrap();
183
        }
184
        Ok(Arc::new(array))
185
    }};
186
}
187
188
macro_rules! binary_string_array_flag_op_scalar {
189
    ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
190
        let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
191
            DataType::Utf8View | DataType::Utf8 => {
192
                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
193
            },
194
            DataType::LargeUtf8 => {
195
                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
196
            },
197
            other => internal_err!(
198
                "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array",
199
                other, stringify!($OP)
200
            ),
201
        };
202
        Some(result)
203
    }};
204
}
205
206
/// Invoke a compute kernel on a data array and a scalar value with flag
207
macro_rules! compute_utf8_flag_op_scalar {
208
    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
209
        let ll = $LEFT
210
            .as_any()
211
            .downcast_ref::<$ARRAYTYPE>()
212
            .expect("compute_utf8_flag_op_scalar failed to downcast array");
213
214
        if let ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT {
215
            let flag = $FLAG.then_some("i");
216
            let mut array =
217
                paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?;
218
            if $NOT {
219
                array = not(&array).unwrap();
220
            }
221
            Ok(Arc::new(array))
222
        } else {
223
            internal_err!(
224
                "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
225
                $RIGHT, stringify!($OP)
226
            )
227
        }
228
    }};
229
}
230
231
impl PhysicalExpr for BinaryExpr {
232
    /// Return a reference to Any that can be used for downcasting
233
48.4k
    fn as_any(&self) -> &dyn Any {
234
48.4k
        self
235
48.4k
    }
236
237
147k
    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
238
147k
        get_result_type(
239
147k
            &self.left.data_type(input_schema)
?0
,
240
147k
            &self.op,
241
147k
            &self.right.data_type(input_schema)
?0
,
242
        )
243
147k
    }
244
245
0
    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
246
0
        Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
247
0
    }
248
249
132k
    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
250
        use arrow::compute::kernels::numeric::*;
251
252
132k
        let lhs = self.left.evaluate(batch)
?0
;
253
132k
        let rhs = self.right.evaluate(batch)
?0
;
254
132k
        let left_data_type = lhs.data_type();
255
132k
        let right_data_type = rhs.data_type();
256
132k
257
132k
        let schema = batch.schema();
258
132k
        let input_schema = schema.as_ref();
259
132k
260
132k
        if left_data_type.is_nested() {
261
0
            if right_data_type != left_data_type {
262
0
                return internal_err!("type mismatch");
263
0
            }
264
0
            return apply_cmp_for_nested(self.op, &lhs, &rhs);
265
132k
        }
266
267
18.9k
        match self.op {
268
0
            Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
269
18.9k
            Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
270
0
            Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
271
22.5k
            Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
272
0
            Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
273
0
            Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
274
0
            Operator::Divide => return apply(&lhs, &rhs, div),
275
24.2k
            Operator::Modulo => return apply(&lhs, &rhs, rem),
276
0
            Operator::Eq => return apply_cmp(&lhs, &rhs, eq),
277
24.3k
            Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
278
8.76k
            Operator::Lt => return apply_cmp(&lhs, &rhs, lt),
279
8.82k
            Operator::Gt => return apply_cmp(&lhs, &rhs, gt),
280
1.19k
            Operator::LtEq => return apply_cmp(&lhs, &rhs, lt_eq),
281
1.19k
            Operator::GtEq => return apply_cmp(&lhs, &rhs, gt_eq),
282
0
            Operator::IsDistinctFrom => return apply_cmp(&lhs, &rhs, distinct),
283
0
            Operator::IsNotDistinctFrom => return apply_cmp(&lhs, &rhs, not_distinct),
284
0
            Operator::LikeMatch => return apply_cmp(&lhs, &rhs, like),
285
0
            Operator::ILikeMatch => return apply_cmp(&lhs, &rhs, ilike),
286
0
            Operator::NotLikeMatch => return apply_cmp(&lhs, &rhs, nlike),
287
0
            Operator::NotILikeMatch => return apply_cmp(&lhs, &rhs, nilike),
288
22.1k
            _ => {}
289
        }
290
291
22.1k
        let result_type = self.data_type(input_schema)
?0
;
292
293
        // Attempt to use special kernels if one input is scalar and the other is an array
294
22.1k
        let scalar_result = match (&lhs, &rhs) {
295
0
            (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
296
0
                // if left is array and right is literal(not NULL) - use scalar operations
297
0
                if scalar.is_null() {
298
0
                    None
299
                } else {
300
0
                    self.evaluate_array_scalar(array, scalar.clone())?.map(|r| {
301
0
                        r.and_then(|a| to_result_type_array(&self.op, a, &result_type))
302
0
                    })
303
                }
304
            }
305
22.1k
            (_, _) => None, // default to array implementation
306
        };
307
308
22.1k
        if let Some(
result0
) = scalar_result {
309
0
            return result.map(ColumnarValue::Array);
310
22.1k
        }
311
312
        // if both arrays or both literals - extract arrays and continue execution
313
22.1k
        let (left, right) = (
314
22.1k
            lhs.into_array(batch.num_rows())
?0
,
315
22.1k
            rhs.into_array(batch.num_rows())
?0
,
316
        );
317
22.1k
        self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type)
318
22.1k
            .map(ColumnarValue::Array)
319
132k
    }
320
321
23.6k
    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
322
23.6k
        vec![&self.left, &self.right]
323
23.6k
    }
324
325
66
    fn with_new_children(
326
66
        self: Arc<Self>,
327
66
        children: Vec<Arc<dyn PhysicalExpr>>,
328
66
    ) -> Result<Arc<dyn PhysicalExpr>> {
329
66
        Ok(Arc::new(
330
66
            BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1]))
331
66
                .with_fail_on_overflow(self.fail_on_overflow),
332
66
        ))
333
66
    }
334
335
39.4k
    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
336
39.4k
        // Get children intervals:
337
39.4k
        let left_interval = children[0];
338
39.4k
        let right_interval = children[1];
339
39.4k
        // Calculate current node's interval:
340
39.4k
        apply_operator(&self.op, left_interval, right_interval)
341
39.4k
    }
342
343
39.3k
    fn propagate_constraints(
344
39.3k
        &self,
345
39.3k
        interval: &Interval,
346
39.3k
        children: &[&Interval],
347
39.3k
    ) -> Result<Option<Vec<Interval>>> {
348
39.3k
        // Get children intervals.
349
39.3k
        let left_interval = children[0];
350
39.3k
        let right_interval = children[1];
351
39.3k
352
39.3k
        if self.op.eq(&Operator::And) {
353
5.73k
            if interval.eq(&Interval::CERTAINLY_TRUE) {
354
                // A certainly true logical conjunction can only derive from possibly
355
                // true operands. Otherwise, we prove infeasability.
356
5.73k
                Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE)
357
5.73k
                    && !right_interval.eq(&Interval::CERTAINLY_FALSE))
358
5.73k
                .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE]))
359
0
            } else if interval.eq(&Interval::CERTAINLY_FALSE) {
360
                // If the logical conjunction is certainly false, one of the
361
                // operands must be false. However, it's not always possible to
362
                // determine which operand is false, leading to different scenarios.
363
364
                // If one operand is certainly true and the other one is uncertain,
365
                // then the latter must be certainly false.
366
0
                if left_interval.eq(&Interval::CERTAINLY_TRUE)
367
0
                    && right_interval.eq(&Interval::UNCERTAIN)
368
                {
369
0
                    Ok(Some(vec![
370
0
                        Interval::CERTAINLY_TRUE,
371
0
                        Interval::CERTAINLY_FALSE,
372
0
                    ]))
373
0
                } else if right_interval.eq(&Interval::CERTAINLY_TRUE)
374
0
                    && left_interval.eq(&Interval::UNCERTAIN)
375
                {
376
0
                    Ok(Some(vec![
377
0
                        Interval::CERTAINLY_FALSE,
378
0
                        Interval::CERTAINLY_TRUE,
379
0
                    ]))
380
                }
381
                // If both children are uncertain, or if one is certainly false,
382
                // we cannot conclusively refine their intervals. In this case,
383
                // propagation does not result in any interval changes.
384
                else {
385
0
                    Ok(Some(vec![]))
386
                }
387
            } else {
388
                // An uncertain logical conjunction result can not shrink the
389
                // end-points of its children.
390
0
                Ok(Some(vec![]))
391
            }
392
33.6k
        } else if self.op.eq(&Operator::Or) {
393
0
            if interval.eq(&Interval::CERTAINLY_FALSE) {
394
                // A certainly false logical conjunction can only derive from certainly
395
                // false operands. Otherwise, we prove infeasability.
396
0
                Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE)
397
0
                    && !right_interval.eq(&Interval::CERTAINLY_TRUE))
398
0
                .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE]))
399
0
            } else if interval.eq(&Interval::CERTAINLY_TRUE) {
400
                // If the logical disjunction is certainly true, one of the
401
                // operands must be true. However, it's not always possible to
402
                // determine which operand is true, leading to different scenarios.
403
404
                // If one operand is certainly false and the other one is uncertain,
405
                // then the latter must be certainly true.
406
0
                if left_interval.eq(&Interval::CERTAINLY_FALSE)
407
0
                    && right_interval.eq(&Interval::UNCERTAIN)
408
                {
409
0
                    Ok(Some(vec![
410
0
                        Interval::CERTAINLY_FALSE,
411
0
                        Interval::CERTAINLY_TRUE,
412
0
                    ]))
413
0
                } else if right_interval.eq(&Interval::CERTAINLY_FALSE)
414
0
                    && left_interval.eq(&Interval::UNCERTAIN)
415
                {
416
0
                    Ok(Some(vec![
417
0
                        Interval::CERTAINLY_TRUE,
418
0
                        Interval::CERTAINLY_FALSE,
419
0
                    ]))
420
                }
421
                // If both children are uncertain, or if one is certainly true,
422
                // we cannot conclusively refine their intervals. In this case,
423
                // propagation does not result in any interval changes.
424
                else {
425
0
                    Ok(Some(vec![]))
426
                }
427
            } else {
428
                // An uncertain logical disjunction result can not shrink the
429
                // end-points of its children.
430
0
                Ok(Some(vec![]))
431
            }
432
33.6k
        } else if self.op.is_comparison_operator() {
433
            Ok(
434
11.4k
                propagate_comparison(&self.op, interval, left_interval, right_interval)
?0
435
11.4k
                    .map(|(left, right)| vec![left, right]),
436
11.4k
            )
437
        } else {
438
            Ok(
439
22.1k
                propagate_arithmetic(&self.op, interval, left_interval, right_interval)
?0
440
22.1k
                    .map(|(left, right)| vec![left, right]),
441
22.1k
            )
442
        }
443
39.3k
    }
444
445
0
    fn dyn_hash(&self, state: &mut dyn Hasher) {
446
0
        let mut s = state;
447
0
        self.hash(&mut s);
448
0
    }
449
450
    /// For each operator, [`BinaryExpr`] has distinct rules.
451
    /// TODO: There may be rules specific to some data types and expression ranges.
452
0
    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
453
0
        let (l_order, l_range) = (children[0].sort_properties, &children[0].range);
454
0
        let (r_order, r_range) = (children[1].sort_properties, &children[1].range);
455
0
        match self.op() {
456
            Operator::Plus => Ok(ExprProperties {
457
0
                sort_properties: l_order.add(&r_order),
458
0
                range: l_range.add(r_range)?,
459
            }),
460
            Operator::Minus => Ok(ExprProperties {
461
0
                sort_properties: l_order.sub(&r_order),
462
0
                range: l_range.sub(r_range)?,
463
            }),
464
            Operator::Gt => Ok(ExprProperties {
465
0
                sort_properties: l_order.gt_or_gteq(&r_order),
466
0
                range: l_range.gt(r_range)?,
467
            }),
468
            Operator::GtEq => Ok(ExprProperties {
469
0
                sort_properties: l_order.gt_or_gteq(&r_order),
470
0
                range: l_range.gt_eq(r_range)?,
471
            }),
472
            Operator::Lt => Ok(ExprProperties {
473
0
                sort_properties: r_order.gt_or_gteq(&l_order),
474
0
                range: l_range.lt(r_range)?,
475
            }),
476
            Operator::LtEq => Ok(ExprProperties {
477
0
                sort_properties: r_order.gt_or_gteq(&l_order),
478
0
                range: l_range.lt_eq(r_range)?,
479
            }),
480
            Operator::And => Ok(ExprProperties {
481
0
                sort_properties: r_order.and_or(&l_order),
482
0
                range: l_range.and(r_range)?,
483
            }),
484
            Operator::Or => Ok(ExprProperties {
485
0
                sort_properties: r_order.and_or(&l_order),
486
0
                range: l_range.or(r_range)?,
487
            }),
488
0
            _ => Ok(ExprProperties::new_unknown()),
489
        }
490
0
    }
491
}
492
493
impl PartialEq<dyn Any> for BinaryExpr {
494
99.2k
    fn eq(&self, other: &dyn Any) -> bool {
495
99.2k
        down_cast_any_ref(other)
496
99.2k
            .downcast_ref::<Self>()
497
99.2k
            .map(|x| {
498
26.8k
                self.left.eq(&x.left)
499
3.60k
                    && self.op == x.op
500
2.73k
                    && self.right.eq(&x.right)
501
348
                    && self.fail_on_overflow.eq(&x.fail_on_overflow)
502
99.2k
            
}26.8k
)
503
99.2k
            .unwrap_or(false)
504
99.2k
    }
505
}
506
507
/// Casts dictionary array to result type for binary numerical operators. Such operators
508
/// between array and scalar produce a dictionary array other than primitive array of the
509
/// same operators between array and array. This leads to inconsistent result types causing
510
/// errors in the following query execution. For such operators between array and scalar,
511
/// we cast the dictionary array to primitive array.
512
0
fn to_result_type_array(
513
0
    op: &Operator,
514
0
    array: ArrayRef,
515
0
    result_type: &DataType,
516
0
) -> Result<ArrayRef> {
517
0
    if array.data_type() == result_type {
518
0
        Ok(array)
519
0
    } else if op.is_numerical_operators() {
520
0
        match array.data_type() {
521
0
            DataType::Dictionary(_, value_type) => {
522
0
                if value_type.as_ref() == result_type {
523
0
                    Ok(cast(&array, result_type)?)
524
                } else {
525
0
                    internal_err!(
526
0
                            "Incompatible Dictionary value type {value_type:?} with result type {result_type:?} of Binary operator {op:?}"
527
0
                        )
528
                }
529
            }
530
0
            _ => Ok(array),
531
        }
532
    } else {
533
0
        Ok(array)
534
    }
535
0
}
536
537
impl BinaryExpr {
538
    /// Evaluate the expression of the left input is an array and
539
    /// right is literal - use scalar operations
540
0
    fn evaluate_array_scalar(
541
0
        &self,
542
0
        array: &dyn Array,
543
0
        scalar: ScalarValue,
544
0
    ) -> Result<Option<Result<ArrayRef>>> {
545
        use Operator::*;
546
0
        let scalar_result = match &self.op {
547
0
            RegexMatch => binary_string_array_flag_op_scalar!(
548
0
                array,
549
0
                scalar,
550
                regexp_is_match,
551
0
                false,
552
                false
553
            ),
554
0
            RegexIMatch => binary_string_array_flag_op_scalar!(
555
0
                array,
556
0
                scalar,
557
                regexp_is_match,
558
0
                false,
559
                true
560
            ),
561
0
            RegexNotMatch => binary_string_array_flag_op_scalar!(
562
0
                array,
563
0
                scalar,
564
                regexp_is_match,
565
0
                true,
566
                false
567
            ),
568
0
            RegexNotIMatch => binary_string_array_flag_op_scalar!(
569
0
                array,
570
0
                scalar,
571
                regexp_is_match,
572
0
                true,
573
                true
574
            ),
575
0
            BitwiseAnd => bitwise_and_dyn_scalar(array, scalar),
576
0
            BitwiseOr => bitwise_or_dyn_scalar(array, scalar),
577
0
            BitwiseXor => bitwise_xor_dyn_scalar(array, scalar),
578
0
            BitwiseShiftRight => bitwise_shift_right_dyn_scalar(array, scalar),
579
0
            BitwiseShiftLeft => bitwise_shift_left_dyn_scalar(array, scalar),
580
            // if scalar operation is not supported - fallback to array implementation
581
0
            _ => None,
582
        };
583
584
0
        Ok(scalar_result)
585
0
    }
586
587
22.1k
    fn evaluate_with_resolved_args(
588
22.1k
        &self,
589
22.1k
        left: Arc<dyn Array>,
590
22.1k
        left_data_type: &DataType,
591
22.1k
        right: Arc<dyn Array>,
592
22.1k
        right_data_type: &DataType,
593
22.1k
    ) -> Result<ArrayRef> {
594
        use Operator::*;
595
22.1k
        match &self.op {
596
            IsDistinctFrom | IsNotDistinctFrom | Lt | LtEq | Gt | GtEq | Eq | NotEq
597
            | Plus | Minus | Multiply | Divide | Modulo | LikeMatch | ILikeMatch
598
0
            | NotLikeMatch | NotILikeMatch => unreachable!(),
599
            And => {
600
22.1k
                if left_data_type == &DataType::Boolean {
601
22.1k
                    Ok(boolean_op(&left, &right, and_kleene)
?0
)
602
                } else {
603
0
                    internal_err!(
604
0
                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
605
0
                        self.op,
606
0
                        left.data_type(),
607
0
                        right.data_type()
608
0
                    )
609
                }
610
            }
611
            Or => {
612
0
                if left_data_type == &DataType::Boolean {
613
0
                    Ok(boolean_op(&left, &right, or_kleene)?)
614
                } else {
615
0
                    internal_err!(
616
0
                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
617
0
                        self.op,
618
0
                        left_data_type,
619
0
                        right_data_type
620
0
                    )
621
                }
622
            }
623
            RegexMatch => {
624
0
                binary_string_array_flag_op!(left, right, regexp_is_match, false, false)
625
            }
626
            RegexIMatch => {
627
0
                binary_string_array_flag_op!(left, right, regexp_is_match, false, true)
628
            }
629
            RegexNotMatch => {
630
0
                binary_string_array_flag_op!(left, right, regexp_is_match, true, false)
631
            }
632
            RegexNotIMatch => {
633
0
                binary_string_array_flag_op!(left, right, regexp_is_match, true, true)
634
            }
635
0
            BitwiseAnd => bitwise_and_dyn(left, right),
636
0
            BitwiseOr => bitwise_or_dyn(left, right),
637
0
            BitwiseXor => bitwise_xor_dyn(left, right),
638
0
            BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
639
0
            BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
640
0
            StringConcat => concat_elements(left, right),
641
            AtArrow | ArrowAt => {
642
0
                unreachable!("ArrowAt and AtArrow should be rewritten to function")
643
            }
644
        }
645
22.1k
    }
646
}
647
648
0
fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> Result<ArrayRef> {
649
0
    Ok(match left.data_type() {
650
0
        DataType::Utf8 => Arc::new(concat_elements_utf8(
651
0
            left.as_string::<i32>(),
652
0
            right.as_string::<i32>(),
653
0
        )?),
654
0
        DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
655
0
            left.as_string::<i64>(),
656
0
            right.as_string::<i64>(),
657
0
        )?),
658
0
        DataType::Utf8View => Arc::new(concat_elements_utf8view(
659
0
            left.as_string_view(),
660
0
            right.as_string_view(),
661
0
        )?),
662
0
        other => {
663
0
            return internal_err!(
664
0
                "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays"
665
0
            );
666
        }
667
    })
668
0
}
669
670
/// Create a binary expression whose arguments are correctly coerced.
671
/// This function errors if it is not possible to coerce the arguments
672
/// to computational types supported by the operator.
673
458
pub fn binary(
674
458
    lhs: Arc<dyn PhysicalExpr>,
675
458
    op: Operator,
676
458
    rhs: Arc<dyn PhysicalExpr>,
677
458
    _input_schema: &Schema,
678
458
) -> Result<Arc<dyn PhysicalExpr>> {
679
458
    Ok(Arc::new(BinaryExpr::new(lhs, op, rhs)))
680
458
}
681
682
/// Create a similar to expression
683
0
pub fn similar_to(
684
0
    negated: bool,
685
0
    case_insensitive: bool,
686
0
    expr: Arc<dyn PhysicalExpr>,
687
0
    pattern: Arc<dyn PhysicalExpr>,
688
0
) -> Result<Arc<dyn PhysicalExpr>> {
689
0
    let binary_op = match (negated, case_insensitive) {
690
0
        (false, false) => Operator::RegexMatch,
691
0
        (false, true) => Operator::RegexIMatch,
692
0
        (true, false) => Operator::RegexNotMatch,
693
0
        (true, true) => Operator::RegexNotIMatch,
694
    };
695
0
    Ok(Arc::new(BinaryExpr::new(expr, binary_op, pattern)))
696
0
}
697
698
#[cfg(test)]
699
mod tests {
700
    use super::*;
701
    use crate::expressions::{col, lit, try_cast, Column, Literal};
702
    use datafusion_common::plan_datafusion_err;
703
    use datafusion_expr::type_coercion::binary::get_input_types;
704
705
    /// Performs a binary operation, applying any type coercion necessary
706
    fn binary_op(
707
        left: Arc<dyn PhysicalExpr>,
708
        op: Operator,
709
        right: Arc<dyn PhysicalExpr>,
710
        schema: &Schema,
711
    ) -> Result<Arc<dyn PhysicalExpr>> {
712
        let left_type = left.data_type(schema)?;
713
        let right_type = right.data_type(schema)?;
714
        let (lhs, rhs) = get_input_types(&left_type, &op, &right_type)?;
715
716
        let left_expr = try_cast(left, schema, lhs)?;
717
        let right_expr = try_cast(right, schema, rhs)?;
718
        binary(left_expr, op, right_expr, schema)
719
    }
720
721
    #[test]
722
    fn binary_comparison() -> Result<()> {
723
        let schema = Schema::new(vec![
724
            Field::new("a", DataType::Int32, false),
725
            Field::new("b", DataType::Int32, false),
726
        ]);
727
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
728
        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
729
730
        // expression: "a < b"
731
        let lt = binary(
732
            col("a", &schema)?,
733
            Operator::Lt,
734
            col("b", &schema)?,
735
            &schema,
736
        )?;
737
        let batch =
738
            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
739
740
        let result = lt
741
            .evaluate(&batch)?
742
            .into_array(batch.num_rows())
743
            .expect("Failed to convert to array");
744
        assert_eq!(result.len(), 5);
745
746
        let expected = [false, false, true, true, true];
747
        let result =
748
            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
749
        for (i, &expected_item) in expected.iter().enumerate().take(5) {
750
            assert_eq!(result.value(i), expected_item);
751
        }
752
753
        Ok(())
754
    }
755
756
    #[test]
757
    fn binary_nested() -> Result<()> {
758
        let schema = Schema::new(vec![
759
            Field::new("a", DataType::Int32, false),
760
            Field::new("b", DataType::Int32, false),
761
        ]);
762
        let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
763
        let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
764
765
        // expression: "a < b OR a == b"
766
        let expr = binary(
767
            binary(
768
                col("a", &schema)?,
769
                Operator::Lt,
770
                col("b", &schema)?,
771
                &schema,
772
            )?,
773
            Operator::Or,
774
            binary(
775
                col("a", &schema)?,
776
                Operator::Eq,
777
                col("b", &schema)?,
778
                &schema,
779
            )?,
780
            &schema,
781
        )?;
782
        let batch =
783
            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
784
785
        assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}"));
786
787
        let result = expr
788
            .evaluate(&batch)?
789
            .into_array(batch.num_rows())
790
            .expect("Failed to convert to array");
791
        assert_eq!(result.len(), 5);
792
793
        let expected = [true, true, false, true, false];
794
        let result =
795
            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
796
        for (i, &expected_item) in expected.iter().enumerate().take(5) {
797
            assert_eq!(result.value(i), expected_item);
798
        }
799
800
        Ok(())
801
    }
802
803
    // runs an end-to-end test of physical type coercion:
804
    // 1. construct a record batch with two columns of type A and B
805
    //  (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
806
    // 2. construct a physical expression of A OP B
807
    // 3. evaluate the expression
808
    // 4. verify that the resulting expression is of type C
809
    // 5. verify that the results of evaluation are $VEC
810
    macro_rules! test_coercion {
811
        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr,) => {{
812
            let schema = Schema::new(vec![
813
                Field::new("a", $A_TYPE, false),
814
                Field::new("b", $B_TYPE, false),
815
            ]);
816
            let a = $A_ARRAY::from($A_VEC);
817
            let b = $B_ARRAY::from($B_VEC);
818
            let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?;
819
820
            let left = try_cast(col("a", &schema)?, &schema, lhs)?;
821
            let right = try_cast(col("b", &schema)?, &schema, rhs)?;
822
823
            // verify that we can construct the expression
824
            let expression = binary(left, $OP, right, &schema)?;
825
            let batch = RecordBatch::try_new(
826
                Arc::new(schema.clone()),
827
                vec![Arc::new(a), Arc::new(b)],
828
            )?;
829
830
            // verify that the expression's type is correct
831
            assert_eq!(expression.data_type(&schema)?, $C_TYPE);
832
833
            // compute
834
            let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
835
836
            // verify that the array's data_type is correct
837
            assert_eq!(*result.data_type(), $C_TYPE);
838
839
            // verify that the data itself is downcastable
840
            let result = result
841
                .as_any()
842
                .downcast_ref::<$C_ARRAY>()
843
                .expect("failed to downcast");
844
            // verify that the result itself is correct
845
            for (i, x) in $VEC.iter().enumerate() {
846
                let v = result.value(i);
847
                assert_eq!(
848
                    v,
849
                    *x,
850
                    "Unexpected output at position {i}:\n\nActual:\n{v}\n\nExpected:\n{x}"
851
                );
852
            }
853
        }};
854
    }
855
856
    #[test]
857
    fn test_type_coercion() -> Result<()> {
858
        test_coercion!(
859
            Int32Array,
860
            DataType::Int32,
861
            vec![1i32, 2i32],
862
            UInt32Array,
863
            DataType::UInt32,
864
            vec![1u32, 2u32],
865
            Operator::Plus,
866
            Int32Array,
867
            DataType::Int32,
868
            [2i32, 4i32],
869
        );
870
        test_coercion!(
871
            Int32Array,
872
            DataType::Int32,
873
            vec![1i32],
874
            UInt16Array,
875
            DataType::UInt16,
876
            vec![1u16],
877
            Operator::Plus,
878
            Int32Array,
879
            DataType::Int32,
880
            [2i32],
881
        );
882
        test_coercion!(
883
            Float32Array,
884
            DataType::Float32,
885
            vec![1f32],
886
            UInt16Array,
887
            DataType::UInt16,
888
            vec![1u16],
889
            Operator::Plus,
890
            Float32Array,
891
            DataType::Float32,
892
            [2f32],
893
        );
894
        test_coercion!(
895
            Float32Array,
896
            DataType::Float32,
897
            vec![2f32],
898
            UInt16Array,
899
            DataType::UInt16,
900
            vec![1u16],
901
            Operator::Multiply,
902
            Float32Array,
903
            DataType::Float32,
904
            [2f32],
905
        );
906
        test_coercion!(
907
            StringArray,
908
            DataType::Utf8,
909
            vec!["1994-12-13", "1995-01-26"],
910
            Date32Array,
911
            DataType::Date32,
912
            vec![9112, 9156],
913
            Operator::Eq,
914
            BooleanArray,
915
            DataType::Boolean,
916
            [true, true],
917
        );
918
        test_coercion!(
919
            StringArray,
920
            DataType::Utf8,
921
            vec!["1994-12-13", "1995-01-26"],
922
            Date32Array,
923
            DataType::Date32,
924
            vec![9113, 9154],
925
            Operator::Lt,
926
            BooleanArray,
927
            DataType::Boolean,
928
            [true, false],
929
        );
930
        test_coercion!(
931
            StringArray,
932
            DataType::Utf8,
933
            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
934
            Date64Array,
935
            DataType::Date64,
936
            vec![787322096000, 791083425000],
937
            Operator::Eq,
938
            BooleanArray,
939
            DataType::Boolean,
940
            [true, true],
941
        );
942
        test_coercion!(
943
            StringArray,
944
            DataType::Utf8,
945
            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
946
            Date64Array,
947
            DataType::Date64,
948
            vec![787322096001, 791083424999],
949
            Operator::Lt,
950
            BooleanArray,
951
            DataType::Boolean,
952
            [true, false],
953
        );
954
        test_coercion!(
955
            StringViewArray,
956
            DataType::Utf8View,
957
            vec!["abc"; 5],
958
            StringArray,
959
            DataType::Utf8,
960
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
961
            Operator::RegexMatch,
962
            BooleanArray,
963
            DataType::Boolean,
964
            [true, false, true, false, false],
965
        );
966
        test_coercion!(
967
            StringViewArray,
968
            DataType::Utf8View,
969
            vec!["abc"; 5],
970
            StringArray,
971
            DataType::Utf8,
972
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
973
            Operator::RegexIMatch,
974
            BooleanArray,
975
            DataType::Boolean,
976
            [true, true, true, true, false],
977
        );
978
        test_coercion!(
979
            StringArray,
980
            DataType::Utf8,
981
            vec!["abc"; 5],
982
            StringViewArray,
983
            DataType::Utf8View,
984
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
985
            Operator::RegexNotMatch,
986
            BooleanArray,
987
            DataType::Boolean,
988
            [false, true, false, true, true],
989
        );
990
        test_coercion!(
991
            StringArray,
992
            DataType::Utf8,
993
            vec!["abc"; 5],
994
            StringViewArray,
995
            DataType::Utf8View,
996
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
997
            Operator::RegexNotIMatch,
998
            BooleanArray,
999
            DataType::Boolean,
1000
            [false, false, false, false, true],
1001
        );
1002
        test_coercion!(
1003
            StringArray,
1004
            DataType::Utf8,
1005
            vec!["abc"; 5],
1006
            StringArray,
1007
            DataType::Utf8,
1008
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1009
            Operator::RegexMatch,
1010
            BooleanArray,
1011
            DataType::Boolean,
1012
            [true, false, true, false, false],
1013
        );
1014
        test_coercion!(
1015
            StringArray,
1016
            DataType::Utf8,
1017
            vec!["abc"; 5],
1018
            StringArray,
1019
            DataType::Utf8,
1020
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1021
            Operator::RegexIMatch,
1022
            BooleanArray,
1023
            DataType::Boolean,
1024
            [true, true, true, true, false],
1025
        );
1026
        test_coercion!(
1027
            StringArray,
1028
            DataType::Utf8,
1029
            vec!["abc"; 5],
1030
            StringArray,
1031
            DataType::Utf8,
1032
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1033
            Operator::RegexNotMatch,
1034
            BooleanArray,
1035
            DataType::Boolean,
1036
            [false, true, false, true, true],
1037
        );
1038
        test_coercion!(
1039
            StringArray,
1040
            DataType::Utf8,
1041
            vec!["abc"; 5],
1042
            StringArray,
1043
            DataType::Utf8,
1044
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1045
            Operator::RegexNotIMatch,
1046
            BooleanArray,
1047
            DataType::Boolean,
1048
            [false, false, false, false, true],
1049
        );
1050
        test_coercion!(
1051
            LargeStringArray,
1052
            DataType::LargeUtf8,
1053
            vec!["abc"; 5],
1054
            LargeStringArray,
1055
            DataType::LargeUtf8,
1056
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1057
            Operator::RegexMatch,
1058
            BooleanArray,
1059
            DataType::Boolean,
1060
            [true, false, true, false, false],
1061
        );
1062
        test_coercion!(
1063
            LargeStringArray,
1064
            DataType::LargeUtf8,
1065
            vec!["abc"; 5],
1066
            LargeStringArray,
1067
            DataType::LargeUtf8,
1068
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1069
            Operator::RegexIMatch,
1070
            BooleanArray,
1071
            DataType::Boolean,
1072
            [true, true, true, true, false],
1073
        );
1074
        test_coercion!(
1075
            LargeStringArray,
1076
            DataType::LargeUtf8,
1077
            vec!["abc"; 5],
1078
            LargeStringArray,
1079
            DataType::LargeUtf8,
1080
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1081
            Operator::RegexNotMatch,
1082
            BooleanArray,
1083
            DataType::Boolean,
1084
            [false, true, false, true, true],
1085
        );
1086
        test_coercion!(
1087
            LargeStringArray,
1088
            DataType::LargeUtf8,
1089
            vec!["abc"; 5],
1090
            LargeStringArray,
1091
            DataType::LargeUtf8,
1092
            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1093
            Operator::RegexNotIMatch,
1094
            BooleanArray,
1095
            DataType::Boolean,
1096
            [false, false, false, false, true],
1097
        );
1098
        test_coercion!(
1099
            StringArray,
1100
            DataType::Utf8,
1101
            vec!["abc"; 5],
1102
            StringArray,
1103
            DataType::Utf8,
1104
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1105
            Operator::LikeMatch,
1106
            BooleanArray,
1107
            DataType::Boolean,
1108
            [true, false, false, true, false],
1109
        );
1110
        test_coercion!(
1111
            StringArray,
1112
            DataType::Utf8,
1113
            vec!["abc"; 5],
1114
            StringArray,
1115
            DataType::Utf8,
1116
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1117
            Operator::ILikeMatch,
1118
            BooleanArray,
1119
            DataType::Boolean,
1120
            [true, true, false, true, true],
1121
        );
1122
        test_coercion!(
1123
            StringArray,
1124
            DataType::Utf8,
1125
            vec!["abc"; 5],
1126
            StringArray,
1127
            DataType::Utf8,
1128
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1129
            Operator::NotLikeMatch,
1130
            BooleanArray,
1131
            DataType::Boolean,
1132
            [false, true, true, false, true],
1133
        );
1134
        test_coercion!(
1135
            StringArray,
1136
            DataType::Utf8,
1137
            vec!["abc"; 5],
1138
            StringArray,
1139
            DataType::Utf8,
1140
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1141
            Operator::NotILikeMatch,
1142
            BooleanArray,
1143
            DataType::Boolean,
1144
            [false, false, true, false, false],
1145
        );
1146
        test_coercion!(
1147
            LargeStringArray,
1148
            DataType::LargeUtf8,
1149
            vec!["abc"; 5],
1150
            LargeStringArray,
1151
            DataType::LargeUtf8,
1152
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1153
            Operator::LikeMatch,
1154
            BooleanArray,
1155
            DataType::Boolean,
1156
            [true, false, false, true, false],
1157
        );
1158
        test_coercion!(
1159
            LargeStringArray,
1160
            DataType::LargeUtf8,
1161
            vec!["abc"; 5],
1162
            LargeStringArray,
1163
            DataType::LargeUtf8,
1164
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1165
            Operator::ILikeMatch,
1166
            BooleanArray,
1167
            DataType::Boolean,
1168
            [true, true, false, true, true],
1169
        );
1170
        test_coercion!(
1171
            LargeStringArray,
1172
            DataType::LargeUtf8,
1173
            vec!["abc"; 5],
1174
            LargeStringArray,
1175
            DataType::LargeUtf8,
1176
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1177
            Operator::NotLikeMatch,
1178
            BooleanArray,
1179
            DataType::Boolean,
1180
            [false, true, true, false, true],
1181
        );
1182
        test_coercion!(
1183
            LargeStringArray,
1184
            DataType::LargeUtf8,
1185
            vec!["abc"; 5],
1186
            LargeStringArray,
1187
            DataType::LargeUtf8,
1188
            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1189
            Operator::NotILikeMatch,
1190
            BooleanArray,
1191
            DataType::Boolean,
1192
            [false, false, true, false, false],
1193
        );
1194
        test_coercion!(
1195
            Int16Array,
1196
            DataType::Int16,
1197
            vec![1i16, 2i16, 3i16],
1198
            Int64Array,
1199
            DataType::Int64,
1200
            vec![10i64, 4i64, 5i64],
1201
            Operator::BitwiseAnd,
1202
            Int64Array,
1203
            DataType::Int64,
1204
            [0i64, 0i64, 1i64],
1205
        );
1206
        test_coercion!(
1207
            UInt16Array,
1208
            DataType::UInt16,
1209
            vec![1u16, 2u16, 3u16],
1210
            UInt64Array,
1211
            DataType::UInt64,
1212
            vec![10u64, 4u64, 5u64],
1213
            Operator::BitwiseAnd,
1214
            UInt64Array,
1215
            DataType::UInt64,
1216
            [0u64, 0u64, 1u64],
1217
        );
1218
        test_coercion!(
1219
            Int16Array,
1220
            DataType::Int16,
1221
            vec![3i16, 2i16, 3i16],
1222
            Int64Array,
1223
            DataType::Int64,
1224
            vec![10i64, 6i64, 5i64],
1225
            Operator::BitwiseOr,
1226
            Int64Array,
1227
            DataType::Int64,
1228
            [11i64, 6i64, 7i64],
1229
        );
1230
        test_coercion!(
1231
            UInt16Array,
1232
            DataType::UInt16,
1233
            vec![1u16, 2u16, 3u16],
1234
            UInt64Array,
1235
            DataType::UInt64,
1236
            vec![10u64, 4u64, 5u64],
1237
            Operator::BitwiseOr,
1238
            UInt64Array,
1239
            DataType::UInt64,
1240
            [11u64, 6u64, 7u64],
1241
        );
1242
        test_coercion!(
1243
            Int16Array,
1244
            DataType::Int16,
1245
            vec![3i16, 2i16, 3i16],
1246
            Int64Array,
1247
            DataType::Int64,
1248
            vec![10i64, 6i64, 5i64],
1249
            Operator::BitwiseXor,
1250
            Int64Array,
1251
            DataType::Int64,
1252
            [9i64, 4i64, 6i64],
1253
        );
1254
        test_coercion!(
1255
            UInt16Array,
1256
            DataType::UInt16,
1257
            vec![3u16, 2u16, 3u16],
1258
            UInt64Array,
1259
            DataType::UInt64,
1260
            vec![10u64, 6u64, 5u64],
1261
            Operator::BitwiseXor,
1262
            UInt64Array,
1263
            DataType::UInt64,
1264
            [9u64, 4u64, 6u64],
1265
        );
1266
        test_coercion!(
1267
            Int16Array,
1268
            DataType::Int16,
1269
            vec![4i16, 27i16, 35i16],
1270
            Int64Array,
1271
            DataType::Int64,
1272
            vec![2i64, 3i64, 4i64],
1273
            Operator::BitwiseShiftRight,
1274
            Int64Array,
1275
            DataType::Int64,
1276
            [1i64, 3i64, 2i64],
1277
        );
1278
        test_coercion!(
1279
            UInt16Array,
1280
            DataType::UInt16,
1281
            vec![4u16, 27u16, 35u16],
1282
            UInt64Array,
1283
            DataType::UInt64,
1284
            vec![2u64, 3u64, 4u64],
1285
            Operator::BitwiseShiftRight,
1286
            UInt64Array,
1287
            DataType::UInt64,
1288
            [1u64, 3u64, 2u64],
1289
        );
1290
        test_coercion!(
1291
            Int16Array,
1292
            DataType::Int16,
1293
            vec![2i16, 3i16, 4i16],
1294
            Int64Array,
1295
            DataType::Int64,
1296
            vec![4i64, 12i64, 7i64],
1297
            Operator::BitwiseShiftLeft,
1298
            Int64Array,
1299
            DataType::Int64,
1300
            [32i64, 12288i64, 512i64],
1301
        );
1302
        test_coercion!(
1303
            UInt16Array,
1304
            DataType::UInt16,
1305
            vec![2u16, 3u16, 4u16],
1306
            UInt64Array,
1307
            DataType::UInt64,
1308
            vec![4u64, 12u64, 7u64],
1309
            Operator::BitwiseShiftLeft,
1310
            UInt64Array,
1311
            DataType::UInt64,
1312
            [32u64, 12288u64, 512u64],
1313
        );
1314
        Ok(())
1315
    }
1316
1317
    // Note it would be nice to use the same test_coercion macro as
1318
    // above, but sadly the type of the values of the dictionary are
1319
    // not encoded in the rust type of the DictionaryArray. Thus there
1320
    // is no way at the time of this writing to create a dictionary
1321
    // array using the `From` trait
1322
    #[test]
1323
    fn test_dictionary_type_to_array_coercion() -> Result<()> {
1324
        // Test string  a string dictionary
1325
        let dict_type =
1326
            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1327
        let string_type = DataType::Utf8;
1328
1329
        // build dictionary
1330
        let mut dict_builder = StringDictionaryBuilder::<Int32Type>::new();
1331
1332
        dict_builder.append("one")?;
1333
        dict_builder.append_null();
1334
        dict_builder.append("three")?;
1335
        dict_builder.append("four")?;
1336
        let dict_array = Arc::new(dict_builder.finish()) as ArrayRef;
1337
1338
        let str_array = Arc::new(StringArray::from(vec![
1339
            Some("not one"),
1340
            Some("two"),
1341
            None,
1342
            Some("four"),
1343
        ])) as ArrayRef;
1344
1345
        let schema = Arc::new(Schema::new(vec![
1346
            Field::new("a", dict_type.clone(), true),
1347
            Field::new("b", string_type.clone(), true),
1348
        ]));
1349
1350
        // Test 1: a = b
1351
        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1352
        apply_logic_op(&schema, &dict_array, &str_array, Operator::Eq, result)?;
1353
1354
        // Test 2: now test the other direction
1355
        // b = a
1356
        let schema = Arc::new(Schema::new(vec![
1357
            Field::new("a", string_type, true),
1358
            Field::new("b", dict_type, true),
1359
        ]));
1360
        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1361
        apply_logic_op(&schema, &str_array, &dict_array, Operator::Eq, result)?;
1362
1363
        Ok(())
1364
    }
1365
1366
    #[test]
1367
    fn plus_op() -> Result<()> {
1368
        let schema = Schema::new(vec![
1369
            Field::new("a", DataType::Int32, false),
1370
            Field::new("b", DataType::Int32, false),
1371
        ]);
1372
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1373
        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1374
1375
        apply_arithmetic::<Int32Type>(
1376
            Arc::new(schema),
1377
            vec![Arc::new(a), Arc::new(b)],
1378
            Operator::Plus,
1379
            Int32Array::from(vec![2, 4, 7, 12, 21]),
1380
        )?;
1381
1382
        Ok(())
1383
    }
1384
1385
    #[test]
1386
    fn plus_op_dict() -> Result<()> {
1387
        let schema = Schema::new(vec![
1388
            Field::new(
1389
                "a",
1390
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1391
                true,
1392
            ),
1393
            Field::new(
1394
                "b",
1395
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1396
                true,
1397
            ),
1398
        ]);
1399
1400
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1401
        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1402
        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1403
1404
        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1405
        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1406
        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1407
1408
        apply_arithmetic::<Int32Type>(
1409
            Arc::new(schema),
1410
            vec![Arc::new(a), Arc::new(b)],
1411
            Operator::Plus,
1412
            Int32Array::from(vec![Some(2), None, Some(4), Some(8), None]),
1413
        )?;
1414
1415
        Ok(())
1416
    }
1417
1418
    #[test]
1419
    fn plus_op_dict_decimal() -> Result<()> {
1420
        let schema = Schema::new(vec![
1421
            Field::new(
1422
                "a",
1423
                DataType::Dictionary(
1424
                    Box::new(DataType::Int8),
1425
                    Box::new(DataType::Decimal128(10, 0)),
1426
                ),
1427
                true,
1428
            ),
1429
            Field::new(
1430
                "b",
1431
                DataType::Dictionary(
1432
                    Box::new(DataType::Int8),
1433
                    Box::new(DataType::Decimal128(10, 0)),
1434
                ),
1435
                true,
1436
            ),
1437
        ]);
1438
1439
        let value = 123;
1440
        let decimal_array = Arc::new(create_decimal_array(
1441
            &[
1442
                Some(value),
1443
                Some(value + 2),
1444
                Some(value - 1),
1445
                Some(value + 1),
1446
            ],
1447
            10,
1448
            0,
1449
        ));
1450
1451
        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1452
        let a = DictionaryArray::try_new(keys, decimal_array)?;
1453
1454
        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1455
        let decimal_array = Arc::new(create_decimal_array(
1456
            &[
1457
                Some(value + 1),
1458
                Some(value + 3),
1459
                Some(value),
1460
                Some(value + 2),
1461
            ],
1462
            10,
1463
            0,
1464
        ));
1465
        let b = DictionaryArray::try_new(keys, decimal_array)?;
1466
1467
        apply_arithmetic(
1468
            Arc::new(schema),
1469
            vec![Arc::new(a), Arc::new(b)],
1470
            Operator::Plus,
1471
            create_decimal_array(&[Some(247), None, None, Some(247), Some(246)], 11, 0),
1472
        )?;
1473
1474
        Ok(())
1475
    }
1476
1477
    #[test]
1478
    fn plus_op_scalar() -> Result<()> {
1479
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1480
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1481
1482
        apply_arithmetic_scalar(
1483
            Arc::new(schema),
1484
            vec![Arc::new(a)],
1485
            Operator::Plus,
1486
            ScalarValue::Int32(Some(1)),
1487
            Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
1488
        )?;
1489
1490
        Ok(())
1491
    }
1492
1493
    #[test]
1494
    fn plus_op_dict_scalar() -> Result<()> {
1495
        let schema = Schema::new(vec![Field::new(
1496
            "a",
1497
            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1498
            true,
1499
        )]);
1500
1501
        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1502
1503
        dict_builder.append(1)?;
1504
        dict_builder.append_null();
1505
        dict_builder.append(2)?;
1506
        dict_builder.append(5)?;
1507
1508
        let a = dict_builder.finish();
1509
1510
        let expected: PrimitiveArray<Int32Type> =
1511
            PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);
1512
1513
        apply_arithmetic_scalar(
1514
            Arc::new(schema),
1515
            vec![Arc::new(a)],
1516
            Operator::Plus,
1517
            ScalarValue::Dictionary(
1518
                Box::new(DataType::Int8),
1519
                Box::new(ScalarValue::Int32(Some(1))),
1520
            ),
1521
            Arc::new(expected),
1522
        )?;
1523
1524
        Ok(())
1525
    }
1526
1527
    #[test]
1528
    fn plus_op_dict_scalar_decimal() -> Result<()> {
1529
        let schema = Schema::new(vec![Field::new(
1530
            "a",
1531
            DataType::Dictionary(
1532
                Box::new(DataType::Int8),
1533
                Box::new(DataType::Decimal128(10, 0)),
1534
            ),
1535
            true,
1536
        )]);
1537
1538
        let value = 123;
1539
        let decimal_array = Arc::new(create_decimal_array(
1540
            &[Some(value), None, Some(value - 1), Some(value + 1)],
1541
            10,
1542
            0,
1543
        ));
1544
1545
        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1546
        let a = DictionaryArray::try_new(keys, decimal_array)?;
1547
1548
        let decimal_array = Arc::new(create_decimal_array(
1549
            &[
1550
                Some(value + 1),
1551
                Some(value),
1552
                None,
1553
                Some(value + 2),
1554
                Some(value + 1),
1555
            ],
1556
            11,
1557
            0,
1558
        ));
1559
1560
        apply_arithmetic_scalar(
1561
            Arc::new(schema),
1562
            vec![Arc::new(a)],
1563
            Operator::Plus,
1564
            ScalarValue::Dictionary(
1565
                Box::new(DataType::Int8),
1566
                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1567
            ),
1568
            decimal_array,
1569
        )?;
1570
1571
        Ok(())
1572
    }
1573
1574
    #[test]
1575
    fn minus_op() -> Result<()> {
1576
        let schema = Arc::new(Schema::new(vec![
1577
            Field::new("a", DataType::Int32, false),
1578
            Field::new("b", DataType::Int32, false),
1579
        ]));
1580
        let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
1581
        let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1582
1583
        apply_arithmetic::<Int32Type>(
1584
            Arc::clone(&schema),
1585
            vec![
1586
                Arc::clone(&a) as Arc<dyn Array>,
1587
                Arc::clone(&b) as Arc<dyn Array>,
1588
            ],
1589
            Operator::Minus,
1590
            Int32Array::from(vec![0, 0, 1, 4, 11]),
1591
        )?;
1592
1593
        // should handle have negative values in result (for signed)
1594
        apply_arithmetic::<Int32Type>(
1595
            schema,
1596
            vec![b, a],
1597
            Operator::Minus,
1598
            Int32Array::from(vec![0, 0, -1, -4, -11]),
1599
        )?;
1600
1601
        Ok(())
1602
    }
1603
1604
    #[test]
1605
    fn minus_op_dict() -> Result<()> {
1606
        let schema = Schema::new(vec![
1607
            Field::new(
1608
                "a",
1609
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1610
                true,
1611
            ),
1612
            Field::new(
1613
                "b",
1614
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1615
                true,
1616
            ),
1617
        ]);
1618
1619
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1620
        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1621
        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1622
1623
        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1624
        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1625
        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1626
1627
        apply_arithmetic::<Int32Type>(
1628
            Arc::new(schema),
1629
            vec![Arc::new(a), Arc::new(b)],
1630
            Operator::Minus,
1631
            Int32Array::from(vec![Some(0), None, Some(0), Some(0), None]),
1632
        )?;
1633
1634
        Ok(())
1635
    }
1636
1637
    #[test]
1638
    fn minus_op_dict_decimal() -> Result<()> {
1639
        let schema = Schema::new(vec![
1640
            Field::new(
1641
                "a",
1642
                DataType::Dictionary(
1643
                    Box::new(DataType::Int8),
1644
                    Box::new(DataType::Decimal128(10, 0)),
1645
                ),
1646
                true,
1647
            ),
1648
            Field::new(
1649
                "b",
1650
                DataType::Dictionary(
1651
                    Box::new(DataType::Int8),
1652
                    Box::new(DataType::Decimal128(10, 0)),
1653
                ),
1654
                true,
1655
            ),
1656
        ]);
1657
1658
        let value = 123;
1659
        let decimal_array = Arc::new(create_decimal_array(
1660
            &[
1661
                Some(value),
1662
                Some(value + 2),
1663
                Some(value - 1),
1664
                Some(value + 1),
1665
            ],
1666
            10,
1667
            0,
1668
        ));
1669
1670
        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1671
        let a = DictionaryArray::try_new(keys, decimal_array)?;
1672
1673
        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1674
        let decimal_array = Arc::new(create_decimal_array(
1675
            &[
1676
                Some(value + 1),
1677
                Some(value + 3),
1678
                Some(value),
1679
                Some(value + 2),
1680
            ],
1681
            10,
1682
            0,
1683
        ));
1684
        let b = DictionaryArray::try_new(keys, decimal_array)?;
1685
1686
        apply_arithmetic(
1687
            Arc::new(schema),
1688
            vec![Arc::new(a), Arc::new(b)],
1689
            Operator::Minus,
1690
            create_decimal_array(&[Some(-1), None, None, Some(1), Some(0)], 11, 0),
1691
        )?;
1692
1693
        Ok(())
1694
    }
1695
1696
    #[test]
1697
    fn minus_op_scalar() -> Result<()> {
1698
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1699
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1700
1701
        apply_arithmetic_scalar(
1702
            Arc::new(schema),
1703
            vec![Arc::new(a)],
1704
            Operator::Minus,
1705
            ScalarValue::Int32(Some(1)),
1706
            Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
1707
        )?;
1708
1709
        Ok(())
1710
    }
1711
1712
    #[test]
1713
    fn minus_op_dict_scalar() -> Result<()> {
1714
        let schema = Schema::new(vec![Field::new(
1715
            "a",
1716
            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1717
            true,
1718
        )]);
1719
1720
        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1721
1722
        dict_builder.append(1)?;
1723
        dict_builder.append_null();
1724
        dict_builder.append(2)?;
1725
        dict_builder.append(5)?;
1726
1727
        let a = dict_builder.finish();
1728
1729
        let expected: PrimitiveArray<Int32Type> =
1730
            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);
1731
1732
        apply_arithmetic_scalar(
1733
            Arc::new(schema),
1734
            vec![Arc::new(a)],
1735
            Operator::Minus,
1736
            ScalarValue::Dictionary(
1737
                Box::new(DataType::Int8),
1738
                Box::new(ScalarValue::Int32(Some(1))),
1739
            ),
1740
            Arc::new(expected),
1741
        )?;
1742
1743
        Ok(())
1744
    }
1745
1746
    #[test]
1747
    fn minus_op_dict_scalar_decimal() -> Result<()> {
1748
        let schema = Schema::new(vec![Field::new(
1749
            "a",
1750
            DataType::Dictionary(
1751
                Box::new(DataType::Int8),
1752
                Box::new(DataType::Decimal128(10, 0)),
1753
            ),
1754
            true,
1755
        )]);
1756
1757
        let value = 123;
1758
        let decimal_array = Arc::new(create_decimal_array(
1759
            &[Some(value), None, Some(value - 1), Some(value + 1)],
1760
            10,
1761
            0,
1762
        ));
1763
1764
        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1765
        let a = DictionaryArray::try_new(keys, decimal_array)?;
1766
1767
        let decimal_array = Arc::new(create_decimal_array(
1768
            &[
1769
                Some(value - 1),
1770
                Some(value - 2),
1771
                None,
1772
                Some(value),
1773
                Some(value - 1),
1774
            ],
1775
            11,
1776
            0,
1777
        ));
1778
1779
        apply_arithmetic_scalar(
1780
            Arc::new(schema),
1781
            vec![Arc::new(a)],
1782
            Operator::Minus,
1783
            ScalarValue::Dictionary(
1784
                Box::new(DataType::Int8),
1785
                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1786
            ),
1787
            decimal_array,
1788
        )?;
1789
1790
        Ok(())
1791
    }
1792
1793
    #[test]
1794
    fn multiply_op() -> Result<()> {
1795
        let schema = Arc::new(Schema::new(vec![
1796
            Field::new("a", DataType::Int32, false),
1797
            Field::new("b", DataType::Int32, false),
1798
        ]));
1799
        let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
1800
        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
1801
1802
        apply_arithmetic::<Int32Type>(
1803
            schema,
1804
            vec![a, b],
1805
            Operator::Multiply,
1806
            Int32Array::from(vec![8, 32, 128, 512, 2048]),
1807
        )?;
1808
1809
        Ok(())
1810
    }
1811
1812
    #[test]
1813
    fn multiply_op_dict() -> Result<()> {
1814
        let schema = Schema::new(vec![
1815
            Field::new(
1816
                "a",
1817
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1818
                true,
1819
            ),
1820
            Field::new(
1821
                "b",
1822
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1823
                true,
1824
            ),
1825
        ]);
1826
1827
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1828
        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1829
        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1830
1831
        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1832
        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1833
        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1834
1835
        apply_arithmetic::<Int32Type>(
1836
            Arc::new(schema),
1837
            vec![Arc::new(a), Arc::new(b)],
1838
            Operator::Multiply,
1839
            Int32Array::from(vec![Some(1), None, Some(4), Some(16), None]),
1840
        )?;
1841
1842
        Ok(())
1843
    }
1844
1845
    #[test]
1846
    fn multiply_op_dict_decimal() -> Result<()> {
1847
        let schema = Schema::new(vec![
1848
            Field::new(
1849
                "a",
1850
                DataType::Dictionary(
1851
                    Box::new(DataType::Int8),
1852
                    Box::new(DataType::Decimal128(10, 0)),
1853
                ),
1854
                true,
1855
            ),
1856
            Field::new(
1857
                "b",
1858
                DataType::Dictionary(
1859
                    Box::new(DataType::Int8),
1860
                    Box::new(DataType::Decimal128(10, 0)),
1861
                ),
1862
                true,
1863
            ),
1864
        ]);
1865
1866
        let value = 123;
1867
        let decimal_array = Arc::new(create_decimal_array(
1868
            &[
1869
                Some(value),
1870
                Some(value + 2),
1871
                Some(value - 1),
1872
                Some(value + 1),
1873
            ],
1874
            10,
1875
            0,
1876
        )) as ArrayRef;
1877
1878
        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1879
        let a = DictionaryArray::try_new(keys, decimal_array)?;
1880
1881
        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1882
        let decimal_array = Arc::new(create_decimal_array(
1883
            &[
1884
                Some(value + 1),
1885
                Some(value + 3),
1886
                Some(value),
1887
                Some(value + 2),
1888
            ],
1889
            10,
1890
            0,
1891
        ));
1892
        let b = DictionaryArray::try_new(keys, decimal_array)?;
1893
1894
        apply_arithmetic(
1895
            Arc::new(schema),
1896
            vec![Arc::new(a), Arc::new(b)],
1897
            Operator::Multiply,
1898
            create_decimal_array(
1899
                &[Some(15252), None, None, Some(15252), Some(15129)],
1900
                21,
1901
                0,
1902
            ),
1903
        )?;
1904
1905
        Ok(())
1906
    }
1907
1908
    #[test]
1909
    fn multiply_op_scalar() -> Result<()> {
1910
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1911
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1912
1913
        apply_arithmetic_scalar(
1914
            Arc::new(schema),
1915
            vec![Arc::new(a)],
1916
            Operator::Multiply,
1917
            ScalarValue::Int32(Some(2)),
1918
            Arc::new(Int32Array::from(vec![2, 4, 6, 8, 10])),
1919
        )?;
1920
1921
        Ok(())
1922
    }
1923
1924
    #[test]
1925
    fn multiply_op_dict_scalar() -> Result<()> {
1926
        let schema = Schema::new(vec![Field::new(
1927
            "a",
1928
            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1929
            true,
1930
        )]);
1931
1932
        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1933
1934
        dict_builder.append(1)?;
1935
        dict_builder.append_null();
1936
        dict_builder.append(2)?;
1937
        dict_builder.append(5)?;
1938
1939
        let a = dict_builder.finish();
1940
1941
        let expected: PrimitiveArray<Int32Type> =
1942
            PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);
1943
1944
        apply_arithmetic_scalar(
1945
            Arc::new(schema),
1946
            vec![Arc::new(a)],
1947
            Operator::Multiply,
1948
            ScalarValue::Dictionary(
1949
                Box::new(DataType::Int8),
1950
                Box::new(ScalarValue::Int32(Some(2))),
1951
            ),
1952
            Arc::new(expected),
1953
        )?;
1954
1955
        Ok(())
1956
    }
1957
1958
    #[test]
1959
    fn multiply_op_dict_scalar_decimal() -> Result<()> {
1960
        let schema = Schema::new(vec![Field::new(
1961
            "a",
1962
            DataType::Dictionary(
1963
                Box::new(DataType::Int8),
1964
                Box::new(DataType::Decimal128(10, 0)),
1965
            ),
1966
            true,
1967
        )]);
1968
1969
        let value = 123;
1970
        let decimal_array = Arc::new(create_decimal_array(
1971
            &[Some(value), None, Some(value - 1), Some(value + 1)],
1972
            10,
1973
            0,
1974
        ));
1975
1976
        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1977
        let a = DictionaryArray::try_new(keys, decimal_array)?;
1978
1979
        let decimal_array = Arc::new(create_decimal_array(
1980
            &[Some(246), Some(244), None, Some(248), Some(246)],
1981
            21,
1982
            0,
1983
        ));
1984
1985
        apply_arithmetic_scalar(
1986
            Arc::new(schema),
1987
            vec![Arc::new(a)],
1988
            Operator::Multiply,
1989
            ScalarValue::Dictionary(
1990
                Box::new(DataType::Int8),
1991
                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
1992
            ),
1993
            decimal_array,
1994
        )?;
1995
1996
        Ok(())
1997
    }
1998
1999
    #[test]
2000
    fn divide_op() -> Result<()> {
2001
        let schema = Arc::new(Schema::new(vec![
2002
            Field::new("a", DataType::Int32, false),
2003
            Field::new("b", DataType::Int32, false),
2004
        ]));
2005
        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2006
        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2007
2008
        apply_arithmetic::<Int32Type>(
2009
            schema,
2010
            vec![a, b],
2011
            Operator::Divide,
2012
            Int32Array::from(vec![4, 8, 16, 32, 64]),
2013
        )?;
2014
2015
        Ok(())
2016
    }
2017
2018
    #[test]
2019
    fn divide_op_dict() -> Result<()> {
2020
        let schema = Schema::new(vec![
2021
            Field::new(
2022
                "a",
2023
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2024
                true,
2025
            ),
2026
            Field::new(
2027
                "b",
2028
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2029
                true,
2030
            ),
2031
        ]);
2032
2033
        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2034
2035
        dict_builder.append(1)?;
2036
        dict_builder.append_null();
2037
        dict_builder.append(2)?;
2038
        dict_builder.append(5)?;
2039
        dict_builder.append(0)?;
2040
2041
        let a = dict_builder.finish();
2042
2043
        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2044
        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2045
        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2046
2047
        apply_arithmetic::<Int32Type>(
2048
            Arc::new(schema),
2049
            vec![Arc::new(a), Arc::new(b)],
2050
            Operator::Divide,
2051
            Int32Array::from(vec![Some(1), None, Some(1), Some(1), Some(0)]),
2052
        )?;
2053
2054
        Ok(())
2055
    }
2056
2057
    #[test]
2058
    fn divide_op_dict_decimal() -> Result<()> {
2059
        let schema = Schema::new(vec![
2060
            Field::new(
2061
                "a",
2062
                DataType::Dictionary(
2063
                    Box::new(DataType::Int8),
2064
                    Box::new(DataType::Decimal128(10, 0)),
2065
                ),
2066
                true,
2067
            ),
2068
            Field::new(
2069
                "b",
2070
                DataType::Dictionary(
2071
                    Box::new(DataType::Int8),
2072
                    Box::new(DataType::Decimal128(10, 0)),
2073
                ),
2074
                true,
2075
            ),
2076
        ]);
2077
2078
        let value = 123;
2079
        let decimal_array = Arc::new(create_decimal_array(
2080
            &[
2081
                Some(value),
2082
                Some(value + 2),
2083
                Some(value - 1),
2084
                Some(value + 1),
2085
            ],
2086
            10,
2087
            0,
2088
        ));
2089
2090
        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2091
        let a = DictionaryArray::try_new(keys, decimal_array)?;
2092
2093
        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2094
        let decimal_array = Arc::new(create_decimal_array(
2095
            &[
2096
                Some(value + 1),
2097
                Some(value + 3),
2098
                Some(value),
2099
                Some(value + 2),
2100
            ],
2101
            10,
2102
            0,
2103
        ));
2104
        let b = DictionaryArray::try_new(keys, decimal_array)?;
2105
2106
        apply_arithmetic(
2107
            Arc::new(schema),
2108
            vec![Arc::new(a), Arc::new(b)],
2109
            Operator::Divide,
2110
            create_decimal_array(
2111
                &[
2112
                    Some(9919), // 0.9919
2113
                    None,
2114
                    None,
2115
                    Some(10081), // 1.0081
2116
                    Some(10000), // 1.0
2117
                ],
2118
                14,
2119
                4,
2120
            ),
2121
        )?;
2122
2123
        Ok(())
2124
    }
2125
2126
    #[test]
2127
    fn divide_op_scalar() -> Result<()> {
2128
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2129
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2130
2131
        apply_arithmetic_scalar(
2132
            Arc::new(schema),
2133
            vec![Arc::new(a)],
2134
            Operator::Divide,
2135
            ScalarValue::Int32(Some(2)),
2136
            Arc::new(Int32Array::from(vec![0, 1, 1, 2, 2])),
2137
        )?;
2138
2139
        Ok(())
2140
    }
2141
2142
    #[test]
2143
    fn divide_op_dict_scalar() -> Result<()> {
2144
        let schema = Schema::new(vec![Field::new(
2145
            "a",
2146
            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2147
            true,
2148
        )]);
2149
2150
        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2151
2152
        dict_builder.append(1)?;
2153
        dict_builder.append_null();
2154
        dict_builder.append(2)?;
2155
        dict_builder.append(5)?;
2156
2157
        let a = dict_builder.finish();
2158
2159
        let expected: PrimitiveArray<Int32Type> =
2160
            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);
2161
2162
        apply_arithmetic_scalar(
2163
            Arc::new(schema),
2164
            vec![Arc::new(a)],
2165
            Operator::Divide,
2166
            ScalarValue::Dictionary(
2167
                Box::new(DataType::Int8),
2168
                Box::new(ScalarValue::Int32(Some(2))),
2169
            ),
2170
            Arc::new(expected),
2171
        )?;
2172
2173
        Ok(())
2174
    }
2175
2176
    #[test]
2177
    fn divide_op_dict_scalar_decimal() -> Result<()> {
2178
        let schema = Schema::new(vec![Field::new(
2179
            "a",
2180
            DataType::Dictionary(
2181
                Box::new(DataType::Int8),
2182
                Box::new(DataType::Decimal128(10, 0)),
2183
            ),
2184
            true,
2185
        )]);
2186
2187
        let value = 123;
2188
        let decimal_array = Arc::new(create_decimal_array(
2189
            &[Some(value), None, Some(value - 1), Some(value + 1)],
2190
            10,
2191
            0,
2192
        ));
2193
2194
        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2195
        let a = DictionaryArray::try_new(keys, decimal_array)?;
2196
2197
        let decimal_array = Arc::new(create_decimal_array(
2198
            &[Some(615000), Some(610000), None, Some(620000), Some(615000)],
2199
            14,
2200
            4,
2201
        ));
2202
2203
        apply_arithmetic_scalar(
2204
            Arc::new(schema),
2205
            vec![Arc::new(a)],
2206
            Operator::Divide,
2207
            ScalarValue::Dictionary(
2208
                Box::new(DataType::Int8),
2209
                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2210
            ),
2211
            decimal_array,
2212
        )?;
2213
2214
        Ok(())
2215
    }
2216
2217
    #[test]
2218
    fn modulus_op() -> Result<()> {
2219
        let schema = Arc::new(Schema::new(vec![
2220
            Field::new("a", DataType::Int32, false),
2221
            Field::new("b", DataType::Int32, false),
2222
        ]));
2223
        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2224
        let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32]));
2225
2226
        apply_arithmetic::<Int32Type>(
2227
            schema,
2228
            vec![a, b],
2229
            Operator::Modulo,
2230
            Int32Array::from(vec![0, 0, 2, 8, 0]),
2231
        )?;
2232
2233
        Ok(())
2234
    }
2235
2236
    #[test]
2237
    fn modulus_op_dict() -> Result<()> {
2238
        let schema = Schema::new(vec![
2239
            Field::new(
2240
                "a",
2241
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2242
                true,
2243
            ),
2244
            Field::new(
2245
                "b",
2246
                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2247
                true,
2248
            ),
2249
        ]);
2250
2251
        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2252
2253
        dict_builder.append(1)?;
2254
        dict_builder.append_null();
2255
        dict_builder.append(2)?;
2256
        dict_builder.append(5)?;
2257
        dict_builder.append(0)?;
2258
2259
        let a = dict_builder.finish();
2260
2261
        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2262
        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2263
        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2264
2265
        apply_arithmetic::<Int32Type>(
2266
            Arc::new(schema),
2267
            vec![Arc::new(a), Arc::new(b)],
2268
            Operator::Modulo,
2269
            Int32Array::from(vec![Some(0), None, Some(0), Some(1), Some(0)]),
2270
        )?;
2271
2272
        Ok(())
2273
    }
2274
2275
    #[test]
2276
    fn modulus_op_dict_decimal() -> Result<()> {
2277
        let schema = Schema::new(vec![
2278
            Field::new(
2279
                "a",
2280
                DataType::Dictionary(
2281
                    Box::new(DataType::Int8),
2282
                    Box::new(DataType::Decimal128(10, 0)),
2283
                ),
2284
                true,
2285
            ),
2286
            Field::new(
2287
                "b",
2288
                DataType::Dictionary(
2289
                    Box::new(DataType::Int8),
2290
                    Box::new(DataType::Decimal128(10, 0)),
2291
                ),
2292
                true,
2293
            ),
2294
        ]);
2295
2296
        let value = 123;
2297
        let decimal_array = Arc::new(create_decimal_array(
2298
            &[
2299
                Some(value),
2300
                Some(value + 2),
2301
                Some(value - 1),
2302
                Some(value + 1),
2303
            ],
2304
            10,
2305
            0,
2306
        ));
2307
2308
        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2309
        let a = DictionaryArray::try_new(keys, decimal_array)?;
2310
2311
        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2312
        let decimal_array = Arc::new(create_decimal_array(
2313
            &[
2314
                Some(value + 1),
2315
                Some(value + 3),
2316
                Some(value),
2317
                Some(value + 2),
2318
            ],
2319
            10,
2320
            0,
2321
        ));
2322
        let b = DictionaryArray::try_new(keys, decimal_array)?;
2323
2324
        apply_arithmetic(
2325
            Arc::new(schema),
2326
            vec![Arc::new(a), Arc::new(b)],
2327
            Operator::Modulo,
2328
            create_decimal_array(&[Some(123), None, None, Some(1), Some(0)], 10, 0),
2329
        )?;
2330
2331
        Ok(())
2332
    }
2333
2334
    #[test]
2335
    fn modulus_op_scalar() -> Result<()> {
2336
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2337
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2338
2339
        apply_arithmetic_scalar(
2340
            Arc::new(schema),
2341
            vec![Arc::new(a)],
2342
            Operator::Modulo,
2343
            ScalarValue::Int32(Some(2)),
2344
            Arc::new(Int32Array::from(vec![1, 0, 1, 0, 1])),
2345
        )?;
2346
2347
        Ok(())
2348
    }
2349
2350
    #[test]
2351
    fn modules_op_dict_scalar() -> Result<()> {
2352
        let schema = Schema::new(vec![Field::new(
2353
            "a",
2354
            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2355
            true,
2356
        )]);
2357
2358
        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2359
2360
        dict_builder.append(1)?;
2361
        dict_builder.append_null();
2362
        dict_builder.append(2)?;
2363
        dict_builder.append(5)?;
2364
2365
        let a = dict_builder.finish();
2366
2367
        let expected: PrimitiveArray<Int32Type> =
2368
            PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);
2369
2370
        apply_arithmetic_scalar(
2371
            Arc::new(schema),
2372
            vec![Arc::new(a)],
2373
            Operator::Modulo,
2374
            ScalarValue::Dictionary(
2375
                Box::new(DataType::Int8),
2376
                Box::new(ScalarValue::Int32(Some(2))),
2377
            ),
2378
            Arc::new(expected),
2379
        )?;
2380
2381
        Ok(())
2382
    }
2383
2384
    #[test]
2385
    fn modulus_op_dict_scalar_decimal() -> Result<()> {
2386
        let schema = Schema::new(vec![Field::new(
2387
            "a",
2388
            DataType::Dictionary(
2389
                Box::new(DataType::Int8),
2390
                Box::new(DataType::Decimal128(10, 0)),
2391
            ),
2392
            true,
2393
        )]);
2394
2395
        let value = 123;
2396
        let decimal_array = Arc::new(create_decimal_array(
2397
            &[Some(value), None, Some(value - 1), Some(value + 1)],
2398
            10,
2399
            0,
2400
        ));
2401
2402
        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2403
        let a = DictionaryArray::try_new(keys, decimal_array)?;
2404
2405
        let decimal_array = Arc::new(create_decimal_array(
2406
            &[Some(1), Some(0), None, Some(0), Some(1)],
2407
            10,
2408
            0,
2409
        ));
2410
2411
        apply_arithmetic_scalar(
2412
            Arc::new(schema),
2413
            vec![Arc::new(a)],
2414
            Operator::Modulo,
2415
            ScalarValue::Dictionary(
2416
                Box::new(DataType::Int8),
2417
                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2418
            ),
2419
            decimal_array,
2420
        )?;
2421
2422
        Ok(())
2423
    }
2424
2425
    fn apply_arithmetic<T: ArrowNumericType>(
2426
        schema: SchemaRef,
2427
        data: Vec<ArrayRef>,
2428
        op: Operator,
2429
        expected: PrimitiveArray<T>,
2430
    ) -> Result<()> {
2431
        let arithmetic_op =
2432
            binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?;
2433
        let batch = RecordBatch::try_new(schema, data)?;
2434
        let result = arithmetic_op
2435
            .evaluate(&batch)?
2436
            .into_array(batch.num_rows())
2437
            .expect("Failed to convert to array");
2438
2439
        assert_eq!(result.as_ref(), &expected);
2440
        Ok(())
2441
    }
2442
2443
    fn apply_arithmetic_scalar(
2444
        schema: SchemaRef,
2445
        data: Vec<ArrayRef>,
2446
        op: Operator,
2447
        literal: ScalarValue,
2448
        expected: ArrayRef,
2449
    ) -> Result<()> {
2450
        let lit = Arc::new(Literal::new(literal));
2451
        let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?;
2452
        let batch = RecordBatch::try_new(schema, data)?;
2453
        let result = arithmetic_op
2454
            .evaluate(&batch)?
2455
            .into_array(batch.num_rows())
2456
            .expect("Failed to convert to array");
2457
2458
        assert_eq!(&result, &expected);
2459
        Ok(())
2460
    }
2461
2462
    fn apply_logic_op(
2463
        schema: &SchemaRef,
2464
        left: &ArrayRef,
2465
        right: &ArrayRef,
2466
        op: Operator,
2467
        expected: BooleanArray,
2468
    ) -> Result<()> {
2469
        let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
2470
        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
2471
        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
2472
        let result = op
2473
            .evaluate(&batch)?
2474
            .into_array(batch.num_rows())
2475
            .expect("Failed to convert to array");
2476
2477
        assert_eq!(result.as_ref(), &expected);
2478
        Ok(())
2479
    }
2480
2481
    // Test `scalar <op> arr` produces expected
2482
    fn apply_logic_op_scalar_arr(
2483
        schema: &SchemaRef,
2484
        scalar: &ScalarValue,
2485
        arr: &ArrayRef,
2486
        op: Operator,
2487
        expected: &BooleanArray,
2488
    ) -> Result<()> {
2489
        let scalar = lit(scalar.clone());
2490
        let op = binary_op(scalar, op, col("a", schema)?, schema)?;
2491
        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2492
        let result = op
2493
            .evaluate(&batch)?
2494
            .into_array(batch.num_rows())
2495
            .expect("Failed to convert to array");
2496
        assert_eq!(result.as_ref(), expected);
2497
2498
        Ok(())
2499
    }
2500
2501
    // Test `arr <op> scalar` produces expected
2502
    fn apply_logic_op_arr_scalar(
2503
        schema: &SchemaRef,
2504
        arr: &ArrayRef,
2505
        scalar: &ScalarValue,
2506
        op: Operator,
2507
        expected: &BooleanArray,
2508
    ) -> Result<()> {
2509
        let scalar = lit(scalar.clone());
2510
        let op = binary_op(col("a", schema)?, op, scalar, schema)?;
2511
        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2512
        let result = op
2513
            .evaluate(&batch)?
2514
            .into_array(batch.num_rows())
2515
            .expect("Failed to convert to array");
2516
        assert_eq!(result.as_ref(), expected);
2517
2518
        Ok(())
2519
    }
2520
2521
    #[test]
2522
    fn and_with_nulls_op() -> Result<()> {
2523
        let schema = Schema::new(vec![
2524
            Field::new("a", DataType::Boolean, true),
2525
            Field::new("b", DataType::Boolean, true),
2526
        ]);
2527
        let a = Arc::new(BooleanArray::from(vec![
2528
            Some(true),
2529
            Some(false),
2530
            None,
2531
            Some(true),
2532
            Some(false),
2533
            None,
2534
            Some(true),
2535
            Some(false),
2536
            None,
2537
        ])) as ArrayRef;
2538
        let b = Arc::new(BooleanArray::from(vec![
2539
            Some(true),
2540
            Some(true),
2541
            Some(true),
2542
            Some(false),
2543
            Some(false),
2544
            Some(false),
2545
            None,
2546
            None,
2547
            None,
2548
        ])) as ArrayRef;
2549
2550
        let expected = BooleanArray::from(vec![
2551
            Some(true),
2552
            Some(false),
2553
            None,
2554
            Some(false),
2555
            Some(false),
2556
            Some(false),
2557
            None,
2558
            Some(false),
2559
            None,
2560
        ]);
2561
        apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, expected)?;
2562
2563
        Ok(())
2564
    }
2565
2566
    #[test]
2567
    fn regex_with_nulls() -> Result<()> {
2568
        let schema = Schema::new(vec![
2569
            Field::new("a", DataType::Utf8, true),
2570
            Field::new("b", DataType::Utf8, true),
2571
        ]);
2572
        let a = Arc::new(StringArray::from(vec![
2573
            Some("abc"),
2574
            None,
2575
            Some("abc"),
2576
            None,
2577
            Some("abc"),
2578
        ])) as ArrayRef;
2579
        let b = Arc::new(StringArray::from(vec![
2580
            Some("^a"),
2581
            Some("^A"),
2582
            None,
2583
            None,
2584
            Some("^(b|c)"),
2585
        ])) as ArrayRef;
2586
2587
        let regex_expected =
2588
            BooleanArray::from(vec![Some(true), None, None, None, Some(false)]);
2589
        let regex_not_expected =
2590
            BooleanArray::from(vec![Some(false), None, None, None, Some(true)]);
2591
        apply_logic_op(
2592
            &Arc::new(schema.clone()),
2593
            &a,
2594
            &b,
2595
            Operator::RegexMatch,
2596
            regex_expected.clone(),
2597
        )?;
2598
        apply_logic_op(
2599
            &Arc::new(schema.clone()),
2600
            &a,
2601
            &b,
2602
            Operator::RegexIMatch,
2603
            regex_expected.clone(),
2604
        )?;
2605
        apply_logic_op(
2606
            &Arc::new(schema.clone()),
2607
            &a,
2608
            &b,
2609
            Operator::RegexNotMatch,
2610
            regex_not_expected.clone(),
2611
        )?;
2612
        apply_logic_op(
2613
            &Arc::new(schema),
2614
            &a,
2615
            &b,
2616
            Operator::RegexNotIMatch,
2617
            regex_not_expected.clone(),
2618
        )?;
2619
2620
        let schema = Schema::new(vec![
2621
            Field::new("a", DataType::LargeUtf8, true),
2622
            Field::new("b", DataType::LargeUtf8, true),
2623
        ]);
2624
        let a = Arc::new(LargeStringArray::from(vec![
2625
            Some("abc"),
2626
            None,
2627
            Some("abc"),
2628
            None,
2629
            Some("abc"),
2630
        ])) as ArrayRef;
2631
        let b = Arc::new(LargeStringArray::from(vec![
2632
            Some("^a"),
2633
            Some("^A"),
2634
            None,
2635
            None,
2636
            Some("^(b|c)"),
2637
        ])) as ArrayRef;
2638
2639
        apply_logic_op(
2640
            &Arc::new(schema.clone()),
2641
            &a,
2642
            &b,
2643
            Operator::RegexMatch,
2644
            regex_expected.clone(),
2645
        )?;
2646
        apply_logic_op(
2647
            &Arc::new(schema.clone()),
2648
            &a,
2649
            &b,
2650
            Operator::RegexIMatch,
2651
            regex_expected,
2652
        )?;
2653
        apply_logic_op(
2654
            &Arc::new(schema.clone()),
2655
            &a,
2656
            &b,
2657
            Operator::RegexNotMatch,
2658
            regex_not_expected.clone(),
2659
        )?;
2660
        apply_logic_op(
2661
            &Arc::new(schema),
2662
            &a,
2663
            &b,
2664
            Operator::RegexNotIMatch,
2665
            regex_not_expected,
2666
        )?;
2667
2668
        Ok(())
2669
    }
2670
2671
    #[test]
2672
    fn or_with_nulls_op() -> Result<()> {
2673
        let schema = Schema::new(vec![
2674
            Field::new("a", DataType::Boolean, true),
2675
            Field::new("b", DataType::Boolean, true),
2676
        ]);
2677
        let a = Arc::new(BooleanArray::from(vec![
2678
            Some(true),
2679
            Some(false),
2680
            None,
2681
            Some(true),
2682
            Some(false),
2683
            None,
2684
            Some(true),
2685
            Some(false),
2686
            None,
2687
        ])) as ArrayRef;
2688
        let b = Arc::new(BooleanArray::from(vec![
2689
            Some(true),
2690
            Some(true),
2691
            Some(true),
2692
            Some(false),
2693
            Some(false),
2694
            Some(false),
2695
            None,
2696
            None,
2697
            None,
2698
        ])) as ArrayRef;
2699
2700
        let expected = BooleanArray::from(vec![
2701
            Some(true),
2702
            Some(true),
2703
            Some(true),
2704
            Some(true),
2705
            Some(false),
2706
            None,
2707
            Some(true),
2708
            None,
2709
            None,
2710
        ]);
2711
        apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, expected)?;
2712
2713
        Ok(())
2714
    }
2715
2716
    /// Returns (schema, a: BooleanArray, b: BooleanArray) with all possible inputs
2717
    ///
2718
    /// a: [true, true, true,  NULL, NULL, NULL,  false, false, false]
2719
    /// b: [true, NULL, false, true, NULL, false, true,  NULL,  false]
2720
    fn bool_test_arrays() -> (SchemaRef, ArrayRef, ArrayRef) {
2721
        let schema = Schema::new(vec![
2722
            Field::new("a", DataType::Boolean, true),
2723
            Field::new("b", DataType::Boolean, true),
2724
        ]);
2725
        let a: BooleanArray = [
2726
            Some(true),
2727
            Some(true),
2728
            Some(true),
2729
            None,
2730
            None,
2731
            None,
2732
            Some(false),
2733
            Some(false),
2734
            Some(false),
2735
        ]
2736
        .iter()
2737
        .collect();
2738
        let b: BooleanArray = [
2739
            Some(true),
2740
            None,
2741
            Some(false),
2742
            Some(true),
2743
            None,
2744
            Some(false),
2745
            Some(true),
2746
            None,
2747
            Some(false),
2748
        ]
2749
        .iter()
2750
        .collect();
2751
        (Arc::new(schema), Arc::new(a), Arc::new(b))
2752
    }
2753
2754
    /// Returns (schema, BooleanArray) with [true, NULL, false]
2755
    fn scalar_bool_test_array() -> (SchemaRef, ArrayRef) {
2756
        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
2757
        let a: BooleanArray = [Some(true), None, Some(false)].iter().collect();
2758
        (Arc::new(schema), Arc::new(a))
2759
    }
2760
2761
    #[test]
2762
    fn eq_op_bool() {
2763
        let (schema, a, b) = bool_test_arrays();
2764
        let expected = [
2765
            Some(true),
2766
            None,
2767
            Some(false),
2768
            None,
2769
            None,
2770
            None,
2771
            Some(false),
2772
            None,
2773
            Some(true),
2774
        ]
2775
        .iter()
2776
        .collect();
2777
        apply_logic_op(&schema, &a, &b, Operator::Eq, expected).unwrap();
2778
    }
2779
2780
    #[test]
2781
    fn eq_op_bool_scalar() {
2782
        let (schema, a) = scalar_bool_test_array();
2783
        let expected = [Some(true), None, Some(false)].iter().collect();
2784
        apply_logic_op_scalar_arr(
2785
            &schema,
2786
            &ScalarValue::from(true),
2787
            &a,
2788
            Operator::Eq,
2789
            &expected,
2790
        )
2791
        .unwrap();
2792
        apply_logic_op_arr_scalar(
2793
            &schema,
2794
            &a,
2795
            &ScalarValue::from(true),
2796
            Operator::Eq,
2797
            &expected,
2798
        )
2799
        .unwrap();
2800
2801
        let expected = [Some(false), None, Some(true)].iter().collect();
2802
        apply_logic_op_scalar_arr(
2803
            &schema,
2804
            &ScalarValue::from(false),
2805
            &a,
2806
            Operator::Eq,
2807
            &expected,
2808
        )
2809
        .unwrap();
2810
        apply_logic_op_arr_scalar(
2811
            &schema,
2812
            &a,
2813
            &ScalarValue::from(false),
2814
            Operator::Eq,
2815
            &expected,
2816
        )
2817
        .unwrap();
2818
    }
2819
2820
    #[test]
2821
    fn neq_op_bool() {
2822
        let (schema, a, b) = bool_test_arrays();
2823
        let expected = [
2824
            Some(false),
2825
            None,
2826
            Some(true),
2827
            None,
2828
            None,
2829
            None,
2830
            Some(true),
2831
            None,
2832
            Some(false),
2833
        ]
2834
        .iter()
2835
        .collect();
2836
        apply_logic_op(&schema, &a, &b, Operator::NotEq, expected).unwrap();
2837
    }
2838
2839
    #[test]
2840
    fn neq_op_bool_scalar() {
2841
        let (schema, a) = scalar_bool_test_array();
2842
        let expected = [Some(false), None, Some(true)].iter().collect();
2843
        apply_logic_op_scalar_arr(
2844
            &schema,
2845
            &ScalarValue::from(true),
2846
            &a,
2847
            Operator::NotEq,
2848
            &expected,
2849
        )
2850
        .unwrap();
2851
        apply_logic_op_arr_scalar(
2852
            &schema,
2853
            &a,
2854
            &ScalarValue::from(true),
2855
            Operator::NotEq,
2856
            &expected,
2857
        )
2858
        .unwrap();
2859
2860
        let expected = [Some(true), None, Some(false)].iter().collect();
2861
        apply_logic_op_scalar_arr(
2862
            &schema,
2863
            &ScalarValue::from(false),
2864
            &a,
2865
            Operator::NotEq,
2866
            &expected,
2867
        )
2868
        .unwrap();
2869
        apply_logic_op_arr_scalar(
2870
            &schema,
2871
            &a,
2872
            &ScalarValue::from(false),
2873
            Operator::NotEq,
2874
            &expected,
2875
        )
2876
        .unwrap();
2877
    }
2878
2879
    #[test]
2880
    fn lt_op_bool() {
2881
        let (schema, a, b) = bool_test_arrays();
2882
        let expected = [
2883
            Some(false),
2884
            None,
2885
            Some(false),
2886
            None,
2887
            None,
2888
            None,
2889
            Some(true),
2890
            None,
2891
            Some(false),
2892
        ]
2893
        .iter()
2894
        .collect();
2895
        apply_logic_op(&schema, &a, &b, Operator::Lt, expected).unwrap();
2896
    }
2897
2898
    #[test]
2899
    fn lt_op_bool_scalar() {
2900
        let (schema, a) = scalar_bool_test_array();
2901
        let expected = [Some(false), None, Some(false)].iter().collect();
2902
        apply_logic_op_scalar_arr(
2903
            &schema,
2904
            &ScalarValue::from(true),
2905
            &a,
2906
            Operator::Lt,
2907
            &expected,
2908
        )
2909
        .unwrap();
2910
2911
        let expected = [Some(false), None, Some(true)].iter().collect();
2912
        apply_logic_op_arr_scalar(
2913
            &schema,
2914
            &a,
2915
            &ScalarValue::from(true),
2916
            Operator::Lt,
2917
            &expected,
2918
        )
2919
        .unwrap();
2920
2921
        let expected = [Some(true), None, Some(false)].iter().collect();
2922
        apply_logic_op_scalar_arr(
2923
            &schema,
2924
            &ScalarValue::from(false),
2925
            &a,
2926
            Operator::Lt,
2927
            &expected,
2928
        )
2929
        .unwrap();
2930
2931
        let expected = [Some(false), None, Some(false)].iter().collect();
2932
        apply_logic_op_arr_scalar(
2933
            &schema,
2934
            &a,
2935
            &ScalarValue::from(false),
2936
            Operator::Lt,
2937
            &expected,
2938
        )
2939
        .unwrap();
2940
    }
2941
2942
    #[test]
2943
    fn lt_eq_op_bool() {
2944
        let (schema, a, b) = bool_test_arrays();
2945
        let expected = [
2946
            Some(true),
2947
            None,
2948
            Some(false),
2949
            None,
2950
            None,
2951
            None,
2952
            Some(true),
2953
            None,
2954
            Some(true),
2955
        ]
2956
        .iter()
2957
        .collect();
2958
        apply_logic_op(&schema, &a, &b, Operator::LtEq, expected).unwrap();
2959
    }
2960
2961
    #[test]
2962
    fn lt_eq_op_bool_scalar() {
2963
        let (schema, a) = scalar_bool_test_array();
2964
        let expected = [Some(true), None, Some(false)].iter().collect();
2965
        apply_logic_op_scalar_arr(
2966
            &schema,
2967
            &ScalarValue::from(true),
2968
            &a,
2969
            Operator::LtEq,
2970
            &expected,
2971
        )
2972
        .unwrap();
2973
2974
        let expected = [Some(true), None, Some(true)].iter().collect();
2975
        apply_logic_op_arr_scalar(
2976
            &schema,
2977
            &a,
2978
            &ScalarValue::from(true),
2979
            Operator::LtEq,
2980
            &expected,
2981
        )
2982
        .unwrap();
2983
2984
        let expected = [Some(true), None, Some(true)].iter().collect();
2985
        apply_logic_op_scalar_arr(
2986
            &schema,
2987
            &ScalarValue::from(false),
2988
            &a,
2989
            Operator::LtEq,
2990
            &expected,
2991
        )
2992
        .unwrap();
2993
2994
        let expected = [Some(false), None, Some(true)].iter().collect();
2995
        apply_logic_op_arr_scalar(
2996
            &schema,
2997
            &a,
2998
            &ScalarValue::from(false),
2999
            Operator::LtEq,
3000
            &expected,
3001
        )
3002
        .unwrap();
3003
    }
3004
3005
    #[test]
3006
    fn gt_op_bool() {
3007
        let (schema, a, b) = bool_test_arrays();
3008
        let expected = [
3009
            Some(false),
3010
            None,
3011
            Some(true),
3012
            None,
3013
            None,
3014
            None,
3015
            Some(false),
3016
            None,
3017
            Some(false),
3018
        ]
3019
        .iter()
3020
        .collect();
3021
        apply_logic_op(&schema, &a, &b, Operator::Gt, expected).unwrap();
3022
    }
3023
3024
    #[test]
3025
    fn gt_op_bool_scalar() {
3026
        let (schema, a) = scalar_bool_test_array();
3027
        let expected = [Some(false), None, Some(true)].iter().collect();
3028
        apply_logic_op_scalar_arr(
3029
            &schema,
3030
            &ScalarValue::from(true),
3031
            &a,
3032
            Operator::Gt,
3033
            &expected,
3034
        )
3035
        .unwrap();
3036
3037
        let expected = [Some(false), None, Some(false)].iter().collect();
3038
        apply_logic_op_arr_scalar(
3039
            &schema,
3040
            &a,
3041
            &ScalarValue::from(true),
3042
            Operator::Gt,
3043
            &expected,
3044
        )
3045
        .unwrap();
3046
3047
        let expected = [Some(false), None, Some(false)].iter().collect();
3048
        apply_logic_op_scalar_arr(
3049
            &schema,
3050
            &ScalarValue::from(false),
3051
            &a,
3052
            Operator::Gt,
3053
            &expected,
3054
        )
3055
        .unwrap();
3056
3057
        let expected = [Some(true), None, Some(false)].iter().collect();
3058
        apply_logic_op_arr_scalar(
3059
            &schema,
3060
            &a,
3061
            &ScalarValue::from(false),
3062
            Operator::Gt,
3063
            &expected,
3064
        )
3065
        .unwrap();
3066
    }
3067
3068
    #[test]
3069
    fn gt_eq_op_bool() {
3070
        let (schema, a, b) = bool_test_arrays();
3071
        let expected = [
3072
            Some(true),
3073
            None,
3074
            Some(true),
3075
            None,
3076
            None,
3077
            None,
3078
            Some(false),
3079
            None,
3080
            Some(true),
3081
        ]
3082
        .iter()
3083
        .collect();
3084
        apply_logic_op(&schema, &a, &b, Operator::GtEq, expected).unwrap();
3085
    }
3086
3087
    #[test]
3088
    fn gt_eq_op_bool_scalar() {
3089
        let (schema, a) = scalar_bool_test_array();
3090
        let expected = [Some(true), None, Some(true)].iter().collect();
3091
        apply_logic_op_scalar_arr(
3092
            &schema,
3093
            &ScalarValue::from(true),
3094
            &a,
3095
            Operator::GtEq,
3096
            &expected,
3097
        )
3098
        .unwrap();
3099
3100
        let expected = [Some(true), None, Some(false)].iter().collect();
3101
        apply_logic_op_arr_scalar(
3102
            &schema,
3103
            &a,
3104
            &ScalarValue::from(true),
3105
            Operator::GtEq,
3106
            &expected,
3107
        )
3108
        .unwrap();
3109
3110
        let expected = [Some(false), None, Some(true)].iter().collect();
3111
        apply_logic_op_scalar_arr(
3112
            &schema,
3113
            &ScalarValue::from(false),
3114
            &a,
3115
            Operator::GtEq,
3116
            &expected,
3117
        )
3118
        .unwrap();
3119
3120
        let expected = [Some(true), None, Some(true)].iter().collect();
3121
        apply_logic_op_arr_scalar(
3122
            &schema,
3123
            &a,
3124
            &ScalarValue::from(false),
3125
            Operator::GtEq,
3126
            &expected,
3127
        )
3128
        .unwrap();
3129
    }
3130
3131
    #[test]
3132
    fn is_distinct_from_op_bool() {
3133
        let (schema, a, b) = bool_test_arrays();
3134
        let expected = [
3135
            Some(false),
3136
            Some(true),
3137
            Some(true),
3138
            Some(true),
3139
            Some(false),
3140
            Some(true),
3141
            Some(true),
3142
            Some(true),
3143
            Some(false),
3144
        ]
3145
        .iter()
3146
        .collect();
3147
        apply_logic_op(&schema, &a, &b, Operator::IsDistinctFrom, expected).unwrap();
3148
    }
3149
3150
    #[test]
3151
    fn is_not_distinct_from_op_bool() {
3152
        let (schema, a, b) = bool_test_arrays();
3153
        let expected = [
3154
            Some(true),
3155
            Some(false),
3156
            Some(false),
3157
            Some(false),
3158
            Some(true),
3159
            Some(false),
3160
            Some(false),
3161
            Some(false),
3162
            Some(true),
3163
        ]
3164
        .iter()
3165
        .collect();
3166
        apply_logic_op(&schema, &a, &b, Operator::IsNotDistinctFrom, expected).unwrap();
3167
    }
3168
3169
    #[test]
3170
    fn relatively_deeply_nested() {
3171
        // Reproducer for https://github.com/apache/datafusion/issues/419
3172
3173
        // where even relatively shallow binary expressions overflowed
3174
        // the stack in debug builds
3175
3176
        let input: Vec<_> = vec![1, 2, 3, 4, 5].into_iter().map(Some).collect();
3177
        let a: Int32Array = input.iter().collect();
3178
3179
        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(a) as _)]).unwrap();
3180
        let schema = batch.schema();
3181
3182
        // build a left deep tree ((((a + a) + a) + a ....
3183
        let tree_depth: i32 = 100;
3184
        let expr = (0..tree_depth)
3185
            .map(|_| col("a", schema.as_ref()).unwrap())
3186
            .reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap())
3187
            .unwrap();
3188
3189
        let result = expr
3190
            .evaluate(&batch)
3191
            .expect("evaluation")
3192
            .into_array(batch.num_rows())
3193
            .expect("Failed to convert to array");
3194
3195
        let expected: Int32Array = input
3196
            .into_iter()
3197
            .map(|i| i.map(|i| i * tree_depth))
3198
            .collect();
3199
        assert_eq!(result.as_ref(), &expected);
3200
    }
3201
3202
    fn create_decimal_array(
3203
        array: &[Option<i128>],
3204
        precision: u8,
3205
        scale: i8,
3206
    ) -> Decimal128Array {
3207
        let mut decimal_builder = Decimal128Builder::with_capacity(array.len());
3208
        for value in array.iter().copied() {
3209
            decimal_builder.append_option(value)
3210
        }
3211
        decimal_builder
3212
            .finish()
3213
            .with_precision_and_scale(precision, scale)
3214
            .unwrap()
3215
    }
3216
3217
    #[test]
3218
    fn comparison_dict_decimal_scalar_expr_test() -> Result<()> {
3219
        // scalar of decimal compare with dictionary decimal array
3220
        let value_i128 = 123;
3221
        let decimal_scalar = ScalarValue::Dictionary(
3222
            Box::new(DataType::Int8),
3223
            Box::new(ScalarValue::Decimal128(Some(value_i128), 25, 3)),
3224
        );
3225
        let schema = Arc::new(Schema::new(vec![Field::new(
3226
            "a",
3227
            DataType::Dictionary(
3228
                Box::new(DataType::Int8),
3229
                Box::new(DataType::Decimal128(25, 3)),
3230
            ),
3231
            true,
3232
        )]));
3233
        let decimal_array = Arc::new(create_decimal_array(
3234
            &[
3235
                Some(value_i128),
3236
                None,
3237
                Some(value_i128 - 1),
3238
                Some(value_i128 + 1),
3239
            ],
3240
            25,
3241
            3,
3242
        ));
3243
3244
        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
3245
        let dictionary =
3246
            Arc::new(DictionaryArray::try_new(keys, decimal_array)?) as ArrayRef;
3247
3248
        // array = scalar
3249
        apply_logic_op_arr_scalar(
3250
            &schema,
3251
            &dictionary,
3252
            &decimal_scalar,
3253
            Operator::Eq,
3254
            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3255
        )
3256
        .unwrap();
3257
        // array != scalar
3258
        apply_logic_op_arr_scalar(
3259
            &schema,
3260
            &dictionary,
3261
            &decimal_scalar,
3262
            Operator::NotEq,
3263
            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3264
        )
3265
        .unwrap();
3266
        //  array < scalar
3267
        apply_logic_op_arr_scalar(
3268
            &schema,
3269
            &dictionary,
3270
            &decimal_scalar,
3271
            Operator::Lt,
3272
            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3273
        )
3274
        .unwrap();
3275
3276
        //  array <= scalar
3277
        apply_logic_op_arr_scalar(
3278
            &schema,
3279
            &dictionary,
3280
            &decimal_scalar,
3281
            Operator::LtEq,
3282
            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3283
        )
3284
        .unwrap();
3285
        // array > scalar
3286
        apply_logic_op_arr_scalar(
3287
            &schema,
3288
            &dictionary,
3289
            &decimal_scalar,
3290
            Operator::Gt,
3291
            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3292
        )
3293
        .unwrap();
3294
3295
        // array >= scalar
3296
        apply_logic_op_arr_scalar(
3297
            &schema,
3298
            &dictionary,
3299
            &decimal_scalar,
3300
            Operator::GtEq,
3301
            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3302
        )
3303
        .unwrap();
3304
3305
        Ok(())
3306
    }
3307
3308
    #[test]
3309
    fn comparison_decimal_expr_test() -> Result<()> {
3310
        // scalar of decimal compare with decimal array
3311
        let value_i128 = 123;
3312
        let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3);
3313
        let schema = Arc::new(Schema::new(vec![Field::new(
3314
            "a",
3315
            DataType::Decimal128(25, 3),
3316
            true,
3317
        )]));
3318
        let decimal_array = Arc::new(create_decimal_array(
3319
            &[
3320
                Some(value_i128),
3321
                None,
3322
                Some(value_i128 - 1),
3323
                Some(value_i128 + 1),
3324
            ],
3325
            25,
3326
            3,
3327
        )) as ArrayRef;
3328
        // array = scalar
3329
        apply_logic_op_arr_scalar(
3330
            &schema,
3331
            &decimal_array,
3332
            &decimal_scalar,
3333
            Operator::Eq,
3334
            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3335
        )
3336
        .unwrap();
3337
        // array != scalar
3338
        apply_logic_op_arr_scalar(
3339
            &schema,
3340
            &decimal_array,
3341
            &decimal_scalar,
3342
            Operator::NotEq,
3343
            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3344
        )
3345
        .unwrap();
3346
        //  array < scalar
3347
        apply_logic_op_arr_scalar(
3348
            &schema,
3349
            &decimal_array,
3350
            &decimal_scalar,
3351
            Operator::Lt,
3352
            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3353
        )
3354
        .unwrap();
3355
3356
        //  array <= scalar
3357
        apply_logic_op_arr_scalar(
3358
            &schema,
3359
            &decimal_array,
3360
            &decimal_scalar,
3361
            Operator::LtEq,
3362
            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3363
        )
3364
        .unwrap();
3365
        // array > scalar
3366
        apply_logic_op_arr_scalar(
3367
            &schema,
3368
            &decimal_array,
3369
            &decimal_scalar,
3370
            Operator::Gt,
3371
            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3372
        )
3373
        .unwrap();
3374
3375
        // array >= scalar
3376
        apply_logic_op_arr_scalar(
3377
            &schema,
3378
            &decimal_array,
3379
            &decimal_scalar,
3380
            Operator::GtEq,
3381
            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3382
        )
3383
        .unwrap();
3384
3385
        // scalar of different data type with decimal array
3386
        let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3);
3387
        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
3388
        // scalar == array
3389
        apply_logic_op_scalar_arr(
3390
            &schema,
3391
            &decimal_scalar,
3392
            &(Arc::new(Int64Array::from(vec![Some(124), None])) as ArrayRef),
3393
            Operator::Eq,
3394
            &BooleanArray::from(vec![Some(false), None]),
3395
        )
3396
        .unwrap();
3397
3398
        // array != scalar
3399
        apply_logic_op_arr_scalar(
3400
            &schema,
3401
            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(1)])) as ArrayRef),
3402
            &decimal_scalar,
3403
            Operator::NotEq,
3404
            &BooleanArray::from(vec![Some(true), None, Some(true)]),
3405
        )
3406
        .unwrap();
3407
3408
        // array < scalar
3409
        apply_logic_op_arr_scalar(
3410
            &schema,
3411
            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3412
            &decimal_scalar,
3413
            Operator::Lt,
3414
            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3415
        )
3416
        .unwrap();
3417
3418
        // array > scalar
3419
        apply_logic_op_arr_scalar(
3420
            &schema,
3421
            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3422
            &decimal_scalar,
3423
            Operator::Gt,
3424
            &BooleanArray::from(vec![Some(false), None, Some(true)]),
3425
        )
3426
        .unwrap();
3427
3428
        let schema =
3429
            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
3430
        // array == scalar
3431
        apply_logic_op_arr_scalar(
3432
            &schema,
3433
            &(Arc::new(Float64Array::from(vec![Some(123.456), None, Some(123.457)]))
3434
                as ArrayRef),
3435
            &decimal_scalar,
3436
            Operator::Eq,
3437
            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3438
        )
3439
        .unwrap();
3440
3441
        // array <= scalar
3442
        apply_logic_op_arr_scalar(
3443
            &schema,
3444
            &(Arc::new(Float64Array::from(vec![
3445
                Some(123.456),
3446
                None,
3447
                Some(123.457),
3448
                Some(123.45),
3449
            ])) as ArrayRef),
3450
            &decimal_scalar,
3451
            Operator::LtEq,
3452
            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3453
        )
3454
        .unwrap();
3455
        // array >= scalar
3456
        apply_logic_op_arr_scalar(
3457
            &schema,
3458
            &(Arc::new(Float64Array::from(vec![
3459
                Some(123.456),
3460
                None,
3461
                Some(123.457),
3462
                Some(123.45),
3463
            ])) as ArrayRef),
3464
            &decimal_scalar,
3465
            Operator::GtEq,
3466
            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3467
        )
3468
        .unwrap();
3469
3470
        let value: i128 = 123;
3471
        let decimal_array = Arc::new(create_decimal_array(
3472
            &[Some(value), None, Some(value - 1), Some(value + 1)],
3473
            10,
3474
            0,
3475
        )) as ArrayRef;
3476
3477
        // comparison array op for decimal array
3478
        let schema = Arc::new(Schema::new(vec![
3479
            Field::new("a", DataType::Decimal128(10, 0), true),
3480
            Field::new("b", DataType::Decimal128(10, 0), true),
3481
        ]));
3482
        let right_decimal_array = Arc::new(create_decimal_array(
3483
            &[
3484
                Some(value - 1),
3485
                Some(value),
3486
                Some(value + 1),
3487
                Some(value + 1),
3488
            ],
3489
            10,
3490
            0,
3491
        )) as ArrayRef;
3492
3493
        apply_logic_op(
3494
            &schema,
3495
            &decimal_array,
3496
            &right_decimal_array,
3497
            Operator::Eq,
3498
            BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3499
        )
3500
        .unwrap();
3501
3502
        apply_logic_op(
3503
            &schema,
3504
            &decimal_array,
3505
            &right_decimal_array,
3506
            Operator::NotEq,
3507
            BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3508
        )
3509
        .unwrap();
3510
3511
        apply_logic_op(
3512
            &schema,
3513
            &decimal_array,
3514
            &right_decimal_array,
3515
            Operator::Lt,
3516
            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3517
        )
3518
        .unwrap();
3519
3520
        apply_logic_op(
3521
            &schema,
3522
            &decimal_array,
3523
            &right_decimal_array,
3524
            Operator::LtEq,
3525
            BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3526
        )
3527
        .unwrap();
3528
3529
        apply_logic_op(
3530
            &schema,
3531
            &decimal_array,
3532
            &right_decimal_array,
3533
            Operator::Gt,
3534
            BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3535
        )
3536
        .unwrap();
3537
3538
        apply_logic_op(
3539
            &schema,
3540
            &decimal_array,
3541
            &right_decimal_array,
3542
            Operator::GtEq,
3543
            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3544
        )
3545
        .unwrap();
3546
3547
        // compare decimal array with other array type
3548
        let value: i64 = 123;
3549
        let schema = Arc::new(Schema::new(vec![
3550
            Field::new("a", DataType::Int64, true),
3551
            Field::new("b", DataType::Decimal128(10, 0), true),
3552
        ]));
3553
3554
        let int64_array = Arc::new(Int64Array::from(vec![
3555
            Some(value),
3556
            Some(value - 1),
3557
            Some(value),
3558
            Some(value + 1),
3559
        ])) as ArrayRef;
3560
3561
        // eq: int64array == decimal array
3562
        apply_logic_op(
3563
            &schema,
3564
            &int64_array,
3565
            &decimal_array,
3566
            Operator::Eq,
3567
            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3568
        )
3569
        .unwrap();
3570
        // neq: int64array != decimal array
3571
        apply_logic_op(
3572
            &schema,
3573
            &int64_array,
3574
            &decimal_array,
3575
            Operator::NotEq,
3576
            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3577
        )
3578
        .unwrap();
3579
3580
        let schema = Arc::new(Schema::new(vec![
3581
            Field::new("a", DataType::Float64, true),
3582
            Field::new("b", DataType::Decimal128(10, 2), true),
3583
        ]));
3584
3585
        let value: i128 = 123;
3586
        let decimal_array = Arc::new(create_decimal_array(
3587
            &[
3588
                Some(value), // 1.23
3589
                None,
3590
                Some(value - 1), // 1.22
3591
                Some(value + 1), // 1.24
3592
            ],
3593
            10,
3594
            2,
3595
        )) as ArrayRef;
3596
        let float64_array = Arc::new(Float64Array::from(vec![
3597
            Some(1.23),
3598
            Some(1.22),
3599
            Some(1.23),
3600
            Some(1.24),
3601
        ])) as ArrayRef;
3602
        // lt: float64array < decimal array
3603
        apply_logic_op(
3604
            &schema,
3605
            &float64_array,
3606
            &decimal_array,
3607
            Operator::Lt,
3608
            BooleanArray::from(vec![Some(false), None, Some(false), Some(false)]),
3609
        )
3610
        .unwrap();
3611
        // lt_eq: float64array <= decimal array
3612
        apply_logic_op(
3613
            &schema,
3614
            &float64_array,
3615
            &decimal_array,
3616
            Operator::LtEq,
3617
            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3618
        )
3619
        .unwrap();
3620
        // gt: float64array > decimal array
3621
        apply_logic_op(
3622
            &schema,
3623
            &float64_array,
3624
            &decimal_array,
3625
            Operator::Gt,
3626
            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3627
        )
3628
        .unwrap();
3629
        apply_logic_op(
3630
            &schema,
3631
            &float64_array,
3632
            &decimal_array,
3633
            Operator::GtEq,
3634
            BooleanArray::from(vec![Some(true), None, Some(true), Some(true)]),
3635
        )
3636
        .unwrap();
3637
        // is distinct: float64array is distinct decimal array
3638
        // TODO: now we do not refactor the `is distinct or is not distinct` rule of coercion.
3639
        // traced by https://github.com/apache/datafusion/issues/1590
3640
        // the decimal array will be casted to float64array
3641
        apply_logic_op(
3642
            &schema,
3643
            &float64_array,
3644
            &decimal_array,
3645
            Operator::IsDistinctFrom,
3646
            BooleanArray::from(vec![Some(false), Some(true), Some(true), Some(false)]),
3647
        )
3648
        .unwrap();
3649
        // is not distinct
3650
        apply_logic_op(
3651
            &schema,
3652
            &float64_array,
3653
            &decimal_array,
3654
            Operator::IsNotDistinctFrom,
3655
            BooleanArray::from(vec![Some(true), Some(false), Some(false), Some(true)]),
3656
        )
3657
        .unwrap();
3658
3659
        Ok(())
3660
    }
3661
3662
    fn apply_decimal_arithmetic_op(
3663
        schema: &SchemaRef,
3664
        left: &ArrayRef,
3665
        right: &ArrayRef,
3666
        op: Operator,
3667
        expected: ArrayRef,
3668
    ) -> Result<()> {
3669
        let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
3670
        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
3671
        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
3672
        let result = arithmetic_op
3673
            .evaluate(&batch)?
3674
            .into_array(batch.num_rows())
3675
            .expect("Failed to convert to array");
3676
3677
        assert_eq!(result.as_ref(), expected.as_ref());
3678
        Ok(())
3679
    }
3680
3681
    #[test]
3682
    fn arithmetic_decimal_expr_test() -> Result<()> {
3683
        let schema = Arc::new(Schema::new(vec![
3684
            Field::new("a", DataType::Int32, true),
3685
            Field::new("b", DataType::Decimal128(10, 2), true),
3686
        ]));
3687
        let value: i128 = 123;
3688
        let decimal_array = Arc::new(create_decimal_array(
3689
            &[
3690
                Some(value), // 1.23
3691
                None,
3692
                Some(value - 1), // 1.22
3693
                Some(value + 1), // 1.24
3694
            ],
3695
            10,
3696
            2,
3697
        )) as ArrayRef;
3698
        let int32_array = Arc::new(Int32Array::from(vec![
3699
            Some(123),
3700
            Some(122),
3701
            Some(123),
3702
            Some(124),
3703
        ])) as ArrayRef;
3704
3705
        // add: Int32array add decimal array
3706
        let expect = Arc::new(create_decimal_array(
3707
            &[Some(12423), None, Some(12422), Some(12524)],
3708
            13,
3709
            2,
3710
        )) as ArrayRef;
3711
        apply_decimal_arithmetic_op(
3712
            &schema,
3713
            &int32_array,
3714
            &decimal_array,
3715
            Operator::Plus,
3716
            expect,
3717
        )
3718
        .unwrap();
3719
3720
        // subtract: decimal array subtract int32 array
3721
        let schema = Arc::new(Schema::new(vec![
3722
            Field::new("a", DataType::Decimal128(10, 2), true),
3723
            Field::new("b", DataType::Int32, true),
3724
        ]));
3725
        let expect = Arc::new(create_decimal_array(
3726
            &[Some(-12177), None, Some(-12178), Some(-12276)],
3727
            13,
3728
            2,
3729
        )) as ArrayRef;
3730
        apply_decimal_arithmetic_op(
3731
            &schema,
3732
            &decimal_array,
3733
            &int32_array,
3734
            Operator::Minus,
3735
            expect,
3736
        )
3737
        .unwrap();
3738
3739
        // multiply: decimal array multiply int32 array
3740
        let expect = Arc::new(create_decimal_array(
3741
            &[Some(15129), None, Some(15006), Some(15376)],
3742
            21,
3743
            2,
3744
        )) as ArrayRef;
3745
        apply_decimal_arithmetic_op(
3746
            &schema,
3747
            &decimal_array,
3748
            &int32_array,
3749
            Operator::Multiply,
3750
            expect,
3751
        )
3752
        .unwrap();
3753
3754
        // divide: int32 array divide decimal array
3755
        let schema = Arc::new(Schema::new(vec![
3756
            Field::new("a", DataType::Int32, true),
3757
            Field::new("b", DataType::Decimal128(10, 2), true),
3758
        ]));
3759
        let expect = Arc::new(create_decimal_array(
3760
            &[Some(1000000), None, Some(1008196), Some(1000000)],
3761
            16,
3762
            4,
3763
        )) as ArrayRef;
3764
        apply_decimal_arithmetic_op(
3765
            &schema,
3766
            &int32_array,
3767
            &decimal_array,
3768
            Operator::Divide,
3769
            expect,
3770
        )
3771
        .unwrap();
3772
3773
        // modulus: int32 array modulus decimal array
3774
        let schema = Arc::new(Schema::new(vec![
3775
            Field::new("a", DataType::Int32, true),
3776
            Field::new("b", DataType::Decimal128(10, 2), true),
3777
        ]));
3778
        let expect = Arc::new(create_decimal_array(
3779
            &[Some(000), None, Some(100), Some(000)],
3780
            10,
3781
            2,
3782
        )) as ArrayRef;
3783
        apply_decimal_arithmetic_op(
3784
            &schema,
3785
            &int32_array,
3786
            &decimal_array,
3787
            Operator::Modulo,
3788
            expect,
3789
        )
3790
        .unwrap();
3791
3792
        Ok(())
3793
    }
3794
3795
    #[test]
3796
    fn arithmetic_decimal_float_expr_test() -> Result<()> {
3797
        let schema = Arc::new(Schema::new(vec![
3798
            Field::new("a", DataType::Float64, true),
3799
            Field::new("b", DataType::Decimal128(10, 2), true),
3800
        ]));
3801
        let value: i128 = 123;
3802
        let decimal_array = Arc::new(create_decimal_array(
3803
            &[
3804
                Some(value), // 1.23
3805
                None,
3806
                Some(value - 1), // 1.22
3807
                Some(value + 1), // 1.24
3808
            ],
3809
            10,
3810
            2,
3811
        )) as ArrayRef;
3812
        let float64_array = Arc::new(Float64Array::from(vec![
3813
            Some(123.0),
3814
            Some(122.0),
3815
            Some(123.0),
3816
            Some(124.0),
3817
        ])) as ArrayRef;
3818
3819
        // add: float64 array add decimal array
3820
        let expect = Arc::new(Float64Array::from(vec![
3821
            Some(124.23),
3822
            None,
3823
            Some(124.22),
3824
            Some(125.24),
3825
        ])) as ArrayRef;
3826
        apply_decimal_arithmetic_op(
3827
            &schema,
3828
            &float64_array,
3829
            &decimal_array,
3830
            Operator::Plus,
3831
            expect,
3832
        )
3833
        .unwrap();
3834
3835
        // subtract: decimal array subtract float64 array
3836
        let schema = Arc::new(Schema::new(vec![
3837
            Field::new("a", DataType::Float64, true),
3838
            Field::new("b", DataType::Decimal128(10, 2), true),
3839
        ]));
3840
        let expect = Arc::new(Float64Array::from(vec![
3841
            Some(121.77),
3842
            None,
3843
            Some(121.78),
3844
            Some(122.76),
3845
        ])) as ArrayRef;
3846
        apply_decimal_arithmetic_op(
3847
            &schema,
3848
            &float64_array,
3849
            &decimal_array,
3850
            Operator::Minus,
3851
            expect,
3852
        )
3853
        .unwrap();
3854
3855
        // multiply: decimal array multiply float64 array
3856
        let expect = Arc::new(Float64Array::from(vec![
3857
            Some(151.29),
3858
            None,
3859
            Some(150.06),
3860
            Some(153.76),
3861
        ])) as ArrayRef;
3862
        apply_decimal_arithmetic_op(
3863
            &schema,
3864
            &float64_array,
3865
            &decimal_array,
3866
            Operator::Multiply,
3867
            expect,
3868
        )
3869
        .unwrap();
3870
3871
        // divide: float64 array divide decimal array
3872
        let schema = Arc::new(Schema::new(vec![
3873
            Field::new("a", DataType::Float64, true),
3874
            Field::new("b", DataType::Decimal128(10, 2), true),
3875
        ]));
3876
        let expect = Arc::new(Float64Array::from(vec![
3877
            Some(100.0),
3878
            None,
3879
            Some(100.81967213114754),
3880
            Some(100.0),
3881
        ])) as ArrayRef;
3882
        apply_decimal_arithmetic_op(
3883
            &schema,
3884
            &float64_array,
3885
            &decimal_array,
3886
            Operator::Divide,
3887
            expect,
3888
        )
3889
        .unwrap();
3890
3891
        // modulus: float64 array modulus decimal array
3892
        let schema = Arc::new(Schema::new(vec![
3893
            Field::new("a", DataType::Float64, true),
3894
            Field::new("b", DataType::Decimal128(10, 2), true),
3895
        ]));
3896
        let expect = Arc::new(Float64Array::from(vec![
3897
            Some(1.7763568394002505e-15),
3898
            None,
3899
            Some(1.0000000000000027),
3900
            Some(8.881784197001252e-16),
3901
        ])) as ArrayRef;
3902
        apply_decimal_arithmetic_op(
3903
            &schema,
3904
            &float64_array,
3905
            &decimal_array,
3906
            Operator::Modulo,
3907
            expect,
3908
        )
3909
        .unwrap();
3910
3911
        Ok(())
3912
    }
3913
3914
    #[test]
3915
    fn arithmetic_divide_zero() -> Result<()> {
3916
        // other data type
3917
        let schema = Arc::new(Schema::new(vec![
3918
            Field::new("a", DataType::Int32, true),
3919
            Field::new("b", DataType::Int32, true),
3920
        ]));
3921
        let a = Arc::new(Int32Array::from(vec![100]));
3922
        let b = Arc::new(Int32Array::from(vec![0]));
3923
3924
        let err = apply_arithmetic::<Int32Type>(
3925
            schema,
3926
            vec![a, b],
3927
            Operator::Divide,
3928
            Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64)]),
3929
        )
3930
        .unwrap_err();
3931
3932
        let _expected = plan_datafusion_err!("Divide by zero");
3933
3934
        assert!(matches!(err, ref _expected), "{err}");
3935
3936
        // decimal
3937
        let schema = Arc::new(Schema::new(vec![
3938
            Field::new("a", DataType::Decimal128(25, 3), true),
3939
            Field::new("b", DataType::Decimal128(25, 3), true),
3940
        ]));
3941
        let left_decimal_array = Arc::new(create_decimal_array(&[Some(1234567)], 25, 3));
3942
        let right_decimal_array = Arc::new(create_decimal_array(&[Some(0)], 25, 3));
3943
3944
        let err = apply_arithmetic::<Decimal128Type>(
3945
            schema,
3946
            vec![left_decimal_array, right_decimal_array],
3947
            Operator::Divide,
3948
            create_decimal_array(
3949
                &[Some(12345670000000000000000000000000000), None],
3950
                38,
3951
                29,
3952
            ),
3953
        )
3954
        .unwrap_err();
3955
3956
        assert!(matches!(err, ref _expected), "{err}");
3957
3958
        Ok(())
3959
    }
3960
3961
    #[test]
3962
    fn bitwise_array_test() -> Result<()> {
3963
        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
3964
        let right =
3965
            Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
3966
        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
3967
        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
3968
        assert_eq!(result.as_ref(), &expected);
3969
3970
        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
3971
        let expected = Int32Array::from(vec![Some(13), None, Some(15)]);
3972
        assert_eq!(result.as_ref(), &expected);
3973
3974
        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
3975
        let expected = Int32Array::from(vec![Some(13), None, Some(12)]);
3976
        assert_eq!(result.as_ref(), &expected);
3977
3978
        let left =
3979
            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
3980
        let right =
3981
            Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
3982
        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
3983
        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
3984
        assert_eq!(result.as_ref(), &expected);
3985
3986
        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
3987
        let expected = UInt32Array::from(vec![Some(13), None, Some(15)]);
3988
        assert_eq!(result.as_ref(), &expected);
3989
3990
        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
3991
        let expected = UInt32Array::from(vec![Some(13), None, Some(12)]);
3992
        assert_eq!(result.as_ref(), &expected);
3993
3994
        Ok(())
3995
    }
3996
3997
    #[test]
3998
    fn bitwise_shift_array_test() -> Result<()> {
3999
        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4000
        let modules =
4001
            Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4002
        let mut result =
4003
            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4004
4005
        let expected = Int32Array::from(vec![Some(8), None, Some(2560)]);
4006
        assert_eq!(result.as_ref(), &expected);
4007
4008
        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4009
        assert_eq!(result.as_ref(), &input);
4010
4011
        let input =
4012
            Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4013
        let modules =
4014
            Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4015
        let mut result =
4016
            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4017
4018
        let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]);
4019
        assert_eq!(result.as_ref(), &expected);
4020
4021
        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4022
        assert_eq!(result.as_ref(), &input);
4023
        Ok(())
4024
    }
4025
4026
    #[test]
4027
    fn bitwise_shift_array_overflow_test() -> Result<()> {
4028
        let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef;
4029
        let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef;
4030
        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4031
4032
        let expected = Int32Array::from(vec![Some(32)]);
4033
        assert_eq!(result.as_ref(), &expected);
4034
4035
        let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef;
4036
        let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef;
4037
        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4038
4039
        let expected = UInt32Array::from(vec![Some(32)]);
4040
        assert_eq!(result.as_ref(), &expected);
4041
        Ok(())
4042
    }
4043
4044
    #[test]
4045
    fn bitwise_scalar_test() -> Result<()> {
4046
        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4047
        let right = ScalarValue::from(3i32);
4048
        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4049
        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4050
        assert_eq!(result.as_ref(), &expected);
4051
4052
        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4053
        let expected = Int32Array::from(vec![Some(15), None, Some(11)]);
4054
        assert_eq!(result.as_ref(), &expected);
4055
4056
        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4057
        let expected = Int32Array::from(vec![Some(15), None, Some(8)]);
4058
        assert_eq!(result.as_ref(), &expected);
4059
4060
        let left =
4061
            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4062
        let right = ScalarValue::from(3u32);
4063
        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4064
        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4065
        assert_eq!(result.as_ref(), &expected);
4066
4067
        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4068
        let expected = UInt32Array::from(vec![Some(15), None, Some(11)]);
4069
        assert_eq!(result.as_ref(), &expected);
4070
4071
        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4072
        let expected = UInt32Array::from(vec![Some(15), None, Some(8)]);
4073
        assert_eq!(result.as_ref(), &expected);
4074
        Ok(())
4075
    }
4076
4077
    #[test]
4078
    fn bitwise_shift_scalar_test() -> Result<()> {
4079
        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4080
        let module = ScalarValue::from(10i32);
4081
        let mut result =
4082
            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4083
4084
        let expected = Int32Array::from(vec![Some(2048), None, Some(4096)]);
4085
        assert_eq!(result.as_ref(), &expected);
4086
4087
        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4088
        assert_eq!(result.as_ref(), &input);
4089
4090
        let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4091
        let module = ScalarValue::from(10u32);
4092
        let mut result =
4093
            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4094
4095
        let expected = UInt32Array::from(vec![Some(2048), None, Some(4096)]);
4096
        assert_eq!(result.as_ref(), &expected);
4097
4098
        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4099
        assert_eq!(result.as_ref(), &input);
4100
        Ok(())
4101
    }
4102
4103
    #[test]
4104
    fn test_display_and_or_combo() {
4105
        let expr = BinaryExpr::new(
4106
            Arc::new(BinaryExpr::new(
4107
                lit(ScalarValue::from(1)),
4108
                Operator::And,
4109
                lit(ScalarValue::from(2)),
4110
            )),
4111
            Operator::And,
4112
            Arc::new(BinaryExpr::new(
4113
                lit(ScalarValue::from(3)),
4114
                Operator::And,
4115
                lit(ScalarValue::from(4)),
4116
            )),
4117
        );
4118
        assert_eq!(expr.to_string(), "1 AND 2 AND 3 AND 4");
4119
4120
        let expr = BinaryExpr::new(
4121
            Arc::new(BinaryExpr::new(
4122
                lit(ScalarValue::from(1)),
4123
                Operator::Or,
4124
                lit(ScalarValue::from(2)),
4125
            )),
4126
            Operator::Or,
4127
            Arc::new(BinaryExpr::new(
4128
                lit(ScalarValue::from(3)),
4129
                Operator::Or,
4130
                lit(ScalarValue::from(4)),
4131
            )),
4132
        );
4133
        assert_eq!(expr.to_string(), "1 OR 2 OR 3 OR 4");
4134
4135
        let expr = BinaryExpr::new(
4136
            Arc::new(BinaryExpr::new(
4137
                lit(ScalarValue::from(1)),
4138
                Operator::And,
4139
                lit(ScalarValue::from(2)),
4140
            )),
4141
            Operator::Or,
4142
            Arc::new(BinaryExpr::new(
4143
                lit(ScalarValue::from(3)),
4144
                Operator::And,
4145
                lit(ScalarValue::from(4)),
4146
            )),
4147
        );
4148
        assert_eq!(expr.to_string(), "1 AND 2 OR 3 AND 4");
4149
4150
        let expr = BinaryExpr::new(
4151
            Arc::new(BinaryExpr::new(
4152
                lit(ScalarValue::from(1)),
4153
                Operator::Or,
4154
                lit(ScalarValue::from(2)),
4155
            )),
4156
            Operator::And,
4157
            Arc::new(BinaryExpr::new(
4158
                lit(ScalarValue::from(3)),
4159
                Operator::Or,
4160
                lit(ScalarValue::from(4)),
4161
            )),
4162
        );
4163
        assert_eq!(expr.to_string(), "(1 OR 2) AND (3 OR 4)");
4164
    }
4165
4166
    #[test]
4167
    fn test_to_result_type_array() {
4168
        let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
4169
        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
4170
        let dictionary =
4171
            Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef;
4172
4173
        // Casting Dictionary to Int32
4174
        let casted = to_result_type_array(
4175
            &Operator::Plus,
4176
            Arc::clone(&dictionary),
4177
            &DataType::Int32,
4178
        )
4179
        .unwrap();
4180
        assert_eq!(
4181
            &casted,
4182
            &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)]))
4183
                as ArrayRef)
4184
        );
4185
4186
        // Array has same datatype as result type, no casting
4187
        let casted = to_result_type_array(
4188
            &Operator::Plus,
4189
            Arc::clone(&dictionary),
4190
            dictionary.data_type(),
4191
        )
4192
        .unwrap();
4193
        assert_eq!(&casted, &dictionary);
4194
4195
        // Not numerical operator, no casting
4196
        let casted = to_result_type_array(
4197
            &Operator::Eq,
4198
            Arc::clone(&dictionary),
4199
            &DataType::Int32,
4200
        )
4201
        .unwrap();
4202
        assert_eq!(&casted, &dictionary);
4203
    }
4204
4205
    #[test]
4206
    fn test_add_with_overflow() -> Result<()> {
4207
        // create test data
4208
        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4209
        let r = Arc::new(Int32Array::from(vec![2, 1]));
4210
        let schema = Arc::new(Schema::new(vec![
4211
            Field::new("l", DataType::Int32, false),
4212
            Field::new("r", DataType::Int32, false),
4213
        ]));
4214
        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4215
4216
        // create expression
4217
        let expr = BinaryExpr::new(
4218
            Arc::new(Column::new("l", 0)),
4219
            Operator::Plus,
4220
            Arc::new(Column::new("r", 1)),
4221
        )
4222
        .with_fail_on_overflow(true);
4223
4224
        // evaluate expression
4225
        let result = expr.evaluate(&batch);
4226
        assert!(result
4227
            .err()
4228
            .unwrap()
4229
            .to_string()
4230
            .contains("Overflow happened on: 2147483647 + 1"));
4231
        Ok(())
4232
    }
4233
4234
    #[test]
4235
    fn test_subtract_with_overflow() -> Result<()> {
4236
        // create test data
4237
        let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
4238
        let r = Arc::new(Int32Array::from(vec![2, 1]));
4239
        let schema = Arc::new(Schema::new(vec![
4240
            Field::new("l", DataType::Int32, false),
4241
            Field::new("r", DataType::Int32, false),
4242
        ]));
4243
        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4244
4245
        // create expression
4246
        let expr = BinaryExpr::new(
4247
            Arc::new(Column::new("l", 0)),
4248
            Operator::Minus,
4249
            Arc::new(Column::new("r", 1)),
4250
        )
4251
        .with_fail_on_overflow(true);
4252
4253
        // evaluate expression
4254
        let result = expr.evaluate(&batch);
4255
        assert!(result
4256
            .err()
4257
            .unwrap()
4258
            .to_string()
4259
            .contains("Overflow happened on: -2147483648 - 1"));
4260
        Ok(())
4261
    }
4262
4263
    #[test]
4264
    fn test_mul_with_overflow() -> Result<()> {
4265
        // create test data
4266
        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4267
        let r = Arc::new(Int32Array::from(vec![2, 2]));
4268
        let schema = Arc::new(Schema::new(vec![
4269
            Field::new("l", DataType::Int32, false),
4270
            Field::new("r", DataType::Int32, false),
4271
        ]));
4272
        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4273
4274
        // create expression
4275
        let expr = BinaryExpr::new(
4276
            Arc::new(Column::new("l", 0)),
4277
            Operator::Multiply,
4278
            Arc::new(Column::new("r", 1)),
4279
        )
4280
        .with_fail_on_overflow(true);
4281
4282
        // evaluate expression
4283
        let result = expr.evaluate(&batch);
4284
        assert!(result
4285
            .err()
4286
            .unwrap()
4287
            .to_string()
4288
            .contains("Overflow happened on: 2147483647 * 2"));
4289
        Ok(())
4290
    }
4291
4292
    /// Test helper for SIMILAR TO binary operation
4293
    fn apply_similar_to(
4294
        schema: &SchemaRef,
4295
        va: Vec<&str>,
4296
        vb: Vec<&str>,
4297
        negated: bool,
4298
        case_insensitive: bool,
4299
        expected: &BooleanArray,
4300
    ) -> Result<()> {
4301
        let a = StringArray::from(va);
4302
        let b = StringArray::from(vb);
4303
        let op = similar_to(
4304
            negated,
4305
            case_insensitive,
4306
            col("a", schema)?,
4307
            col("b", schema)?,
4308
        )?;
4309
        let batch =
4310
            RecordBatch::try_new(Arc::clone(schema), vec![Arc::new(a), Arc::new(b)])?;
4311
        let result = op
4312
            .evaluate(&batch)?
4313
            .into_array(batch.num_rows())
4314
            .expect("Failed to convert to array");
4315
        assert_eq!(result.as_ref(), expected);
4316
4317
        Ok(())
4318
    }
4319
4320
    #[test]
4321
    fn test_similar_to() {
4322
        let schema = Arc::new(Schema::new(vec![
4323
            Field::new("a", DataType::Utf8, false),
4324
            Field::new("b", DataType::Utf8, false),
4325
        ]));
4326
4327
        let expected = [Some(true), Some(false)].iter().collect();
4328
        // case-sensitive
4329
        apply_similar_to(
4330
            &schema,
4331
            vec!["hello world", "Hello World"],
4332
            vec!["hello.*", "hello.*"],
4333
            false,
4334
            false,
4335
            &expected,
4336
        )
4337
        .unwrap();
4338
        // case-insensitive
4339
        apply_similar_to(
4340
            &schema,
4341
            vec!["hello world", "bye"],
4342
            vec!["hello.*", "hello.*"],
4343
            false,
4344
            true,
4345
            &expected,
4346
        )
4347
        .unwrap();
4348
    }
4349
}