/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/expressions/in_list.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Implementation of `InList` expressions: [`InListExpr`] |
19 | | |
20 | | use std::any::Any; |
21 | | use std::fmt::Debug; |
22 | | use std::hash::{Hash, Hasher}; |
23 | | use std::sync::Arc; |
24 | | |
25 | | use crate::physical_expr::{down_cast_any_ref, physical_exprs_bag_equal}; |
26 | | use crate::PhysicalExpr; |
27 | | |
28 | | use arrow::array::*; |
29 | | use arrow::buffer::BooleanBuffer; |
30 | | use arrow::compute::kernels::boolean::{not, or_kleene}; |
31 | | use arrow::compute::take; |
32 | | use arrow::datatypes::*; |
33 | | use arrow::util::bit_iterator::BitIndexIterator; |
34 | | use arrow::{downcast_dictionary_array, downcast_primitive_array}; |
35 | | use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; |
36 | | use datafusion_common::cast::{ |
37 | | as_boolean_array, as_generic_binary_array, as_string_array, |
38 | | }; |
39 | | use datafusion_common::hash_utils::HashValue; |
40 | | use datafusion_common::{ |
41 | | exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue, |
42 | | }; |
43 | | use datafusion_expr::ColumnarValue; |
44 | | use datafusion_physical_expr_common::datum::compare_with_eq; |
45 | | |
46 | | use ahash::RandomState; |
47 | | use hashbrown::hash_map::RawEntryMut; |
48 | | use hashbrown::HashMap; |
49 | | |
50 | | /// InList |
51 | | pub struct InListExpr { |
52 | | expr: Arc<dyn PhysicalExpr>, |
53 | | list: Vec<Arc<dyn PhysicalExpr>>, |
54 | | negated: bool, |
55 | | static_filter: Option<Arc<dyn Set>>, |
56 | | } |
57 | | |
58 | | impl Debug for InListExpr { |
59 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
60 | 0 | f.debug_struct("InListExpr") |
61 | 0 | .field("expr", &self.expr) |
62 | 0 | .field("list", &self.list) |
63 | 0 | .field("negated", &self.negated) |
64 | 0 | .finish() |
65 | 0 | } |
66 | | } |
67 | | |
68 | | /// A type-erased container of array elements |
69 | | pub trait Set: Send + Sync { |
70 | | fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>; |
71 | | fn has_nulls(&self) -> bool; |
72 | | } |
73 | | |
74 | | struct ArrayHashSet { |
75 | | state: RandomState, |
76 | | /// Used to provide a lookup from value to in list index |
77 | | /// |
78 | | /// Note: usize::hash is not used, instead the raw entry |
79 | | /// API is used to store entries w.r.t their value |
80 | | map: HashMap<usize, (), ()>, |
81 | | } |
82 | | |
83 | | struct ArraySet<T> { |
84 | | array: T, |
85 | | hash_set: ArrayHashSet, |
86 | | } |
87 | | |
88 | | impl<T> ArraySet<T> |
89 | | where |
90 | | T: Array + From<ArrayData>, |
91 | | { |
92 | 0 | fn new(array: &T, hash_set: ArrayHashSet) -> Self { |
93 | 0 | Self { |
94 | 0 | array: downcast_array(array), |
95 | 0 | hash_set, |
96 | 0 | } |
97 | 0 | } |
98 | | } |
99 | | |
100 | | impl<T> Set for ArraySet<T> |
101 | | where |
102 | | T: Array + 'static, |
103 | | for<'a> &'a T: ArrayAccessor, |
104 | | for<'a> <&'a T as ArrayAccessor>::Item: IsEqual, |
105 | | { |
106 | 0 | fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> { |
107 | 0 | downcast_dictionary_array! { |
108 | | v => { |
109 | 0 | let values_contains = self.contains(v.values().as_ref(), negated)?; |
110 | 0 | let result = take(&values_contains, v.keys(), None)?; |
111 | 0 | return Ok(downcast_array(result.as_ref())) |
112 | | } |
113 | 0 | _ => {} |
114 | 0 | } |
115 | 0 |
|
116 | 0 | let v = v.as_any().downcast_ref::<T>().unwrap(); |
117 | 0 | let in_array = &self.array; |
118 | 0 | let has_nulls = in_array.null_count() != 0; |
119 | 0 |
|
120 | 0 | Ok(ArrayIter::new(v) |
121 | 0 | .map(|v| { |
122 | 0 | v.and_then(|v| { |
123 | 0 | let hash = v.hash_one(&self.hash_set.state); |
124 | 0 | let contains = self |
125 | 0 | .hash_set |
126 | 0 | .map |
127 | 0 | .raw_entry() |
128 | 0 | .from_hash(hash, |idx| in_array.value(*idx).is_equal(&v)) |
129 | 0 | .is_some(); |
130 | | |
131 | 0 | match contains { |
132 | 0 | true => Some(!negated), |
133 | 0 | false if has_nulls => None, |
134 | 0 | false => Some(negated), |
135 | | } |
136 | 0 | }) |
137 | 0 | }) |
138 | 0 | .collect()) |
139 | 0 | } |
140 | | |
141 | 0 | fn has_nulls(&self) -> bool { |
142 | 0 | self.array.null_count() != 0 |
143 | 0 | } |
144 | | } |
145 | | |
146 | | /// Computes an [`ArrayHashSet`] for the provided [`Array`] if there |
147 | | /// are nulls present or there are more than the configured number of |
148 | | /// elements. |
149 | | /// |
150 | | /// Note: This is split into a separate function as higher-rank trait bounds currently |
151 | | /// cause type inference to misbehave |
152 | 0 | fn make_hash_set<T>(array: T) -> ArrayHashSet |
153 | 0 | where |
154 | 0 | T: ArrayAccessor, |
155 | 0 | T::Item: IsEqual, |
156 | 0 | { |
157 | 0 | let state = RandomState::new(); |
158 | 0 | let mut map: HashMap<usize, (), ()> = |
159 | 0 | HashMap::with_capacity_and_hasher(array.len(), ()); |
160 | 0 |
|
161 | 0 | let insert_value = |idx| { |
162 | 0 | let value = array.value(idx); |
163 | 0 | let hash = value.hash_one(&state); |
164 | 0 | if let RawEntryMut::Vacant(v) = map |
165 | 0 | .raw_entry_mut() |
166 | 0 | .from_hash(hash, |x| array.value(*x).is_equal(&value)) |
167 | 0 | { |
168 | 0 | v.insert_with_hasher(hash, idx, (), |x| array.value(*x).hash_one(&state)); |
169 | 0 | } |
170 | 0 | }; |
171 | | |
172 | 0 | match array.nulls() { |
173 | 0 | Some(nulls) => { |
174 | 0 | BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) |
175 | 0 | .for_each(insert_value) |
176 | | } |
177 | 0 | None => (0..array.len()).for_each(insert_value), |
178 | | } |
179 | | |
180 | 0 | ArrayHashSet { state, map } |
181 | 0 | } |
182 | | |
183 | | /// Creates a `Box<dyn Set>` for the given list of `IN` expressions and `batch` |
184 | 0 | fn make_set(array: &dyn Array) -> Result<Arc<dyn Set>> { |
185 | 0 | Ok(downcast_primitive_array! { |
186 | 0 | array => Arc::new(ArraySet::new(array, make_hash_set(array))), |
187 | | DataType::Boolean => { |
188 | 0 | let array = as_boolean_array(array)?; |
189 | 0 | Arc::new(ArraySet::new(array, make_hash_set(array))) |
190 | | }, |
191 | | DataType::Utf8 => { |
192 | 0 | let array = as_string_array(array)?; |
193 | 0 | Arc::new(ArraySet::new(array, make_hash_set(array))) |
194 | | } |
195 | | DataType::LargeUtf8 => { |
196 | 0 | let array = as_largestring_array(array); |
197 | 0 | Arc::new(ArraySet::new(array, make_hash_set(array))) |
198 | | } |
199 | | DataType::Binary => { |
200 | 0 | let array = as_generic_binary_array::<i32>(array)?; |
201 | 0 | Arc::new(ArraySet::new(array, make_hash_set(array))) |
202 | | } |
203 | | DataType::LargeBinary => { |
204 | 0 | let array = as_generic_binary_array::<i64>(array)?; |
205 | 0 | Arc::new(ArraySet::new(array, make_hash_set(array))) |
206 | | } |
207 | 0 | DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"), |
208 | 0 | d => return not_impl_err!("DataType::{d} not supported in InList") |
209 | | }) |
210 | 0 | } |
211 | | |
212 | | /// Evaluates the list of expressions into an array, flattening any dictionaries |
213 | 0 | fn evaluate_list( |
214 | 0 | list: &[Arc<dyn PhysicalExpr>], |
215 | 0 | batch: &RecordBatch, |
216 | 0 | ) -> Result<ArrayRef> { |
217 | 0 | let scalars = list |
218 | 0 | .iter() |
219 | 0 | .map(|expr| { |
220 | 0 | expr.evaluate(batch).and_then(|r| match r { |
221 | | ColumnarValue::Array(_) => { |
222 | 0 | exec_err!("InList expression must evaluate to a scalar") |
223 | | } |
224 | | // Flatten dictionary values |
225 | 0 | ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v), |
226 | 0 | ColumnarValue::Scalar(s) => Ok(s), |
227 | 0 | }) |
228 | 0 | }) |
229 | 0 | .collect::<Result<Vec<_>>>()?; |
230 | | |
231 | 0 | ScalarValue::iter_to_array(scalars) |
232 | 0 | } |
233 | | |
234 | 0 | fn try_cast_static_filter_to_set( |
235 | 0 | list: &[Arc<dyn PhysicalExpr>], |
236 | 0 | schema: &Schema, |
237 | 0 | ) -> Result<Arc<dyn Set>> { |
238 | 0 | let batch = RecordBatch::new_empty(Arc::new(schema.clone())); |
239 | 0 | make_set(evaluate_list(list, &batch)?.as_ref()) |
240 | 0 | } |
241 | | |
242 | | /// Custom equality check function which is used with [`ArrayHashSet`] for existence check. |
243 | | trait IsEqual: HashValue { |
244 | | fn is_equal(&self, other: &Self) -> bool; |
245 | | } |
246 | | |
247 | | impl<'a, T: IsEqual + ?Sized> IsEqual for &'a T { |
248 | 0 | fn is_equal(&self, other: &Self) -> bool { |
249 | 0 | T::is_equal(self, other) |
250 | 0 | } |
251 | | } |
252 | | |
253 | | macro_rules! is_equal { |
254 | | ($($t:ty),+) => { |
255 | | $(impl IsEqual for $t { |
256 | 0 | fn is_equal(&self, other: &Self) -> bool { |
257 | 0 | self == other |
258 | 0 | } |
259 | | })* |
260 | | }; |
261 | | } |
262 | | is_equal!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); |
263 | | is_equal!(bool, str, [u8]); |
264 | | is_equal!(IntervalDayTime, IntervalMonthDayNano); |
265 | | |
266 | | macro_rules! is_equal_float { |
267 | | ($($t:ty),+) => { |
268 | | $(impl IsEqual for $t { |
269 | 0 | fn is_equal(&self, other: &Self) -> bool { |
270 | 0 | self.to_bits() == other.to_bits() |
271 | 0 | } |
272 | | })* |
273 | | }; |
274 | | } |
275 | | is_equal_float!(half::f16, f32, f64); |
276 | | |
277 | | impl InListExpr { |
278 | | /// Create a new InList expression |
279 | 0 | pub fn new( |
280 | 0 | expr: Arc<dyn PhysicalExpr>, |
281 | 0 | list: Vec<Arc<dyn PhysicalExpr>>, |
282 | 0 | negated: bool, |
283 | 0 | static_filter: Option<Arc<dyn Set>>, |
284 | 0 | ) -> Self { |
285 | 0 | Self { |
286 | 0 | expr, |
287 | 0 | list, |
288 | 0 | negated, |
289 | 0 | static_filter, |
290 | 0 | } |
291 | 0 | } |
292 | | |
293 | | /// Input expression |
294 | 0 | pub fn expr(&self) -> &Arc<dyn PhysicalExpr> { |
295 | 0 | &self.expr |
296 | 0 | } |
297 | | |
298 | | /// List to search in |
299 | 0 | pub fn list(&self) -> &[Arc<dyn PhysicalExpr>] { |
300 | 0 | &self.list |
301 | 0 | } |
302 | | |
303 | | /// Is this negated e.g. NOT IN LIST |
304 | 0 | pub fn negated(&self) -> bool { |
305 | 0 | self.negated |
306 | 0 | } |
307 | | } |
308 | | |
309 | | impl std::fmt::Display for InListExpr { |
310 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
311 | 0 | if self.negated { |
312 | 0 | if self.static_filter.is_some() { |
313 | 0 | write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list) |
314 | | } else { |
315 | 0 | write!(f, "{} NOT IN ({:?})", self.expr, self.list) |
316 | | } |
317 | 0 | } else if self.static_filter.is_some() { |
318 | 0 | write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list) |
319 | | } else { |
320 | 0 | write!(f, "{} IN ({:?})", self.expr, self.list) |
321 | | } |
322 | 0 | } |
323 | | } |
324 | | |
325 | | impl PhysicalExpr for InListExpr { |
326 | | /// Return a reference to Any that can be used for downcasting |
327 | 0 | fn as_any(&self) -> &dyn Any { |
328 | 0 | self |
329 | 0 | } |
330 | | |
331 | 0 | fn data_type(&self, _input_schema: &Schema) -> Result<DataType> { |
332 | 0 | Ok(DataType::Boolean) |
333 | 0 | } |
334 | | |
335 | 0 | fn nullable(&self, input_schema: &Schema) -> Result<bool> { |
336 | 0 | if self.expr.nullable(input_schema)? { |
337 | 0 | return Ok(true); |
338 | 0 | } |
339 | | |
340 | 0 | if let Some(static_filter) = &self.static_filter { |
341 | 0 | Ok(static_filter.has_nulls()) |
342 | | } else { |
343 | 0 | for expr in &self.list { |
344 | 0 | if expr.nullable(input_schema)? { |
345 | 0 | return Ok(true); |
346 | 0 | } |
347 | | } |
348 | 0 | Ok(false) |
349 | | } |
350 | 0 | } |
351 | | |
352 | 0 | fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { |
353 | 0 | let num_rows = batch.num_rows(); |
354 | 0 | let value = self.expr.evaluate(batch)?; |
355 | 0 | let r = match &self.static_filter { |
356 | 0 | Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?, |
357 | | None => { |
358 | 0 | let value = value.into_array(num_rows)?; |
359 | 0 | let is_nested = value.data_type().is_nested(); |
360 | 0 | let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( |
361 | 0 | BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |
362 | 0 | |result, expr| -> Result<BooleanArray> { |
363 | 0 | let rhs = compare_with_eq( |
364 | 0 | &value, |
365 | 0 | &expr?.into_array(num_rows)?, |
366 | 0 | is_nested, |
367 | 0 | )?; |
368 | 0 | Ok(or_kleene(&result, &rhs)?) |
369 | 0 | }, |
370 | 0 | )?; |
371 | | |
372 | 0 | if self.negated { |
373 | 0 | not(&found)? |
374 | | } else { |
375 | 0 | found |
376 | | } |
377 | | } |
378 | | }; |
379 | 0 | Ok(ColumnarValue::Array(Arc::new(r))) |
380 | 0 | } |
381 | | |
382 | 0 | fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> { |
383 | 0 | let mut children = vec![]; |
384 | 0 | children.push(&self.expr); |
385 | 0 | children.extend(&self.list); |
386 | 0 | children |
387 | 0 | } |
388 | | |
389 | 0 | fn with_new_children( |
390 | 0 | self: Arc<Self>, |
391 | 0 | children: Vec<Arc<dyn PhysicalExpr>>, |
392 | 0 | ) -> Result<Arc<dyn PhysicalExpr>> { |
393 | 0 | // assume the static_filter will not change during the rewrite process |
394 | 0 | Ok(Arc::new(InListExpr::new( |
395 | 0 | Arc::clone(&children[0]), |
396 | 0 | children[1..].to_vec(), |
397 | 0 | self.negated, |
398 | 0 | self.static_filter.clone(), |
399 | 0 | ))) |
400 | 0 | } |
401 | | |
402 | 0 | fn dyn_hash(&self, state: &mut dyn Hasher) { |
403 | 0 | let mut s = state; |
404 | 0 | self.expr.hash(&mut s); |
405 | 0 | self.negated.hash(&mut s); |
406 | 0 | self.list.hash(&mut s); |
407 | 0 | // Add `self.static_filter` when hash is available |
408 | 0 | } |
409 | | } |
410 | | |
411 | | impl PartialEq<dyn Any> for InListExpr { |
412 | 0 | fn eq(&self, other: &dyn Any) -> bool { |
413 | 0 | down_cast_any_ref(other) |
414 | 0 | .downcast_ref::<Self>() |
415 | 0 | .map(|x| { |
416 | 0 | self.expr.eq(&x.expr) |
417 | 0 | && physical_exprs_bag_equal(&self.list, &x.list) |
418 | 0 | && self.negated == x.negated |
419 | 0 | }) |
420 | 0 | .unwrap_or(false) |
421 | 0 | } |
422 | | } |
423 | | |
424 | | /// Creates a unary expression InList |
425 | 0 | pub fn in_list( |
426 | 0 | expr: Arc<dyn PhysicalExpr>, |
427 | 0 | list: Vec<Arc<dyn PhysicalExpr>>, |
428 | 0 | negated: &bool, |
429 | 0 | schema: &Schema, |
430 | 0 | ) -> Result<Arc<dyn PhysicalExpr>> { |
431 | | // check the data type |
432 | 0 | let expr_data_type = expr.data_type(schema)?; |
433 | 0 | for list_expr in list.iter() { |
434 | 0 | let list_expr_data_type = list_expr.data_type(schema)?; |
435 | 0 | if !DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type) { |
436 | 0 | return internal_err!( |
437 | 0 | "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" |
438 | 0 | ); |
439 | 0 | } |
440 | | } |
441 | 0 | let static_filter = try_cast_static_filter_to_set(&list, schema).ok(); |
442 | 0 | Ok(Arc::new(InListExpr::new( |
443 | 0 | expr, |
444 | 0 | list, |
445 | 0 | *negated, |
446 | 0 | static_filter, |
447 | 0 | ))) |
448 | 0 | } |
449 | | |
450 | | #[cfg(test)] |
451 | | mod tests { |
452 | | |
453 | | use super::*; |
454 | | use crate::expressions; |
455 | | use crate::expressions::{col, lit, try_cast}; |
456 | | use datafusion_common::plan_err; |
457 | | use datafusion_expr::type_coercion::binary::comparison_coercion; |
458 | | |
459 | | type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>); |
460 | | |
461 | | // Try to do the type coercion for list physical expr. |
462 | | // It's just used in the test |
463 | | fn in_list_cast( |
464 | | expr: Arc<dyn PhysicalExpr>, |
465 | | list: Vec<Arc<dyn PhysicalExpr>>, |
466 | | input_schema: &Schema, |
467 | | ) -> Result<InListCastResult> { |
468 | | let expr_type = &expr.data_type(input_schema)?; |
469 | | let list_types: Vec<DataType> = list |
470 | | .iter() |
471 | | .map(|list_expr| list_expr.data_type(input_schema).unwrap()) |
472 | | .collect(); |
473 | | let result_type = get_coerce_type(expr_type, &list_types); |
474 | | match result_type { |
475 | | None => plan_err!( |
476 | | "Can not find compatible types to compare {expr_type:?} with {list_types:?}" |
477 | | ), |
478 | | Some(data_type) => { |
479 | | // find the coerced type |
480 | | let cast_expr = try_cast(expr, input_schema, data_type.clone())?; |
481 | | let cast_list_expr = list |
482 | | .into_iter() |
483 | | .map(|list_expr| { |
484 | | try_cast(list_expr, input_schema, data_type.clone()).unwrap() |
485 | | }) |
486 | | .collect(); |
487 | | Ok((cast_expr, cast_list_expr)) |
488 | | } |
489 | | } |
490 | | } |
491 | | |
492 | | // Attempts to coerce the types of `list_type` to be comparable with the |
493 | | // `expr_type` |
494 | | fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> { |
495 | | list_type |
496 | | .iter() |
497 | | .try_fold(expr_type.clone(), |left_type, right_type| { |
498 | | comparison_coercion(&left_type, right_type) |
499 | | }) |
500 | | } |
501 | | |
502 | | // applies the in_list expr to an input batch and list |
503 | | macro_rules! in_list { |
504 | | ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ |
505 | | let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; |
506 | | in_list_raw!( |
507 | | $BATCH, |
508 | | cast_list_exprs, |
509 | | $NEGATED, |
510 | | $EXPECTED, |
511 | | cast_expr, |
512 | | $SCHEMA |
513 | | ); |
514 | | }}; |
515 | | } |
516 | | |
517 | | // applies the in_list expr to an input batch and list without cast |
518 | | macro_rules! in_list_raw { |
519 | | ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ |
520 | | let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap(); |
521 | | let result = expr |
522 | | .evaluate(&$BATCH)? |
523 | | .into_array($BATCH.num_rows()) |
524 | | .expect("Failed to convert to array"); |
525 | | let result = |
526 | | as_boolean_array(&result).expect("failed to downcast to BooleanArray"); |
527 | | let expected = &BooleanArray::from($EXPECTED); |
528 | | assert_eq!(expected, result); |
529 | | }}; |
530 | | } |
531 | | |
532 | | #[test] |
533 | | fn in_list_utf8() -> Result<()> { |
534 | | let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); |
535 | | let a = StringArray::from(vec![Some("a"), Some("d"), None]); |
536 | | let col_a = col("a", &schema)?; |
537 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
538 | | |
539 | | // expression: "a in ("a", "b")" |
540 | | let list = vec![lit("a"), lit("b")]; |
541 | | in_list!( |
542 | | batch, |
543 | | list, |
544 | | &false, |
545 | | vec![Some(true), Some(false), None], |
546 | | Arc::clone(&col_a), |
547 | | &schema |
548 | | ); |
549 | | |
550 | | // expression: "a not in ("a", "b")" |
551 | | let list = vec![lit("a"), lit("b")]; |
552 | | in_list!( |
553 | | batch, |
554 | | list, |
555 | | &true, |
556 | | vec![Some(false), Some(true), None], |
557 | | Arc::clone(&col_a), |
558 | | &schema |
559 | | ); |
560 | | |
561 | | // expression: "a in ("a", "b", null)" |
562 | | let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; |
563 | | in_list!( |
564 | | batch, |
565 | | list, |
566 | | &false, |
567 | | vec![Some(true), None, None], |
568 | | Arc::clone(&col_a), |
569 | | &schema |
570 | | ); |
571 | | |
572 | | // expression: "a not in ("a", "b", null)" |
573 | | let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; |
574 | | in_list!( |
575 | | batch, |
576 | | list, |
577 | | &true, |
578 | | vec![Some(false), None, None], |
579 | | Arc::clone(&col_a), |
580 | | &schema |
581 | | ); |
582 | | |
583 | | Ok(()) |
584 | | } |
585 | | |
586 | | #[test] |
587 | | fn in_list_binary() -> Result<()> { |
588 | | let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]); |
589 | | let a = BinaryArray::from(vec![ |
590 | | Some([1, 2, 3].as_slice()), |
591 | | Some([1, 2, 2].as_slice()), |
592 | | None, |
593 | | ]); |
594 | | let col_a = col("a", &schema)?; |
595 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
596 | | |
597 | | // expression: "a in ([1, 2, 3], [4, 5, 6])" |
598 | | let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; |
599 | | in_list!( |
600 | | batch, |
601 | | list.clone(), |
602 | | &false, |
603 | | vec![Some(true), Some(false), None], |
604 | | Arc::clone(&col_a), |
605 | | &schema |
606 | | ); |
607 | | |
608 | | // expression: "a not in ([1, 2, 3], [4, 5, 6])" |
609 | | in_list!( |
610 | | batch, |
611 | | list, |
612 | | &true, |
613 | | vec![Some(false), Some(true), None], |
614 | | Arc::clone(&col_a), |
615 | | &schema |
616 | | ); |
617 | | |
618 | | // expression: "a in ([1, 2, 3], [4, 5, 6], null)" |
619 | | let list = vec![ |
620 | | lit([1, 2, 3].as_slice()), |
621 | | lit([4, 5, 6].as_slice()), |
622 | | lit(ScalarValue::Binary(None)), |
623 | | ]; |
624 | | in_list!( |
625 | | batch, |
626 | | list.clone(), |
627 | | &false, |
628 | | vec![Some(true), None, None], |
629 | | Arc::clone(&col_a), |
630 | | &schema |
631 | | ); |
632 | | |
633 | | // expression: "a in ([1, 2, 3], [4, 5, 6], null)" |
634 | | in_list!( |
635 | | batch, |
636 | | list, |
637 | | &true, |
638 | | vec![Some(false), None, None], |
639 | | Arc::clone(&col_a), |
640 | | &schema |
641 | | ); |
642 | | |
643 | | Ok(()) |
644 | | } |
645 | | |
646 | | #[test] |
647 | | fn in_list_int64() -> Result<()> { |
648 | | let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); |
649 | | let a = Int64Array::from(vec![Some(0), Some(2), None]); |
650 | | let col_a = col("a", &schema)?; |
651 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
652 | | |
653 | | // expression: "a in (0, 1)" |
654 | | let list = vec![lit(0i64), lit(1i64)]; |
655 | | in_list!( |
656 | | batch, |
657 | | list, |
658 | | &false, |
659 | | vec![Some(true), Some(false), None], |
660 | | Arc::clone(&col_a), |
661 | | &schema |
662 | | ); |
663 | | |
664 | | // expression: "a not in (0, 1)" |
665 | | let list = vec![lit(0i64), lit(1i64)]; |
666 | | in_list!( |
667 | | batch, |
668 | | list, |
669 | | &true, |
670 | | vec![Some(false), Some(true), None], |
671 | | Arc::clone(&col_a), |
672 | | &schema |
673 | | ); |
674 | | |
675 | | // expression: "a in (0, 1, NULL)" |
676 | | let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)]; |
677 | | in_list!( |
678 | | batch, |
679 | | list, |
680 | | &false, |
681 | | vec![Some(true), None, None], |
682 | | Arc::clone(&col_a), |
683 | | &schema |
684 | | ); |
685 | | |
686 | | // expression: "a not in (0, 1, NULL)" |
687 | | let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)]; |
688 | | in_list!( |
689 | | batch, |
690 | | list, |
691 | | &true, |
692 | | vec![Some(false), None, None], |
693 | | Arc::clone(&col_a), |
694 | | &schema |
695 | | ); |
696 | | |
697 | | Ok(()) |
698 | | } |
699 | | |
700 | | #[test] |
701 | | fn in_list_float64() -> Result<()> { |
702 | | let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); |
703 | | let a = Float64Array::from(vec![ |
704 | | Some(0.0), |
705 | | Some(0.2), |
706 | | None, |
707 | | Some(f64::NAN), |
708 | | Some(-f64::NAN), |
709 | | ]); |
710 | | let col_a = col("a", &schema)?; |
711 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
712 | | |
713 | | // expression: "a in (0.0, 0.1)" |
714 | | let list = vec![lit(0.0f64), lit(0.1f64)]; |
715 | | in_list!( |
716 | | batch, |
717 | | list, |
718 | | &false, |
719 | | vec![Some(true), Some(false), None, Some(false), Some(false)], |
720 | | Arc::clone(&col_a), |
721 | | &schema |
722 | | ); |
723 | | |
724 | | // expression: "a not in (0.0, 0.1)" |
725 | | let list = vec![lit(0.0f64), lit(0.1f64)]; |
726 | | in_list!( |
727 | | batch, |
728 | | list, |
729 | | &true, |
730 | | vec![Some(false), Some(true), None, Some(true), Some(true)], |
731 | | Arc::clone(&col_a), |
732 | | &schema |
733 | | ); |
734 | | |
735 | | // expression: "a in (0.0, 0.1, NULL)" |
736 | | let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; |
737 | | in_list!( |
738 | | batch, |
739 | | list, |
740 | | &false, |
741 | | vec![Some(true), None, None, None, None], |
742 | | Arc::clone(&col_a), |
743 | | &schema |
744 | | ); |
745 | | |
746 | | // expression: "a not in (0.0, 0.1, NULL)" |
747 | | let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; |
748 | | in_list!( |
749 | | batch, |
750 | | list, |
751 | | &true, |
752 | | vec![Some(false), None, None, None, None], |
753 | | Arc::clone(&col_a), |
754 | | &schema |
755 | | ); |
756 | | |
757 | | // expression: "a in (0.0, 0.1, NaN)" |
758 | | let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; |
759 | | in_list!( |
760 | | batch, |
761 | | list, |
762 | | &false, |
763 | | vec![Some(true), Some(false), None, Some(true), Some(false)], |
764 | | Arc::clone(&col_a), |
765 | | &schema |
766 | | ); |
767 | | |
768 | | // expression: "a not in (0.0, 0.1, NaN)" |
769 | | let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; |
770 | | in_list!( |
771 | | batch, |
772 | | list, |
773 | | &true, |
774 | | vec![Some(false), Some(true), None, Some(false), Some(true)], |
775 | | Arc::clone(&col_a), |
776 | | &schema |
777 | | ); |
778 | | |
779 | | // expression: "a in (0.0, 0.1, -NaN)" |
780 | | let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; |
781 | | in_list!( |
782 | | batch, |
783 | | list, |
784 | | &false, |
785 | | vec![Some(true), Some(false), None, Some(false), Some(true)], |
786 | | Arc::clone(&col_a), |
787 | | &schema |
788 | | ); |
789 | | |
790 | | // expression: "a not in (0.0, 0.1, -NaN)" |
791 | | let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; |
792 | | in_list!( |
793 | | batch, |
794 | | list, |
795 | | &true, |
796 | | vec![Some(false), Some(true), None, Some(true), Some(false)], |
797 | | Arc::clone(&col_a), |
798 | | &schema |
799 | | ); |
800 | | |
801 | | Ok(()) |
802 | | } |
803 | | |
804 | | #[test] |
805 | | fn in_list_bool() -> Result<()> { |
806 | | let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); |
807 | | let a = BooleanArray::from(vec![Some(true), None]); |
808 | | let col_a = col("a", &schema)?; |
809 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
810 | | |
811 | | // expression: "a in (true)" |
812 | | let list = vec![lit(true)]; |
813 | | in_list!( |
814 | | batch, |
815 | | list, |
816 | | &false, |
817 | | vec![Some(true), None], |
818 | | Arc::clone(&col_a), |
819 | | &schema |
820 | | ); |
821 | | |
822 | | // expression: "a not in (true)" |
823 | | let list = vec![lit(true)]; |
824 | | in_list!( |
825 | | batch, |
826 | | list, |
827 | | &true, |
828 | | vec![Some(false), None], |
829 | | Arc::clone(&col_a), |
830 | | &schema |
831 | | ); |
832 | | |
833 | | // expression: "a in (true, NULL)" |
834 | | let list = vec![lit(true), lit(ScalarValue::Null)]; |
835 | | in_list!( |
836 | | batch, |
837 | | list, |
838 | | &false, |
839 | | vec![Some(true), None], |
840 | | Arc::clone(&col_a), |
841 | | &schema |
842 | | ); |
843 | | |
844 | | // expression: "a not in (true, NULL)" |
845 | | let list = vec![lit(true), lit(ScalarValue::Null)]; |
846 | | in_list!( |
847 | | batch, |
848 | | list, |
849 | | &true, |
850 | | vec![Some(false), None], |
851 | | Arc::clone(&col_a), |
852 | | &schema |
853 | | ); |
854 | | |
855 | | Ok(()) |
856 | | } |
857 | | |
858 | | #[test] |
859 | | fn in_list_date64() -> Result<()> { |
860 | | let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]); |
861 | | let a = Date64Array::from(vec![Some(0), Some(2), None]); |
862 | | let col_a = col("a", &schema)?; |
863 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
864 | | |
865 | | // expression: "a in (0, 1)" |
866 | | let list = vec![ |
867 | | lit(ScalarValue::Date64(Some(0))), |
868 | | lit(ScalarValue::Date64(Some(1))), |
869 | | ]; |
870 | | in_list!( |
871 | | batch, |
872 | | list, |
873 | | &false, |
874 | | vec![Some(true), Some(false), None], |
875 | | Arc::clone(&col_a), |
876 | | &schema |
877 | | ); |
878 | | |
879 | | // expression: "a not in (0, 1)" |
880 | | let list = vec![ |
881 | | lit(ScalarValue::Date64(Some(0))), |
882 | | lit(ScalarValue::Date64(Some(1))), |
883 | | ]; |
884 | | in_list!( |
885 | | batch, |
886 | | list, |
887 | | &true, |
888 | | vec![Some(false), Some(true), None], |
889 | | Arc::clone(&col_a), |
890 | | &schema |
891 | | ); |
892 | | |
893 | | // expression: "a in (0, 1, NULL)" |
894 | | let list = vec![ |
895 | | lit(ScalarValue::Date64(Some(0))), |
896 | | lit(ScalarValue::Date64(Some(1))), |
897 | | lit(ScalarValue::Null), |
898 | | ]; |
899 | | in_list!( |
900 | | batch, |
901 | | list, |
902 | | &false, |
903 | | vec![Some(true), None, None], |
904 | | Arc::clone(&col_a), |
905 | | &schema |
906 | | ); |
907 | | |
908 | | // expression: "a not in (0, 1, NULL)" |
909 | | let list = vec![ |
910 | | lit(ScalarValue::Date64(Some(0))), |
911 | | lit(ScalarValue::Date64(Some(1))), |
912 | | lit(ScalarValue::Null), |
913 | | ]; |
914 | | in_list!( |
915 | | batch, |
916 | | list, |
917 | | &true, |
918 | | vec![Some(false), None, None], |
919 | | Arc::clone(&col_a), |
920 | | &schema |
921 | | ); |
922 | | |
923 | | Ok(()) |
924 | | } |
925 | | |
926 | | #[test] |
927 | | fn in_list_date32() -> Result<()> { |
928 | | let schema = Schema::new(vec![Field::new("a", DataType::Date32, true)]); |
929 | | let a = Date32Array::from(vec![Some(0), Some(2), None]); |
930 | | let col_a = col("a", &schema)?; |
931 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
932 | | |
933 | | // expression: "a in (0, 1)" |
934 | | let list = vec![ |
935 | | lit(ScalarValue::Date32(Some(0))), |
936 | | lit(ScalarValue::Date32(Some(1))), |
937 | | ]; |
938 | | in_list!( |
939 | | batch, |
940 | | list, |
941 | | &false, |
942 | | vec![Some(true), Some(false), None], |
943 | | Arc::clone(&col_a), |
944 | | &schema |
945 | | ); |
946 | | |
947 | | // expression: "a not in (0, 1)" |
948 | | let list = vec![ |
949 | | lit(ScalarValue::Date32(Some(0))), |
950 | | lit(ScalarValue::Date32(Some(1))), |
951 | | ]; |
952 | | in_list!( |
953 | | batch, |
954 | | list, |
955 | | &true, |
956 | | vec![Some(false), Some(true), None], |
957 | | Arc::clone(&col_a), |
958 | | &schema |
959 | | ); |
960 | | |
961 | | // expression: "a in (0, 1, NULL)" |
962 | | let list = vec![ |
963 | | lit(ScalarValue::Date32(Some(0))), |
964 | | lit(ScalarValue::Date32(Some(1))), |
965 | | lit(ScalarValue::Null), |
966 | | ]; |
967 | | in_list!( |
968 | | batch, |
969 | | list, |
970 | | &false, |
971 | | vec![Some(true), None, None], |
972 | | Arc::clone(&col_a), |
973 | | &schema |
974 | | ); |
975 | | |
976 | | // expression: "a not in (0, 1, NULL)" |
977 | | let list = vec![ |
978 | | lit(ScalarValue::Date32(Some(0))), |
979 | | lit(ScalarValue::Date32(Some(1))), |
980 | | lit(ScalarValue::Null), |
981 | | ]; |
982 | | in_list!( |
983 | | batch, |
984 | | list, |
985 | | &true, |
986 | | vec![Some(false), None, None], |
987 | | Arc::clone(&col_a), |
988 | | &schema |
989 | | ); |
990 | | |
991 | | Ok(()) |
992 | | } |
993 | | |
994 | | #[test] |
995 | | fn in_list_decimal() -> Result<()> { |
996 | | // Now, we can check the NULL type |
997 | | let schema = |
998 | | Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]); |
999 | | let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)] |
1000 | | .into_iter() |
1001 | | .collect::<Decimal128Array>(); |
1002 | | let array = array.with_precision_and_scale(13, 4).unwrap(); |
1003 | | let col_a = col("a", &schema)?; |
1004 | | let batch = |
1005 | | RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?; |
1006 | | |
1007 | | // expression: "a in (100,200), the data type of list is INT32 |
1008 | | let list = vec![lit(100i32), lit(200i32)]; |
1009 | | in_list!( |
1010 | | batch, |
1011 | | list, |
1012 | | &false, |
1013 | | vec![Some(true), None, Some(false)], |
1014 | | Arc::clone(&col_a), |
1015 | | &schema |
1016 | | ); |
1017 | | // expression: "a not in (100,200) |
1018 | | let list = vec![lit(100i32), lit(200i32)]; |
1019 | | in_list!( |
1020 | | batch, |
1021 | | list, |
1022 | | &true, |
1023 | | vec![Some(false), None, Some(true)], |
1024 | | Arc::clone(&col_a), |
1025 | | &schema |
1026 | | ); |
1027 | | |
1028 | | // expression: "a in (200,NULL), the data type of list is INT32 AND NULL |
1029 | | let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)]; |
1030 | | in_list!( |
1031 | | batch, |
1032 | | list.clone(), |
1033 | | &false, |
1034 | | vec![Some(true), None, None], |
1035 | | Arc::clone(&col_a), |
1036 | | &schema |
1037 | | ); |
1038 | | // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL |
1039 | | in_list!( |
1040 | | batch, |
1041 | | list, |
1042 | | &true, |
1043 | | vec![Some(false), None, None], |
1044 | | Arc::clone(&col_a), |
1045 | | &schema |
1046 | | ); |
1047 | | |
1048 | | // expression: "a in (200.5, 100), the data type of list is FLOAT32 and INT32 |
1049 | | let list = vec![lit(200.50f32), lit(100i32)]; |
1050 | | in_list!( |
1051 | | batch, |
1052 | | list, |
1053 | | &false, |
1054 | | vec![Some(true), None, Some(true)], |
1055 | | Arc::clone(&col_a), |
1056 | | &schema |
1057 | | ); |
1058 | | |
1059 | | // expression: "a not in (200.5, 100), the data type of list is FLOAT32 and INT32 |
1060 | | let list = vec![lit(200.50f32), lit(101i32)]; |
1061 | | in_list!( |
1062 | | batch, |
1063 | | list, |
1064 | | &true, |
1065 | | vec![Some(true), None, Some(false)], |
1066 | | Arc::clone(&col_a), |
1067 | | &schema |
1068 | | ); |
1069 | | |
1070 | | // test the optimization: set |
1071 | | // expression: "a in (99..300), the data type of list is INT32 |
1072 | | let list = (99i32..300).map(lit).collect::<Vec<_>>(); |
1073 | | |
1074 | | in_list!( |
1075 | | batch, |
1076 | | list.clone(), |
1077 | | &false, |
1078 | | vec![Some(true), None, Some(false)], |
1079 | | Arc::clone(&col_a), |
1080 | | &schema |
1081 | | ); |
1082 | | |
1083 | | in_list!( |
1084 | | batch, |
1085 | | list, |
1086 | | &true, |
1087 | | vec![Some(false), None, Some(true)], |
1088 | | Arc::clone(&col_a), |
1089 | | &schema |
1090 | | ); |
1091 | | |
1092 | | Ok(()) |
1093 | | } |
1094 | | |
1095 | | #[test] |
1096 | | fn test_cast_static_filter_to_set() -> Result<()> { |
1097 | | // random schema |
1098 | | let schema = |
1099 | | Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]); |
1100 | | |
1101 | | // list of phy expr |
1102 | | let mut phy_exprs = vec![ |
1103 | | lit(1i64), |
1104 | | expressions::cast(lit(2i32), &schema, DataType::Int64)?, |
1105 | | expressions::try_cast(lit(3.13f32), &schema, DataType::Int64)?, |
1106 | | ]; |
1107 | | let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); |
1108 | | |
1109 | | let array = Int64Array::from(vec![1, 2, 3, 4]); |
1110 | | let r = result.contains(&array, false).unwrap(); |
1111 | | assert_eq!(r, BooleanArray::from(vec![true, true, true, false])); |
1112 | | |
1113 | | try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); |
1114 | | // cast(cast(lit())), but the cast to the same data type, one case will be ignored |
1115 | | phy_exprs.push(expressions::cast( |
1116 | | expressions::cast(lit(2i32), &schema, DataType::Int64)?, |
1117 | | &schema, |
1118 | | DataType::Int64, |
1119 | | )?); |
1120 | | try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); |
1121 | | |
1122 | | phy_exprs.clear(); |
1123 | | |
1124 | | // case(cast(lit())), the cast to the diff data type |
1125 | | phy_exprs.push(expressions::cast( |
1126 | | expressions::cast(lit(2i32), &schema, DataType::Int64)?, |
1127 | | &schema, |
1128 | | DataType::Int32, |
1129 | | )?); |
1130 | | try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); |
1131 | | |
1132 | | // column |
1133 | | phy_exprs.push(expressions::col("a", &schema)?); |
1134 | | assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err()); |
1135 | | |
1136 | | Ok(()) |
1137 | | } |
1138 | | |
1139 | | #[test] |
1140 | | fn in_list_timestamp() -> Result<()> { |
1141 | | let schema = Schema::new(vec![Field::new( |
1142 | | "a", |
1143 | | DataType::Timestamp(TimeUnit::Microsecond, None), |
1144 | | true, |
1145 | | )]); |
1146 | | let a = TimestampMicrosecondArray::from(vec![ |
1147 | | Some(1388588401000000000), |
1148 | | Some(1288588501000000000), |
1149 | | None, |
1150 | | ]); |
1151 | | let col_a = col("a", &schema)?; |
1152 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
1153 | | |
1154 | | let list = vec![ |
1155 | | lit(ScalarValue::TimestampMicrosecond( |
1156 | | Some(1388588401000000000), |
1157 | | None, |
1158 | | )), |
1159 | | lit(ScalarValue::TimestampMicrosecond( |
1160 | | Some(1388588401000000001), |
1161 | | None, |
1162 | | )), |
1163 | | lit(ScalarValue::TimestampMicrosecond( |
1164 | | Some(1388588401000000002), |
1165 | | None, |
1166 | | )), |
1167 | | ]; |
1168 | | |
1169 | | in_list!( |
1170 | | batch, |
1171 | | list.clone(), |
1172 | | &false, |
1173 | | vec![Some(true), Some(false), None], |
1174 | | Arc::clone(&col_a), |
1175 | | &schema |
1176 | | ); |
1177 | | |
1178 | | in_list!( |
1179 | | batch, |
1180 | | list.clone(), |
1181 | | &true, |
1182 | | vec![Some(false), Some(true), None], |
1183 | | Arc::clone(&col_a), |
1184 | | &schema |
1185 | | ); |
1186 | | Ok(()) |
1187 | | } |
1188 | | |
1189 | | #[test] |
1190 | | fn in_expr_with_multiple_element_in_list() -> Result<()> { |
1191 | | let schema = Schema::new(vec![ |
1192 | | Field::new("a", DataType::Float64, true), |
1193 | | Field::new("b", DataType::Float64, true), |
1194 | | Field::new("c", DataType::Float64, true), |
1195 | | ]); |
1196 | | let a = Float64Array::from(vec![ |
1197 | | Some(0.0), |
1198 | | Some(1.0), |
1199 | | Some(2.0), |
1200 | | Some(f64::NAN), |
1201 | | Some(-f64::NAN), |
1202 | | ]); |
1203 | | let b = Float64Array::from(vec![ |
1204 | | Some(8.0), |
1205 | | Some(1.0), |
1206 | | Some(5.0), |
1207 | | Some(f64::NAN), |
1208 | | Some(3.0), |
1209 | | ]); |
1210 | | let c = Float64Array::from(vec![ |
1211 | | Some(6.0), |
1212 | | Some(7.0), |
1213 | | None, |
1214 | | Some(5.0), |
1215 | | Some(-f64::NAN), |
1216 | | ]); |
1217 | | let col_a = col("a", &schema)?; |
1218 | | let col_b = col("b", &schema)?; |
1219 | | let col_c = col("c", &schema)?; |
1220 | | let batch = RecordBatch::try_new( |
1221 | | Arc::new(schema.clone()), |
1222 | | vec![Arc::new(a), Arc::new(b), Arc::new(c)], |
1223 | | )?; |
1224 | | |
1225 | | let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)]; |
1226 | | in_list!( |
1227 | | batch, |
1228 | | list.clone(), |
1229 | | &false, |
1230 | | vec![Some(false), Some(true), None, Some(true), Some(true)], |
1231 | | Arc::clone(&col_a), |
1232 | | &schema |
1233 | | ); |
1234 | | |
1235 | | in_list!( |
1236 | | batch, |
1237 | | list, |
1238 | | &true, |
1239 | | vec![Some(true), Some(false), None, Some(false), Some(false)], |
1240 | | Arc::clone(&col_a), |
1241 | | &schema |
1242 | | ); |
1243 | | |
1244 | | Ok(()) |
1245 | | } |
1246 | | |
1247 | | macro_rules! test_nullable { |
1248 | | ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{ |
1249 | | let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; |
1250 | | let expr = in_list(cast_expr, cast_list_exprs, &false, $SCHEMA).unwrap(); |
1251 | | let result = expr.nullable($SCHEMA)?; |
1252 | | assert_eq!($EXPECTED, result); |
1253 | | }}; |
1254 | | } |
1255 | | |
1256 | | #[test] |
1257 | | fn in_list_nullable() -> Result<()> { |
1258 | | let schema = Schema::new(vec![ |
1259 | | Field::new("c1_nullable", DataType::Int64, true), |
1260 | | Field::new("c2_non_nullable", DataType::Int64, false), |
1261 | | ]); |
1262 | | |
1263 | | let c1_nullable = col("c1_nullable", &schema)?; |
1264 | | let c2_non_nullable = col("c2_non_nullable", &schema)?; |
1265 | | |
1266 | | // static_filter has no nulls |
1267 | | let list = vec![lit(1_i64), lit(2_i64)]; |
1268 | | test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); |
1269 | | test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); |
1270 | | |
1271 | | // static_filter has nulls |
1272 | | let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)]; |
1273 | | test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); |
1274 | | test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); |
1275 | | |
1276 | | let list = vec![Arc::clone(&c1_nullable)]; |
1277 | | test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); |
1278 | | |
1279 | | let list = vec![Arc::clone(&c2_non_nullable)]; |
1280 | | test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); |
1281 | | |
1282 | | let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)]; |
1283 | | test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); |
1284 | | |
1285 | | Ok(()) |
1286 | | } |
1287 | | |
1288 | | #[test] |
1289 | | fn in_list_no_cols() -> Result<()> { |
1290 | | // test logic when the in_list expression doesn't have any columns |
1291 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); |
1292 | | let a = Int32Array::from(vec![Some(1), Some(2), None]); |
1293 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
1294 | | |
1295 | | let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))]; |
1296 | | |
1297 | | // 1 IN (1, 6) |
1298 | | let expr = lit(ScalarValue::Int32(Some(1))); |
1299 | | in_list!( |
1300 | | batch, |
1301 | | list.clone(), |
1302 | | &false, |
1303 | | // should have three outputs, as the input batch has three rows |
1304 | | vec![Some(true), Some(true), Some(true)], |
1305 | | expr, |
1306 | | &schema |
1307 | | ); |
1308 | | |
1309 | | // 2 IN (1, 6) |
1310 | | let expr = lit(ScalarValue::Int32(Some(2))); |
1311 | | in_list!( |
1312 | | batch, |
1313 | | list.clone(), |
1314 | | &false, |
1315 | | // should have three outputs, as the input batch has three rows |
1316 | | vec![Some(false), Some(false), Some(false)], |
1317 | | expr, |
1318 | | &schema |
1319 | | ); |
1320 | | |
1321 | | // NULL IN (1, 6) |
1322 | | let expr = lit(ScalarValue::Int32(None)); |
1323 | | in_list!( |
1324 | | batch, |
1325 | | list.clone(), |
1326 | | &false, |
1327 | | // should have three outputs, as the input batch has three rows |
1328 | | vec![None, None, None], |
1329 | | expr, |
1330 | | &schema |
1331 | | ); |
1332 | | |
1333 | | Ok(()) |
1334 | | } |
1335 | | |
1336 | | #[test] |
1337 | | fn in_list_utf8_with_dict_types() -> Result<()> { |
1338 | | fn dict_lit(key_type: DataType, value: &str) -> Arc<dyn PhysicalExpr> { |
1339 | | lit(ScalarValue::Dictionary( |
1340 | | Box::new(key_type), |
1341 | | Box::new(ScalarValue::new_utf8(value.to_string())), |
1342 | | )) |
1343 | | } |
1344 | | |
1345 | | fn null_dict_lit(key_type: DataType) -> Arc<dyn PhysicalExpr> { |
1346 | | lit(ScalarValue::Dictionary( |
1347 | | Box::new(key_type), |
1348 | | Box::new(ScalarValue::Utf8(None)), |
1349 | | )) |
1350 | | } |
1351 | | |
1352 | | let schema = Schema::new(vec![Field::new( |
1353 | | "a", |
1354 | | DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), |
1355 | | true, |
1356 | | )]); |
1357 | | let a: UInt16DictionaryArray = |
1358 | | vec![Some("a"), Some("d"), None].into_iter().collect(); |
1359 | | let col_a = col("a", &schema)?; |
1360 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
1361 | | |
1362 | | // expression: "a in ("a", "b")" |
1363 | | let lists = [ |
1364 | | vec![lit("a"), lit("b")], |
1365 | | vec![ |
1366 | | dict_lit(DataType::Int8, "a"), |
1367 | | dict_lit(DataType::UInt16, "b"), |
1368 | | ], |
1369 | | ]; |
1370 | | for list in lists.iter() { |
1371 | | in_list_raw!( |
1372 | | batch, |
1373 | | list.clone(), |
1374 | | &false, |
1375 | | vec![Some(true), Some(false), None], |
1376 | | Arc::clone(&col_a), |
1377 | | &schema |
1378 | | ); |
1379 | | } |
1380 | | |
1381 | | // expression: "a not in ("a", "b")" |
1382 | | for list in lists.iter() { |
1383 | | in_list_raw!( |
1384 | | batch, |
1385 | | list.clone(), |
1386 | | &true, |
1387 | | vec![Some(false), Some(true), None], |
1388 | | Arc::clone(&col_a), |
1389 | | &schema |
1390 | | ); |
1391 | | } |
1392 | | |
1393 | | // expression: "a in ("a", "b", null)" |
1394 | | let lists = [ |
1395 | | vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], |
1396 | | vec![ |
1397 | | dict_lit(DataType::Int8, "a"), |
1398 | | dict_lit(DataType::UInt16, "b"), |
1399 | | null_dict_lit(DataType::UInt16), |
1400 | | ], |
1401 | | ]; |
1402 | | for list in lists.iter() { |
1403 | | in_list_raw!( |
1404 | | batch, |
1405 | | list.clone(), |
1406 | | &false, |
1407 | | vec![Some(true), None, None], |
1408 | | Arc::clone(&col_a), |
1409 | | &schema |
1410 | | ); |
1411 | | } |
1412 | | |
1413 | | // expression: "a not in ("a", "b", null)" |
1414 | | for list in lists.iter() { |
1415 | | in_list_raw!( |
1416 | | batch, |
1417 | | list.clone(), |
1418 | | &true, |
1419 | | vec![Some(false), None, None], |
1420 | | Arc::clone(&col_a), |
1421 | | &schema |
1422 | | ); |
1423 | | } |
1424 | | |
1425 | | Ok(()) |
1426 | | } |
1427 | | } |