/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/expressions/case.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::borrow::Cow; |
19 | | use std::hash::{Hash, Hasher}; |
20 | | use std::{any::Any, sync::Arc}; |
21 | | |
22 | | use crate::expressions::try_cast; |
23 | | use crate::physical_expr::down_cast_any_ref; |
24 | | use crate::PhysicalExpr; |
25 | | |
26 | | use arrow::array::*; |
27 | | use arrow::compute::kernels::zip::zip; |
28 | | use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; |
29 | | use arrow::datatypes::{DataType, Schema}; |
30 | | use datafusion_common::cast::as_boolean_array; |
31 | | use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; |
32 | | use datafusion_expr::ColumnarValue; |
33 | | |
34 | | use super::{Column, Literal}; |
35 | | use datafusion_physical_expr_common::datum::compare_with_eq; |
36 | | use itertools::Itertools; |
37 | | |
38 | | type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>); |
39 | | |
40 | | #[derive(Debug, Hash)] |
41 | | enum EvalMethod { |
42 | | /// CASE WHEN condition THEN result |
43 | | /// [WHEN ...] |
44 | | /// [ELSE result] |
45 | | /// END |
46 | | NoExpression, |
47 | | /// CASE expression |
48 | | /// WHEN value THEN result |
49 | | /// [WHEN ...] |
50 | | /// [ELSE result] |
51 | | /// END |
52 | | WithExpression, |
53 | | /// This is a specialization for a specific use case where we can take a fast path |
54 | | /// for expressions that are infallible and can be cheaply computed for the entire |
55 | | /// record batch rather than just for the rows where the predicate is true. |
56 | | /// |
57 | | /// CASE WHEN condition THEN column [ELSE NULL] END |
58 | | InfallibleExprOrNull, |
59 | | /// This is a specialization for a specific use case where we can take a fast path |
60 | | /// if there is just one when/then pair and both the `then` and `else` expressions |
61 | | /// are literal values |
62 | | /// CASE WHEN condition THEN literal ELSE literal END |
63 | | ScalarOrScalar, |
64 | | } |
65 | | |
66 | | /// The CASE expression is similar to a series of nested if/else and there are two forms that |
67 | | /// can be used. The first form consists of a series of boolean "when" expressions with |
68 | | /// corresponding "then" expressions, and an optional "else" expression. |
69 | | /// |
70 | | /// CASE WHEN condition THEN result |
71 | | /// [WHEN ...] |
72 | | /// [ELSE result] |
73 | | /// END |
74 | | /// |
75 | | /// The second form uses a base expression and then a series of "when" clauses that match on a |
76 | | /// literal value. |
77 | | /// |
78 | | /// CASE expression |
79 | | /// WHEN value THEN result |
80 | | /// [WHEN ...] |
81 | | /// [ELSE result] |
82 | | /// END |
83 | | #[derive(Debug, Hash)] |
84 | | pub struct CaseExpr { |
85 | | /// Optional base expression that can be compared to literal values in the "when" expressions |
86 | | expr: Option<Arc<dyn PhysicalExpr>>, |
87 | | /// One or more when/then expressions |
88 | | when_then_expr: Vec<WhenThen>, |
89 | | /// Optional "else" expression |
90 | | else_expr: Option<Arc<dyn PhysicalExpr>>, |
91 | | /// Evaluation method to use |
92 | | eval_method: EvalMethod, |
93 | | } |
94 | | |
95 | | impl std::fmt::Display for CaseExpr { |
96 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
97 | 0 | write!(f, "CASE ")?; |
98 | 0 | if let Some(e) = &self.expr { |
99 | 0 | write!(f, "{e} ")?; |
100 | 0 | } |
101 | 0 | for (w, t) in &self.when_then_expr { |
102 | 0 | write!(f, "WHEN {w} THEN {t} ")?; |
103 | | } |
104 | 0 | if let Some(e) = &self.else_expr { |
105 | 0 | write!(f, "ELSE {e} ")?; |
106 | 0 | } |
107 | 0 | write!(f, "END") |
108 | 0 | } |
109 | | } |
110 | | |
111 | | /// This is a specialization for a specific use case where we can take a fast path |
112 | | /// for expressions that are infallible and can be cheaply computed for the entire |
113 | | /// record batch rather than just for the rows where the predicate is true. For now, |
114 | | /// this is limited to use with Column expressions but could potentially be used for other |
115 | | /// expressions in the future |
116 | 0 | fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool { |
117 | 0 | expr.as_any().is::<Column>() |
118 | 0 | } |
119 | | |
120 | | impl CaseExpr { |
121 | | /// Create a new CASE WHEN expression |
122 | 0 | pub fn try_new( |
123 | 0 | expr: Option<Arc<dyn PhysicalExpr>>, |
124 | 0 | when_then_expr: Vec<WhenThen>, |
125 | 0 | else_expr: Option<Arc<dyn PhysicalExpr>>, |
126 | 0 | ) -> Result<Self> { |
127 | | // normalize null literals to None in the else_expr (this already happens |
128 | | // during SQL planning, but not necessarily for other use cases) |
129 | 0 | let else_expr = match &else_expr { |
130 | 0 | Some(e) => match e.as_any().downcast_ref::<Literal>() { |
131 | 0 | Some(lit) if lit.value().is_null() => None, |
132 | 0 | _ => else_expr, |
133 | | }, |
134 | 0 | _ => else_expr, |
135 | | }; |
136 | | |
137 | 0 | if when_then_expr.is_empty() { |
138 | 0 | exec_err!("There must be at least one WHEN clause") |
139 | | } else { |
140 | 0 | let eval_method = if expr.is_some() { |
141 | 0 | EvalMethod::WithExpression |
142 | 0 | } else if when_then_expr.len() == 1 |
143 | 0 | && is_cheap_and_infallible(&(when_then_expr[0].1)) |
144 | 0 | && else_expr.is_none() |
145 | | { |
146 | 0 | EvalMethod::InfallibleExprOrNull |
147 | 0 | } else if when_then_expr.len() == 1 |
148 | 0 | && when_then_expr[0].1.as_any().is::<Literal>() |
149 | 0 | && else_expr.is_some() |
150 | 0 | && else_expr.as_ref().unwrap().as_any().is::<Literal>() |
151 | | { |
152 | 0 | EvalMethod::ScalarOrScalar |
153 | | } else { |
154 | 0 | EvalMethod::NoExpression |
155 | | }; |
156 | | |
157 | 0 | Ok(Self { |
158 | 0 | expr, |
159 | 0 | when_then_expr, |
160 | 0 | else_expr, |
161 | 0 | eval_method, |
162 | 0 | }) |
163 | | } |
164 | 0 | } |
165 | | |
166 | | /// Optional base expression that can be compared to literal values in the "when" expressions |
167 | 0 | pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> { |
168 | 0 | self.expr.as_ref() |
169 | 0 | } |
170 | | |
171 | | /// One or more when/then expressions |
172 | 0 | pub fn when_then_expr(&self) -> &[WhenThen] { |
173 | 0 | &self.when_then_expr |
174 | 0 | } |
175 | | |
176 | | /// Optional "else" expression |
177 | 0 | pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> { |
178 | 0 | self.else_expr.as_ref() |
179 | 0 | } |
180 | | } |
181 | | |
182 | | impl CaseExpr { |
183 | | /// This function evaluates the form of CASE that matches an expression to fixed values. |
184 | | /// |
185 | | /// CASE expression |
186 | | /// WHEN value THEN result |
187 | | /// [WHEN ...] |
188 | | /// [ELSE result] |
189 | | /// END |
190 | 0 | fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> { |
191 | 0 | let return_type = self.data_type(&batch.schema())?; |
192 | 0 | let expr = self.expr.as_ref().unwrap(); |
193 | 0 | let base_value = expr.evaluate(batch)?; |
194 | 0 | let base_value = base_value.into_array(batch.num_rows())?; |
195 | 0 | let base_nulls = is_null(base_value.as_ref())?; |
196 | | |
197 | | // start with nulls as default output |
198 | 0 | let mut current_value = new_null_array(&return_type, batch.num_rows()); |
199 | | // We only consider non-null values while comparing with whens |
200 | 0 | let mut remainder = not(&base_nulls)?; |
201 | 0 | for i in 0..self.when_then_expr.len() { |
202 | 0 | let when_value = self.when_then_expr[i] |
203 | 0 | .0 |
204 | 0 | .evaluate_selection(batch, &remainder)?; |
205 | 0 | let when_value = when_value.into_array(batch.num_rows())?; |
206 | | // build boolean array representing which rows match the "when" value |
207 | 0 | let when_match = compare_with_eq( |
208 | 0 | &when_value, |
209 | 0 | &base_value, |
210 | 0 | // The types of case and when expressions will be coerced to match. |
211 | 0 | // We only need to check if the base_value is nested. |
212 | 0 | base_value.data_type().is_nested(), |
213 | 0 | )?; |
214 | | // Treat nulls as false |
215 | 0 | let when_match = match when_match.null_count() { |
216 | 0 | 0 => Cow::Borrowed(&when_match), |
217 | 0 | _ => Cow::Owned(prep_null_mask_filter(&when_match)), |
218 | | }; |
219 | | // Make sure we only consider rows that have not been matched yet |
220 | 0 | let when_match = and(&when_match, &remainder)?; |
221 | | |
222 | | // When no rows available for when clause, skip then clause |
223 | 0 | if when_match.true_count() == 0 { |
224 | 0 | continue; |
225 | 0 | } |
226 | | |
227 | 0 | let then_value = self.when_then_expr[i] |
228 | 0 | .1 |
229 | 0 | .evaluate_selection(batch, &when_match)?; |
230 | | |
231 | 0 | current_value = match then_value { |
232 | | ColumnarValue::Scalar(ScalarValue::Null) => { |
233 | 0 | nullif(current_value.as_ref(), &when_match)? |
234 | | } |
235 | 0 | ColumnarValue::Scalar(then_value) => { |
236 | 0 | zip(&when_match, &then_value.to_scalar()?, ¤t_value)? |
237 | | } |
238 | 0 | ColumnarValue::Array(then_value) => { |
239 | 0 | zip(&when_match, &then_value, ¤t_value)? |
240 | | } |
241 | | }; |
242 | | |
243 | 0 | remainder = and_not(&remainder, &when_match)?; |
244 | | } |
245 | | |
246 | 0 | if let Some(e) = &self.else_expr { |
247 | | // keep `else_expr`'s data type and return type consistent |
248 | 0 | let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) |
249 | 0 | .unwrap_or_else(|_| Arc::clone(e)); |
250 | | // null and unmatched tuples should be assigned else value |
251 | 0 | remainder = or(&base_nulls, &remainder)?; |
252 | 0 | let else_ = expr |
253 | 0 | .evaluate_selection(batch, &remainder)? |
254 | 0 | .into_array(batch.num_rows())?; |
255 | 0 | current_value = zip(&remainder, &else_, ¤t_value)?; |
256 | 0 | } |
257 | | |
258 | 0 | Ok(ColumnarValue::Array(current_value)) |
259 | 0 | } |
260 | | |
261 | | /// This function evaluates the form of CASE where each WHEN expression is a boolean |
262 | | /// expression. |
263 | | /// |
264 | | /// CASE WHEN condition THEN result |
265 | | /// [WHEN ...] |
266 | | /// [ELSE result] |
267 | | /// END |
268 | 0 | fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> { |
269 | 0 | let return_type = self.data_type(&batch.schema())?; |
270 | | |
271 | | // start with nulls as default output |
272 | 0 | let mut current_value = new_null_array(&return_type, batch.num_rows()); |
273 | 0 | let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); |
274 | 0 | for i in 0..self.when_then_expr.len() { |
275 | 0 | let when_value = self.when_then_expr[i] |
276 | 0 | .0 |
277 | 0 | .evaluate_selection(batch, &remainder)?; |
278 | 0 | let when_value = when_value.into_array(batch.num_rows())?; |
279 | 0 | let when_value = as_boolean_array(&when_value).map_err(|e| { |
280 | 0 | DataFusionError::Context( |
281 | 0 | "WHEN expression did not return a BooleanArray".to_string(), |
282 | 0 | Box::new(e), |
283 | 0 | ) |
284 | 0 | })?; |
285 | | // Treat 'NULL' as false value |
286 | 0 | let when_value = match when_value.null_count() { |
287 | 0 | 0 => Cow::Borrowed(when_value), |
288 | 0 | _ => Cow::Owned(prep_null_mask_filter(when_value)), |
289 | | }; |
290 | | // Make sure we only consider rows that have not been matched yet |
291 | 0 | let when_value = and(&when_value, &remainder)?; |
292 | | |
293 | | // When no rows available for when clause, skip then clause |
294 | 0 | if when_value.true_count() == 0 { |
295 | 0 | continue; |
296 | 0 | } |
297 | | |
298 | 0 | let then_value = self.when_then_expr[i] |
299 | 0 | .1 |
300 | 0 | .evaluate_selection(batch, &when_value)?; |
301 | | |
302 | 0 | current_value = match then_value { |
303 | | ColumnarValue::Scalar(ScalarValue::Null) => { |
304 | 0 | nullif(current_value.as_ref(), &when_value)? |
305 | | } |
306 | 0 | ColumnarValue::Scalar(then_value) => { |
307 | 0 | zip(&when_value, &then_value.to_scalar()?, ¤t_value)? |
308 | | } |
309 | 0 | ColumnarValue::Array(then_value) => { |
310 | 0 | zip(&when_value, &then_value, ¤t_value)? |
311 | | } |
312 | | }; |
313 | | |
314 | | // Succeed tuples should be filtered out for short-circuit evaluation, |
315 | | // null values for the current when expr should be kept |
316 | 0 | remainder = and_not(&remainder, &when_value)?; |
317 | | } |
318 | | |
319 | 0 | if let Some(e) = &self.else_expr { |
320 | | // keep `else_expr`'s data type and return type consistent |
321 | 0 | let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) |
322 | 0 | .unwrap_or_else(|_| Arc::clone(e)); |
323 | 0 | let else_ = expr |
324 | 0 | .evaluate_selection(batch, &remainder)? |
325 | 0 | .into_array(batch.num_rows())?; |
326 | 0 | current_value = zip(&remainder, &else_, ¤t_value)?; |
327 | 0 | } |
328 | | |
329 | 0 | Ok(ColumnarValue::Array(current_value)) |
330 | 0 | } |
331 | | |
332 | | /// This function evaluates the specialized case of: |
333 | | /// |
334 | | /// CASE WHEN condition THEN column |
335 | | /// [ELSE NULL] |
336 | | /// END |
337 | | /// |
338 | | /// Note that this function is only safe to use for "then" expressions |
339 | | /// that are infallible because the expression will be evaluated for all |
340 | | /// rows in the input batch. |
341 | 0 | fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> { |
342 | 0 | let when_expr = &self.when_then_expr[0].0; |
343 | 0 | let then_expr = &self.when_then_expr[0].1; |
344 | 0 | if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? { |
345 | 0 | let bit_mask = bit_mask |
346 | 0 | .as_any() |
347 | 0 | .downcast_ref::<BooleanArray>() |
348 | 0 | .expect("predicate should evaluate to a boolean array"); |
349 | | // invert the bitmask |
350 | 0 | let bit_mask = not(bit_mask)?; |
351 | 0 | match then_expr.evaluate(batch)? { |
352 | 0 | ColumnarValue::Array(array) => { |
353 | 0 | Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) |
354 | | } |
355 | | ColumnarValue::Scalar(_) => { |
356 | 0 | internal_err!("expression did not evaluate to an array") |
357 | | } |
358 | | } |
359 | | } else { |
360 | 0 | internal_err!("predicate did not evaluate to an array") |
361 | | } |
362 | 0 | } |
363 | | |
364 | 0 | fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> { |
365 | 0 | let return_type = self.data_type(&batch.schema())?; |
366 | | |
367 | | // evaluate when expression |
368 | 0 | let when_value = self.when_then_expr[0].0.evaluate(batch)?; |
369 | 0 | let when_value = when_value.into_array(batch.num_rows())?; |
370 | 0 | let when_value = as_boolean_array(&when_value).map_err(|e| { |
371 | 0 | DataFusionError::Context( |
372 | 0 | "WHEN expression did not return a BooleanArray".to_string(), |
373 | 0 | Box::new(e), |
374 | 0 | ) |
375 | 0 | })?; |
376 | | |
377 | | // Treat 'NULL' as false value |
378 | 0 | let when_value = match when_value.null_count() { |
379 | 0 | 0 => Cow::Borrowed(when_value), |
380 | 0 | _ => Cow::Owned(prep_null_mask_filter(when_value)), |
381 | | }; |
382 | | |
383 | | // evaluate then_value |
384 | 0 | let then_value = self.when_then_expr[0].1.evaluate(batch)?; |
385 | 0 | let then_value = Scalar::new(then_value.into_array(1)?); |
386 | | |
387 | | // keep `else_expr`'s data type and return type consistent |
388 | 0 | let e = self.else_expr.as_ref().unwrap(); |
389 | 0 | let expr = try_cast(Arc::clone(e), &batch.schema(), return_type) |
390 | 0 | .unwrap_or_else(|_| Arc::clone(e)); |
391 | 0 | let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); |
392 | | |
393 | 0 | Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) |
394 | 0 | } |
395 | | } |
396 | | |
397 | | impl PhysicalExpr for CaseExpr { |
398 | | /// Return a reference to Any that can be used for down-casting |
399 | 0 | fn as_any(&self) -> &dyn Any { |
400 | 0 | self |
401 | 0 | } |
402 | | |
403 | 0 | fn data_type(&self, input_schema: &Schema) -> Result<DataType> { |
404 | 0 | // since all then results have the same data type, we can choose any one as the |
405 | 0 | // return data type except for the null. |
406 | 0 | let mut data_type = DataType::Null; |
407 | 0 | for i in 0..self.when_then_expr.len() { |
408 | 0 | data_type = self.when_then_expr[i].1.data_type(input_schema)?; |
409 | 0 | if !data_type.equals_datatype(&DataType::Null) { |
410 | 0 | break; |
411 | 0 | } |
412 | | } |
413 | | // if all then results are null, we use data type of else expr instead if possible. |
414 | 0 | if data_type.equals_datatype(&DataType::Null) { |
415 | 0 | if let Some(e) = &self.else_expr { |
416 | 0 | data_type = e.data_type(input_schema)?; |
417 | 0 | } |
418 | 0 | } |
419 | | |
420 | 0 | Ok(data_type) |
421 | 0 | } |
422 | | |
423 | 0 | fn nullable(&self, input_schema: &Schema) -> Result<bool> { |
424 | | // this expression is nullable if any of the input expressions are nullable |
425 | 0 | let then_nullable = self |
426 | 0 | .when_then_expr |
427 | 0 | .iter() |
428 | 0 | .map(|(_, t)| t.nullable(input_schema)) |
429 | 0 | .collect::<Result<Vec<_>>>()?; |
430 | 0 | if then_nullable.contains(&true) { |
431 | 0 | Ok(true) |
432 | 0 | } else if let Some(e) = &self.else_expr { |
433 | 0 | e.nullable(input_schema) |
434 | | } else { |
435 | | // CASE produces NULL if there is no `else` expr |
436 | | // (aka when none of the `when_then_exprs` match) |
437 | 0 | Ok(true) |
438 | | } |
439 | 0 | } |
440 | | |
441 | 0 | fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { |
442 | 0 | match self.eval_method { |
443 | | EvalMethod::WithExpression => { |
444 | | // this use case evaluates "expr" and then compares the values with the "when" |
445 | | // values |
446 | 0 | self.case_when_with_expr(batch) |
447 | | } |
448 | | EvalMethod::NoExpression => { |
449 | | // The "when" conditions all evaluate to boolean in this use case and can be |
450 | | // arbitrary expressions |
451 | 0 | self.case_when_no_expr(batch) |
452 | | } |
453 | | EvalMethod::InfallibleExprOrNull => { |
454 | | // Specialization for CASE WHEN expr THEN column [ELSE NULL] END |
455 | 0 | self.case_column_or_null(batch) |
456 | | } |
457 | 0 | EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), |
458 | | } |
459 | 0 | } |
460 | | |
461 | 0 | fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> { |
462 | 0 | let mut children = vec![]; |
463 | 0 | if let Some(expr) = &self.expr { |
464 | 0 | children.push(expr) |
465 | 0 | } |
466 | 0 | self.when_then_expr.iter().for_each(|(cond, value)| { |
467 | 0 | children.push(cond); |
468 | 0 | children.push(value); |
469 | 0 | }); |
470 | | |
471 | 0 | if let Some(else_expr) = &self.else_expr { |
472 | 0 | children.push(else_expr) |
473 | 0 | } |
474 | 0 | children |
475 | 0 | } |
476 | | |
477 | | // For physical CaseExpr, we do not allow modifying children size |
478 | 0 | fn with_new_children( |
479 | 0 | self: Arc<Self>, |
480 | 0 | children: Vec<Arc<dyn PhysicalExpr>>, |
481 | 0 | ) -> Result<Arc<dyn PhysicalExpr>> { |
482 | 0 | if children.len() != self.children().len() { |
483 | 0 | internal_err!("CaseExpr: Wrong number of children") |
484 | | } else { |
485 | 0 | let (expr, when_then_expr, else_expr) = |
486 | 0 | match (self.expr().is_some(), self.else_expr().is_some()) { |
487 | 0 | (true, true) => ( |
488 | 0 | Some(&children[0]), |
489 | 0 | &children[1..children.len() - 1], |
490 | 0 | Some(&children[children.len() - 1]), |
491 | 0 | ), |
492 | | (true, false) => { |
493 | 0 | (Some(&children[0]), &children[1..children.len()], None) |
494 | | } |
495 | 0 | (false, true) => ( |
496 | 0 | None, |
497 | 0 | &children[0..children.len() - 1], |
498 | 0 | Some(&children[children.len() - 1]), |
499 | 0 | ), |
500 | 0 | (false, false) => (None, &children[0..children.len()], None), |
501 | | }; |
502 | 0 | Ok(Arc::new(CaseExpr::try_new( |
503 | 0 | expr.cloned(), |
504 | 0 | when_then_expr.iter().cloned().tuples().collect(), |
505 | 0 | else_expr.cloned(), |
506 | 0 | )?)) |
507 | | } |
508 | 0 | } |
509 | | |
510 | 0 | fn dyn_hash(&self, state: &mut dyn Hasher) { |
511 | 0 | let mut s = state; |
512 | 0 | self.hash(&mut s); |
513 | 0 | } |
514 | | } |
515 | | |
516 | | impl PartialEq<dyn Any> for CaseExpr { |
517 | 0 | fn eq(&self, other: &dyn Any) -> bool { |
518 | 0 | down_cast_any_ref(other) |
519 | 0 | .downcast_ref::<Self>() |
520 | 0 | .map(|x| { |
521 | 0 | let expr_eq = match (&self.expr, &x.expr) { |
522 | 0 | (Some(expr1), Some(expr2)) => expr1.eq(expr2), |
523 | 0 | (None, None) => true, |
524 | 0 | _ => false, |
525 | | }; |
526 | 0 | let else_expr_eq = match (&self.else_expr, &x.else_expr) { |
527 | 0 | (Some(expr1), Some(expr2)) => expr1.eq(expr2), |
528 | 0 | (None, None) => true, |
529 | 0 | _ => false, |
530 | | }; |
531 | 0 | expr_eq |
532 | 0 | && else_expr_eq |
533 | 0 | && self.when_then_expr.len() == x.when_then_expr.len() |
534 | 0 | && self.when_then_expr.iter().zip(x.when_then_expr.iter()).all( |
535 | 0 | |((when1, then1), (when2, then2))| { |
536 | 0 | when1.eq(when2) && then1.eq(then2) |
537 | 0 | }, |
538 | 0 | ) |
539 | 0 | }) |
540 | 0 | .unwrap_or(false) |
541 | 0 | } |
542 | | } |
543 | | |
544 | | /// Create a CASE expression |
545 | 0 | pub fn case( |
546 | 0 | expr: Option<Arc<dyn PhysicalExpr>>, |
547 | 0 | when_thens: Vec<WhenThen>, |
548 | 0 | else_expr: Option<Arc<dyn PhysicalExpr>>, |
549 | 0 | ) -> Result<Arc<dyn PhysicalExpr>> { |
550 | 0 | Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) |
551 | 0 | } |
552 | | |
553 | | #[cfg(test)] |
554 | | mod tests { |
555 | | use super::*; |
556 | | |
557 | | use crate::expressions::{binary, cast, col, lit, BinaryExpr}; |
558 | | use arrow::buffer::Buffer; |
559 | | use arrow::datatypes::DataType::Float64; |
560 | | use arrow::datatypes::*; |
561 | | use datafusion_common::cast::{as_float64_array, as_int32_array}; |
562 | | use datafusion_common::plan_err; |
563 | | use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
564 | | use datafusion_expr::type_coercion::binary::comparison_coercion; |
565 | | use datafusion_expr::Operator; |
566 | | |
567 | | #[test] |
568 | | fn case_with_expr() -> Result<()> { |
569 | | let batch = case_test_batch()?; |
570 | | let schema = batch.schema(); |
571 | | |
572 | | // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END |
573 | | let when1 = lit("foo"); |
574 | | let then1 = lit(123i32); |
575 | | let when2 = lit("bar"); |
576 | | let then2 = lit(456i32); |
577 | | |
578 | | let expr = generate_case_when_with_type_coercion( |
579 | | Some(col("a", &schema)?), |
580 | | vec![(when1, then1), (when2, then2)], |
581 | | None, |
582 | | schema.as_ref(), |
583 | | )?; |
584 | | let result = expr |
585 | | .evaluate(&batch)? |
586 | | .into_array(batch.num_rows()) |
587 | | .expect("Failed to convert to array"); |
588 | | let result = as_int32_array(&result)?; |
589 | | |
590 | | let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); |
591 | | |
592 | | assert_eq!(expected, result); |
593 | | |
594 | | Ok(()) |
595 | | } |
596 | | |
597 | | #[test] |
598 | | fn case_with_expr_else() -> Result<()> { |
599 | | let batch = case_test_batch()?; |
600 | | let schema = batch.schema(); |
601 | | |
602 | | // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END |
603 | | let when1 = lit("foo"); |
604 | | let then1 = lit(123i32); |
605 | | let when2 = lit("bar"); |
606 | | let then2 = lit(456i32); |
607 | | let else_value = lit(999i32); |
608 | | |
609 | | let expr = generate_case_when_with_type_coercion( |
610 | | Some(col("a", &schema)?), |
611 | | vec![(when1, then1), (when2, then2)], |
612 | | Some(else_value), |
613 | | schema.as_ref(), |
614 | | )?; |
615 | | let result = expr |
616 | | .evaluate(&batch)? |
617 | | .into_array(batch.num_rows()) |
618 | | .expect("Failed to convert to array"); |
619 | | let result = as_int32_array(&result)?; |
620 | | |
621 | | let expected = |
622 | | &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); |
623 | | |
624 | | assert_eq!(expected, result); |
625 | | |
626 | | Ok(()) |
627 | | } |
628 | | |
629 | | #[test] |
630 | | fn case_with_expr_divide_by_zero() -> Result<()> { |
631 | | let batch = case_test_batch1()?; |
632 | | let schema = batch.schema(); |
633 | | |
634 | | // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64) END |
635 | | let when1 = lit(0i32); |
636 | | let then1 = lit(ScalarValue::Float64(None)); |
637 | | let else_value = binary( |
638 | | lit(25.0f64), |
639 | | Operator::Divide, |
640 | | cast(col("a", &schema)?, &batch.schema(), Float64)?, |
641 | | &batch.schema(), |
642 | | )?; |
643 | | |
644 | | let expr = generate_case_when_with_type_coercion( |
645 | | Some(col("a", &schema)?), |
646 | | vec![(when1, then1)], |
647 | | Some(else_value), |
648 | | schema.as_ref(), |
649 | | )?; |
650 | | let result = expr |
651 | | .evaluate(&batch)? |
652 | | .into_array(batch.num_rows()) |
653 | | .expect("Failed to convert to array"); |
654 | | let result = |
655 | | as_float64_array(&result).expect("failed to downcast to Float64Array"); |
656 | | |
657 | | let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); |
658 | | |
659 | | assert_eq!(expected, result); |
660 | | |
661 | | Ok(()) |
662 | | } |
663 | | |
664 | | #[test] |
665 | | fn case_without_expr() -> Result<()> { |
666 | | let batch = case_test_batch()?; |
667 | | let schema = batch.schema(); |
668 | | |
669 | | // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END |
670 | | let when1 = binary( |
671 | | col("a", &schema)?, |
672 | | Operator::Eq, |
673 | | lit("foo"), |
674 | | &batch.schema(), |
675 | | )?; |
676 | | let then1 = lit(123i32); |
677 | | let when2 = binary( |
678 | | col("a", &schema)?, |
679 | | Operator::Eq, |
680 | | lit("bar"), |
681 | | &batch.schema(), |
682 | | )?; |
683 | | let then2 = lit(456i32); |
684 | | |
685 | | let expr = generate_case_when_with_type_coercion( |
686 | | None, |
687 | | vec![(when1, then1), (when2, then2)], |
688 | | None, |
689 | | schema.as_ref(), |
690 | | )?; |
691 | | let result = expr |
692 | | .evaluate(&batch)? |
693 | | .into_array(batch.num_rows()) |
694 | | .expect("Failed to convert to array"); |
695 | | let result = as_int32_array(&result)?; |
696 | | |
697 | | let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); |
698 | | |
699 | | assert_eq!(expected, result); |
700 | | |
701 | | Ok(()) |
702 | | } |
703 | | |
704 | | #[test] |
705 | | fn case_with_expr_when_null() -> Result<()> { |
706 | | let batch = case_test_batch()?; |
707 | | let schema = batch.schema(); |
708 | | |
709 | | // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END |
710 | | let when1 = lit(ScalarValue::Utf8(None)); |
711 | | let then1 = lit(0i32); |
712 | | let when2 = col("a", &schema)?; |
713 | | let then2 = lit(123i32); |
714 | | let else_value = lit(999i32); |
715 | | |
716 | | let expr = generate_case_when_with_type_coercion( |
717 | | Some(col("a", &schema)?), |
718 | | vec![(when1, then1), (when2, then2)], |
719 | | Some(else_value), |
720 | | schema.as_ref(), |
721 | | )?; |
722 | | let result = expr |
723 | | .evaluate(&batch)? |
724 | | .into_array(batch.num_rows()) |
725 | | .expect("Failed to convert to array"); |
726 | | let result = as_int32_array(&result)?; |
727 | | |
728 | | let expected = |
729 | | &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]); |
730 | | |
731 | | assert_eq!(expected, result); |
732 | | |
733 | | Ok(()) |
734 | | } |
735 | | |
736 | | #[test] |
737 | | fn case_without_expr_divide_by_zero() -> Result<()> { |
738 | | let batch = case_test_batch1()?; |
739 | | let schema = batch.schema(); |
740 | | |
741 | | // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END |
742 | | let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?; |
743 | | let then1 = binary( |
744 | | lit(25.0f64), |
745 | | Operator::Divide, |
746 | | cast(col("a", &schema)?, &batch.schema(), Float64)?, |
747 | | &batch.schema(), |
748 | | )?; |
749 | | let x = lit(ScalarValue::Float64(None)); |
750 | | |
751 | | let expr = generate_case_when_with_type_coercion( |
752 | | None, |
753 | | vec![(when1, then1)], |
754 | | Some(x), |
755 | | schema.as_ref(), |
756 | | )?; |
757 | | let result = expr |
758 | | .evaluate(&batch)? |
759 | | .into_array(batch.num_rows()) |
760 | | .expect("Failed to convert to array"); |
761 | | let result = |
762 | | as_float64_array(&result).expect("failed to downcast to Float64Array"); |
763 | | |
764 | | let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); |
765 | | |
766 | | assert_eq!(expected, result); |
767 | | |
768 | | Ok(()) |
769 | | } |
770 | | |
771 | | fn case_test_batch1() -> Result<RecordBatch> { |
772 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); |
773 | | let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]); |
774 | | let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; |
775 | | Ok(batch) |
776 | | } |
777 | | |
778 | | #[test] |
779 | | fn case_without_expr_else() -> Result<()> { |
780 | | let batch = case_test_batch()?; |
781 | | let schema = batch.schema(); |
782 | | |
783 | | // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END |
784 | | let when1 = binary( |
785 | | col("a", &schema)?, |
786 | | Operator::Eq, |
787 | | lit("foo"), |
788 | | &batch.schema(), |
789 | | )?; |
790 | | let then1 = lit(123i32); |
791 | | let when2 = binary( |
792 | | col("a", &schema)?, |
793 | | Operator::Eq, |
794 | | lit("bar"), |
795 | | &batch.schema(), |
796 | | )?; |
797 | | let then2 = lit(456i32); |
798 | | let else_value = lit(999i32); |
799 | | |
800 | | let expr = generate_case_when_with_type_coercion( |
801 | | None, |
802 | | vec![(when1, then1), (when2, then2)], |
803 | | Some(else_value), |
804 | | schema.as_ref(), |
805 | | )?; |
806 | | let result = expr |
807 | | .evaluate(&batch)? |
808 | | .into_array(batch.num_rows()) |
809 | | .expect("Failed to convert to array"); |
810 | | let result = as_int32_array(&result)?; |
811 | | |
812 | | let expected = |
813 | | &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); |
814 | | |
815 | | assert_eq!(expected, result); |
816 | | |
817 | | Ok(()) |
818 | | } |
819 | | |
820 | | #[test] |
821 | | fn case_with_type_cast() -> Result<()> { |
822 | | let batch = case_test_batch()?; |
823 | | let schema = batch.schema(); |
824 | | |
825 | | // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END |
826 | | let when = binary( |
827 | | col("a", &schema)?, |
828 | | Operator::Eq, |
829 | | lit("foo"), |
830 | | &batch.schema(), |
831 | | )?; |
832 | | let then = lit(123.3f64); |
833 | | let else_value = lit(999i32); |
834 | | |
835 | | let expr = generate_case_when_with_type_coercion( |
836 | | None, |
837 | | vec![(when, then)], |
838 | | Some(else_value), |
839 | | schema.as_ref(), |
840 | | )?; |
841 | | let result = expr |
842 | | .evaluate(&batch)? |
843 | | .into_array(batch.num_rows()) |
844 | | .expect("Failed to convert to array"); |
845 | | let result = |
846 | | as_float64_array(&result).expect("failed to downcast to Float64Array"); |
847 | | |
848 | | let expected = |
849 | | &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]); |
850 | | |
851 | | assert_eq!(expected, result); |
852 | | |
853 | | Ok(()) |
854 | | } |
855 | | |
856 | | #[test] |
857 | | fn case_with_matches_and_nulls() -> Result<()> { |
858 | | let batch = case_test_batch_nulls()?; |
859 | | let schema = batch.schema(); |
860 | | |
861 | | // SELECT CASE WHEN load4 = 1.77 THEN load4 END |
862 | | let when = binary( |
863 | | col("load4", &schema)?, |
864 | | Operator::Eq, |
865 | | lit(1.77f64), |
866 | | &batch.schema(), |
867 | | )?; |
868 | | let then = col("load4", &schema)?; |
869 | | |
870 | | let expr = generate_case_when_with_type_coercion( |
871 | | None, |
872 | | vec![(when, then)], |
873 | | None, |
874 | | schema.as_ref(), |
875 | | )?; |
876 | | let result = expr |
877 | | .evaluate(&batch)? |
878 | | .into_array(batch.num_rows()) |
879 | | .expect("Failed to convert to array"); |
880 | | let result = |
881 | | as_float64_array(&result).expect("failed to downcast to Float64Array"); |
882 | | |
883 | | let expected = |
884 | | &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); |
885 | | |
886 | | assert_eq!(expected, result); |
887 | | |
888 | | Ok(()) |
889 | | } |
890 | | |
891 | | #[test] |
892 | | fn case_expr_matches_and_nulls() -> Result<()> { |
893 | | let batch = case_test_batch_nulls()?; |
894 | | let schema = batch.schema(); |
895 | | |
896 | | // SELECT CASE load4 WHEN 1.77 THEN load4 END |
897 | | let expr = col("load4", &schema)?; |
898 | | let when = lit(1.77f64); |
899 | | let then = col("load4", &schema)?; |
900 | | |
901 | | let expr = generate_case_when_with_type_coercion( |
902 | | Some(expr), |
903 | | vec![(when, then)], |
904 | | None, |
905 | | schema.as_ref(), |
906 | | )?; |
907 | | let result = expr |
908 | | .evaluate(&batch)? |
909 | | .into_array(batch.num_rows()) |
910 | | .expect("Failed to convert to array"); |
911 | | let result = |
912 | | as_float64_array(&result).expect("failed to downcast to Float64Array"); |
913 | | |
914 | | let expected = |
915 | | &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); |
916 | | |
917 | | assert_eq!(expected, result); |
918 | | |
919 | | Ok(()) |
920 | | } |
921 | | |
922 | | fn case_test_batch() -> Result<RecordBatch> { |
923 | | let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); |
924 | | let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); |
925 | | let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; |
926 | | Ok(batch) |
927 | | } |
928 | | |
929 | | // Construct an array that has several NULL values whose |
930 | | // underlying buffer actually matches the where expr predicate |
931 | | fn case_test_batch_nulls() -> Result<RecordBatch> { |
932 | | let load4: Float64Array = vec![ |
933 | | Some(1.77), // 1.77 |
934 | | Some(1.77), // null <-- same value, but will be set to null |
935 | | Some(1.77), // null <-- same value, but will be set to null |
936 | | Some(1.78), // 1.78 |
937 | | None, // null |
938 | | Some(1.77), // 1.77 |
939 | | ] |
940 | | .into_iter() |
941 | | .collect(); |
942 | | |
943 | | //let valid_array = vec![true, false, false, true, false, tru |
944 | | let null_buffer = Buffer::from([0b00101001u8]); |
945 | | let load4 = load4 |
946 | | .into_data() |
947 | | .into_builder() |
948 | | .null_bit_buffer(Some(null_buffer)) |
949 | | .build() |
950 | | .unwrap(); |
951 | | let load4: Float64Array = load4.into(); |
952 | | |
953 | | let batch = |
954 | | RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?; |
955 | | Ok(batch) |
956 | | } |
957 | | |
958 | | #[test] |
959 | | fn case_test_incompatible() -> Result<()> { |
960 | | // 1 then is int64 |
961 | | // 2 then is boolean |
962 | | let batch = case_test_batch()?; |
963 | | let schema = batch.schema(); |
964 | | |
965 | | // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END |
966 | | let when1 = binary( |
967 | | col("a", &schema)?, |
968 | | Operator::Eq, |
969 | | lit("foo"), |
970 | | &batch.schema(), |
971 | | )?; |
972 | | let then1 = lit(123i32); |
973 | | let when2 = binary( |
974 | | col("a", &schema)?, |
975 | | Operator::Eq, |
976 | | lit("bar"), |
977 | | &batch.schema(), |
978 | | )?; |
979 | | let then2 = lit(true); |
980 | | |
981 | | let expr = generate_case_when_with_type_coercion( |
982 | | None, |
983 | | vec![(when1, then1), (when2, then2)], |
984 | | None, |
985 | | schema.as_ref(), |
986 | | ); |
987 | | assert!(expr.is_err()); |
988 | | |
989 | | // then 1 is int32 |
990 | | // then 2 is int64 |
991 | | // else is float |
992 | | // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END |
993 | | let when1 = binary( |
994 | | col("a", &schema)?, |
995 | | Operator::Eq, |
996 | | lit("foo"), |
997 | | &batch.schema(), |
998 | | )?; |
999 | | let then1 = lit(123i32); |
1000 | | let when2 = binary( |
1001 | | col("a", &schema)?, |
1002 | | Operator::Eq, |
1003 | | lit("bar"), |
1004 | | &batch.schema(), |
1005 | | )?; |
1006 | | let then2 = lit(456i64); |
1007 | | let else_expr = lit(1.23f64); |
1008 | | |
1009 | | let expr = generate_case_when_with_type_coercion( |
1010 | | None, |
1011 | | vec![(when1, then1), (when2, then2)], |
1012 | | Some(else_expr), |
1013 | | schema.as_ref(), |
1014 | | ); |
1015 | | assert!(expr.is_ok()); |
1016 | | let result_type = expr.unwrap().data_type(schema.as_ref())?; |
1017 | | assert_eq!(Float64, result_type); |
1018 | | Ok(()) |
1019 | | } |
1020 | | |
1021 | | #[test] |
1022 | | fn case_eq() -> Result<()> { |
1023 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); |
1024 | | |
1025 | | let when1 = lit("foo"); |
1026 | | let then1 = lit(123i32); |
1027 | | let when2 = lit("bar"); |
1028 | | let then2 = lit(456i32); |
1029 | | let else_value = lit(999i32); |
1030 | | |
1031 | | let expr1 = generate_case_when_with_type_coercion( |
1032 | | Some(col("a", &schema)?), |
1033 | | vec![ |
1034 | | (Arc::clone(&when1), Arc::clone(&then1)), |
1035 | | (Arc::clone(&when2), Arc::clone(&then2)), |
1036 | | ], |
1037 | | Some(Arc::clone(&else_value)), |
1038 | | &schema, |
1039 | | )?; |
1040 | | |
1041 | | let expr2 = generate_case_when_with_type_coercion( |
1042 | | Some(col("a", &schema)?), |
1043 | | vec![ |
1044 | | (Arc::clone(&when1), Arc::clone(&then1)), |
1045 | | (Arc::clone(&when2), Arc::clone(&then2)), |
1046 | | ], |
1047 | | Some(Arc::clone(&else_value)), |
1048 | | &schema, |
1049 | | )?; |
1050 | | |
1051 | | let expr3 = generate_case_when_with_type_coercion( |
1052 | | Some(col("a", &schema)?), |
1053 | | vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], |
1054 | | None, |
1055 | | &schema, |
1056 | | )?; |
1057 | | |
1058 | | let expr4 = generate_case_when_with_type_coercion( |
1059 | | Some(col("a", &schema)?), |
1060 | | vec![(when1, then1)], |
1061 | | Some(else_value), |
1062 | | &schema, |
1063 | | )?; |
1064 | | |
1065 | | assert!(expr1.eq(&expr2)); |
1066 | | assert!(expr2.eq(&expr1)); |
1067 | | |
1068 | | assert!(expr2.ne(&expr3)); |
1069 | | assert!(expr3.ne(&expr2)); |
1070 | | |
1071 | | assert!(expr1.ne(&expr4)); |
1072 | | assert!(expr4.ne(&expr1)); |
1073 | | |
1074 | | Ok(()) |
1075 | | } |
1076 | | |
1077 | | #[test] |
1078 | | fn case_transform() -> Result<()> { |
1079 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); |
1080 | | |
1081 | | let when1 = lit("foo"); |
1082 | | let then1 = lit(123i32); |
1083 | | let when2 = lit("bar"); |
1084 | | let then2 = lit(456i32); |
1085 | | let else_value = lit(999i32); |
1086 | | |
1087 | | let expr = generate_case_when_with_type_coercion( |
1088 | | Some(col("a", &schema)?), |
1089 | | vec![ |
1090 | | (Arc::clone(&when1), Arc::clone(&then1)), |
1091 | | (Arc::clone(&when2), Arc::clone(&then2)), |
1092 | | ], |
1093 | | Some(Arc::clone(&else_value)), |
1094 | | &schema, |
1095 | | )?; |
1096 | | |
1097 | | let expr2 = Arc::clone(&expr) |
1098 | | .transform(|e| { |
1099 | | let transformed = |
1100 | | match e.as_any().downcast_ref::<crate::expressions::Literal>() { |
1101 | | Some(lit_value) => match lit_value.value() { |
1102 | | ScalarValue::Utf8(Some(str_value)) => { |
1103 | | Some(lit(str_value.to_uppercase())) |
1104 | | } |
1105 | | _ => None, |
1106 | | }, |
1107 | | _ => None, |
1108 | | }; |
1109 | | Ok(if let Some(transformed) = transformed { |
1110 | | Transformed::yes(transformed) |
1111 | | } else { |
1112 | | Transformed::no(e) |
1113 | | }) |
1114 | | }) |
1115 | | .data() |
1116 | | .unwrap(); |
1117 | | |
1118 | | let expr3 = Arc::clone(&expr) |
1119 | | .transform_down(|e| { |
1120 | | let transformed = |
1121 | | match e.as_any().downcast_ref::<crate::expressions::Literal>() { |
1122 | | Some(lit_value) => match lit_value.value() { |
1123 | | ScalarValue::Utf8(Some(str_value)) => { |
1124 | | Some(lit(str_value.to_uppercase())) |
1125 | | } |
1126 | | _ => None, |
1127 | | }, |
1128 | | _ => None, |
1129 | | }; |
1130 | | Ok(if let Some(transformed) = transformed { |
1131 | | Transformed::yes(transformed) |
1132 | | } else { |
1133 | | Transformed::no(e) |
1134 | | }) |
1135 | | }) |
1136 | | .data() |
1137 | | .unwrap(); |
1138 | | |
1139 | | assert!(expr.ne(&expr2)); |
1140 | | assert!(expr2.eq(&expr3)); |
1141 | | |
1142 | | Ok(()) |
1143 | | } |
1144 | | |
1145 | | #[test] |
1146 | | fn test_column_or_null_specialization() -> Result<()> { |
1147 | | // create input data |
1148 | | let mut c1 = Int32Builder::new(); |
1149 | | let mut c2 = StringBuilder::new(); |
1150 | | for i in 0..1000 { |
1151 | | c1.append_value(i); |
1152 | | if i % 7 == 0 { |
1153 | | c2.append_null(); |
1154 | | } else { |
1155 | | c2.append_value(format!("string {i}")); |
1156 | | } |
1157 | | } |
1158 | | let c1 = Arc::new(c1.finish()); |
1159 | | let c2 = Arc::new(c2.finish()); |
1160 | | let schema = Schema::new(vec![ |
1161 | | Field::new("c1", DataType::Int32, true), |
1162 | | Field::new("c2", DataType::Utf8, true), |
1163 | | ]); |
1164 | | let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); |
1165 | | |
1166 | | // CaseWhenExprOrNull should produce same results as CaseExpr |
1167 | | let predicate = Arc::new(BinaryExpr::new( |
1168 | | make_col("c1", 0), |
1169 | | Operator::LtEq, |
1170 | | make_lit_i32(250), |
1171 | | )); |
1172 | | let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; |
1173 | | assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); |
1174 | | match expr.evaluate(&batch)? { |
1175 | | ColumnarValue::Array(array) => { |
1176 | | assert_eq!(1000, array.len()); |
1177 | | assert_eq!(785, array.null_count()); |
1178 | | } |
1179 | | _ => unreachable!(), |
1180 | | } |
1181 | | Ok(()) |
1182 | | } |
1183 | | |
1184 | | fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> { |
1185 | | Arc::new(Column::new(name, index)) |
1186 | | } |
1187 | | |
1188 | | fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> { |
1189 | | Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) |
1190 | | } |
1191 | | |
1192 | | fn generate_case_when_with_type_coercion( |
1193 | | expr: Option<Arc<dyn PhysicalExpr>>, |
1194 | | when_thens: Vec<WhenThen>, |
1195 | | else_expr: Option<Arc<dyn PhysicalExpr>>, |
1196 | | input_schema: &Schema, |
1197 | | ) -> Result<Arc<dyn PhysicalExpr>> { |
1198 | | let coerce_type = |
1199 | | get_case_common_type(&when_thens, else_expr.clone(), input_schema); |
1200 | | let (when_thens, else_expr) = match coerce_type { |
1201 | | None => plan_err!( |
1202 | | "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression" |
1203 | | ), |
1204 | | Some(data_type) => { |
1205 | | // cast then expr |
1206 | | let left = when_thens |
1207 | | .into_iter() |
1208 | | .map(|(when, then)| { |
1209 | | let then = try_cast(then, input_schema, data_type.clone())?; |
1210 | | Ok((when, then)) |
1211 | | }) |
1212 | | .collect::<Result<Vec<_>>>()?; |
1213 | | let right = match else_expr { |
1214 | | None => None, |
1215 | | Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), |
1216 | | }; |
1217 | | |
1218 | | Ok((left, right)) |
1219 | | } |
1220 | | }?; |
1221 | | case(expr, when_thens, else_expr) |
1222 | | } |
1223 | | |
1224 | | fn get_case_common_type( |
1225 | | when_thens: &[WhenThen], |
1226 | | else_expr: Option<Arc<dyn PhysicalExpr>>, |
1227 | | input_schema: &Schema, |
1228 | | ) -> Option<DataType> { |
1229 | | let thens_type = when_thens |
1230 | | .iter() |
1231 | | .map(|when_then| { |
1232 | | let data_type = &when_then.1.data_type(input_schema).unwrap(); |
1233 | | data_type.clone() |
1234 | | }) |
1235 | | .collect::<Vec<_>>(); |
1236 | | let else_type = match else_expr { |
1237 | | None => { |
1238 | | // case when then exprs must have one then value |
1239 | | thens_type[0].clone() |
1240 | | } |
1241 | | Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), |
1242 | | }; |
1243 | | thens_type |
1244 | | .iter() |
1245 | | .try_fold(else_type, |left_type, right_type| { |
1246 | | // TODO: now just use the `equal` coercion rule for case when. If find the issue, and |
1247 | | // refactor again. |
1248 | | comparison_coercion(&left_type, right_type) |
1249 | | }) |
1250 | | } |
1251 | | } |