/Users/andrewlamb/Software/datafusion/datafusion/physical-expr-common/src/datum.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // UnLt required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use arrow::array::BooleanArray; |
19 | | use arrow::array::{make_comparator, ArrayRef, Datum}; |
20 | | use arrow::buffer::NullBuffer; |
21 | | use arrow::compute::SortOptions; |
22 | | use arrow::error::ArrowError; |
23 | | use datafusion_common::DataFusionError; |
24 | | use datafusion_common::{arrow_datafusion_err, internal_err}; |
25 | | use datafusion_common::{Result, ScalarValue}; |
26 | | use datafusion_expr_common::columnar_value::ColumnarValue; |
27 | | use datafusion_expr_common::operator::Operator; |
28 | | use std::sync::Arc; |
29 | | |
30 | | /// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` |
31 | | /// |
32 | | /// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction |
33 | 110k | pub fn apply( |
34 | 110k | lhs: &ColumnarValue, |
35 | 110k | rhs: &ColumnarValue, |
36 | 110k | f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>, |
37 | 110k | ) -> Result<ColumnarValue> { |
38 | 110k | match (&lhs, &rhs) { |
39 | 21.6k | (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { |
40 | 21.6k | Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?0 )) |
41 | | } |
42 | 0 | (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( |
43 | 0 | ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), |
44 | | ), |
45 | 88.3k | (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( |
46 | 88.3k | ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?0 )?0 ), |
47 | | ), |
48 | 0 | (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { |
49 | 0 | let array = f(&left.to_scalar()?, &right.to_scalar()?)?; |
50 | 0 | let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; |
51 | 0 | Ok(ColumnarValue::Scalar(scalar)) |
52 | | } |
53 | | } |
54 | 110k | } |
55 | | |
56 | | /// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` |
57 | 44.3k | pub fn apply_cmp( |
58 | 44.3k | lhs: &ColumnarValue, |
59 | 44.3k | rhs: &ColumnarValue, |
60 | 44.3k | f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>, |
61 | 44.3k | ) -> Result<ColumnarValue> { |
62 | 44.3k | apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?0 ))) |
63 | 44.3k | } |
64 | | |
65 | | /// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like |
66 | | /// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type |
67 | 0 | pub fn apply_cmp_for_nested( |
68 | 0 | op: Operator, |
69 | 0 | lhs: &ColumnarValue, |
70 | 0 | rhs: &ColumnarValue, |
71 | 0 | ) -> Result<ColumnarValue> { |
72 | 0 | if matches!( |
73 | 0 | op, |
74 | | Operator::Eq |
75 | | | Operator::NotEq |
76 | | | Operator::Lt |
77 | | | Operator::Gt |
78 | | | Operator::LtEq |
79 | | | Operator::GtEq |
80 | | | Operator::IsDistinctFrom |
81 | | | Operator::IsNotDistinctFrom |
82 | | ) { |
83 | 0 | apply(lhs, rhs, |l, r| { |
84 | 0 | Ok(Arc::new(compare_op_for_nested(op, l, r)?)) |
85 | 0 | }) |
86 | | } else { |
87 | 0 | internal_err!("invalid operator for nested") |
88 | | } |
89 | 0 | } |
90 | | |
91 | | /// Compare with eq with either nested or non-nested |
92 | 0 | pub fn compare_with_eq( |
93 | 0 | lhs: &dyn Datum, |
94 | 0 | rhs: &dyn Datum, |
95 | 0 | is_nested: bool, |
96 | 0 | ) -> Result<BooleanArray> { |
97 | 0 | if is_nested { |
98 | 0 | compare_op_for_nested(Operator::Eq, lhs, rhs) |
99 | | } else { |
100 | 0 | arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e)) |
101 | | } |
102 | 0 | } |
103 | | |
104 | | /// Compare on nested type List, Struct, and so on |
105 | 3 | pub fn compare_op_for_nested( |
106 | 3 | op: Operator, |
107 | 3 | lhs: &dyn Datum, |
108 | 3 | rhs: &dyn Datum, |
109 | 3 | ) -> Result<BooleanArray> { |
110 | 3 | let (l, is_l_scalar) = lhs.get(); |
111 | 3 | let (r, is_r_scalar) = rhs.get(); |
112 | 3 | let l_len = l.len(); |
113 | 3 | let r_len = r.len(); |
114 | 3 | |
115 | 3 | if l_len != r_len && !is_l_scalar0 && !is_r_scalar0 { |
116 | 0 | return internal_err!("len mismatch"); |
117 | 3 | } |
118 | | |
119 | 3 | let len = match is_l_scalar { |
120 | 0 | true => r_len, |
121 | 3 | false => l_len, |
122 | | }; |
123 | | |
124 | | // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly |
125 | 3 | if !matches!2 (op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) |
126 | 2 | && (is_l_scalar && l.null_count() == 10 || is_r_scalar && r.null_count() == 10 ) |
127 | | { |
128 | 0 | return Ok(BooleanArray::new_null(len)); |
129 | 3 | } |
130 | | |
131 | | // TODO: make SortOptions configurable |
132 | | // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour |
133 | 3 | let cmp = make_comparator(l, r, SortOptions::default())?0 ; |
134 | | |
135 | 5 | let cmp_with_op3 = |i, j| match op { |
136 | 5 | Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(), |
137 | 0 | Operator::Lt => cmp(i, j).is_lt(), |
138 | 0 | Operator::Gt => cmp(i, j).is_gt(), |
139 | 0 | Operator::LtEq => !cmp(i, j).is_gt(), |
140 | 0 | Operator::GtEq => !cmp(i, j).is_lt(), |
141 | 0 | Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(), |
142 | 0 | _ => unreachable!("unexpected operator found"), |
143 | 5 | }; |
144 | | |
145 | 3 | let values = match (is_l_scalar, is_r_scalar) { |
146 | 5 | (false, false) => (0..len).map(3 |i| cmp_with_op(i, i)).collect()3 , |
147 | 0 | (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(), |
148 | 0 | (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(), |
149 | 0 | (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), |
150 | | }; |
151 | | |
152 | | // Distinct understand how to compare with NULL |
153 | | // i.e NULL is distinct from NULL -> false |
154 | 3 | if matches!2 (op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) { |
155 | 1 | Ok(BooleanArray::new(values, None)) |
156 | | } else { |
157 | | // If one of the side is NULL, we returns NULL |
158 | | // i.e. NULL eq NULL -> NULL |
159 | 2 | let nulls = NullBuffer::union(l.nulls(), r.nulls()); |
160 | 2 | Ok(BooleanArray::new(values, nulls)) |
161 | | } |
162 | 3 | } |