Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr-common/src/datum.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// UnLt required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use arrow::array::BooleanArray;
19
use arrow::array::{make_comparator, ArrayRef, Datum};
20
use arrow::buffer::NullBuffer;
21
use arrow::compute::SortOptions;
22
use arrow::error::ArrowError;
23
use datafusion_common::DataFusionError;
24
use datafusion_common::{arrow_datafusion_err, internal_err};
25
use datafusion_common::{Result, ScalarValue};
26
use datafusion_expr_common::columnar_value::ColumnarValue;
27
use datafusion_expr_common::operator::Operator;
28
use std::sync::Arc;
29
30
/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs`
31
///
32
/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction
33
110k
pub fn apply(
34
110k
    lhs: &ColumnarValue,
35
110k
    rhs: &ColumnarValue,
36
110k
    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
37
110k
) -> Result<ColumnarValue> {
38
110k
    match (&lhs, &rhs) {
39
21.6k
        (ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
40
21.6k
            Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())
?0
))
41
        }
42
0
        (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
43
0
            ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
44
        ),
45
88.3k
        (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
46
88.3k
            ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()
?0
)
?0
),
47
        ),
48
0
        (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
49
0
            let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
50
0
            let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
51
0
            Ok(ColumnarValue::Scalar(scalar))
52
        }
53
    }
54
110k
}
55
56
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
57
44.3k
pub fn apply_cmp(
58
44.3k
    lhs: &ColumnarValue,
59
44.3k
    rhs: &ColumnarValue,
60
44.3k
    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
61
44.3k
) -> Result<ColumnarValue> {
62
44.3k
    apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)
?0
)))
63
44.3k
}
64
65
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like
66
/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type
67
0
pub fn apply_cmp_for_nested(
68
0
    op: Operator,
69
0
    lhs: &ColumnarValue,
70
0
    rhs: &ColumnarValue,
71
0
) -> Result<ColumnarValue> {
72
0
    if matches!(
73
0
        op,
74
        Operator::Eq
75
            | Operator::NotEq
76
            | Operator::Lt
77
            | Operator::Gt
78
            | Operator::LtEq
79
            | Operator::GtEq
80
            | Operator::IsDistinctFrom
81
            | Operator::IsNotDistinctFrom
82
    ) {
83
0
        apply(lhs, rhs, |l, r| {
84
0
            Ok(Arc::new(compare_op_for_nested(op, l, r)?))
85
0
        })
86
    } else {
87
0
        internal_err!("invalid operator for nested")
88
    }
89
0
}
90
91
/// Compare with eq with either nested or non-nested
92
0
pub fn compare_with_eq(
93
0
    lhs: &dyn Datum,
94
0
    rhs: &dyn Datum,
95
0
    is_nested: bool,
96
0
) -> Result<BooleanArray> {
97
0
    if is_nested {
98
0
        compare_op_for_nested(Operator::Eq, lhs, rhs)
99
    } else {
100
0
        arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
101
    }
102
0
}
103
104
/// Compare on nested type List, Struct, and so on
105
3
pub fn compare_op_for_nested(
106
3
    op: Operator,
107
3
    lhs: &dyn Datum,
108
3
    rhs: &dyn Datum,
109
3
) -> Result<BooleanArray> {
110
3
    let (l, is_l_scalar) = lhs.get();
111
3
    let (r, is_r_scalar) = rhs.get();
112
3
    let l_len = l.len();
113
3
    let r_len = r.len();
114
3
115
3
    if l_len != r_len && 
!is_l_scalar0
&&
!is_r_scalar0
{
116
0
        return internal_err!("len mismatch");
117
3
    }
118
119
3
    let len = match is_l_scalar {
120
0
        true => r_len,
121
3
        false => l_len,
122
    };
123
124
    // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly
125
3
    if !
matches!2
(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
126
2
        && (is_l_scalar && 
l.null_count() == 10
|| is_r_scalar &&
r.null_count() == 10
)
127
    {
128
0
        return Ok(BooleanArray::new_null(len));
129
3
    }
130
131
    // TODO: make SortOptions configurable
132
    // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour
133
3
    let cmp = make_comparator(l, r, SortOptions::default())
?0
;
134
135
5
    let 
cmp_with_op3
= |i, j| match op {
136
5
        Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
137
0
        Operator::Lt => cmp(i, j).is_lt(),
138
0
        Operator::Gt => cmp(i, j).is_gt(),
139
0
        Operator::LtEq => !cmp(i, j).is_gt(),
140
0
        Operator::GtEq => !cmp(i, j).is_lt(),
141
0
        Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
142
0
        _ => unreachable!("unexpected operator found"),
143
5
    };
144
145
3
    let values = match (is_l_scalar, is_r_scalar) {
146
5
        (false, false) => 
(0..len).map(3
|i| cmp_with_op(i, i)
).collect()3
,
147
0
        (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
148
0
        (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
149
0
        (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
150
    };
151
152
    // Distinct understand how to compare with NULL
153
    // i.e NULL is distinct from NULL -> false
154
3
    if 
matches!2
(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
155
1
        Ok(BooleanArray::new(values, None))
156
    } else {
157
        // If one of the side is NULL, we returns NULL
158
        // i.e. NULL eq NULL -> NULL
159
2
        let nulls = NullBuffer::union(l.nulls(), r.nulls());
160
2
        Ok(BooleanArray::new(values, nulls))
161
    }
162
3
}