Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/expressions/case.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::borrow::Cow;
19
use std::hash::{Hash, Hasher};
20
use std::{any::Any, sync::Arc};
21
22
use crate::expressions::try_cast;
23
use crate::physical_expr::down_cast_any_ref;
24
use crate::PhysicalExpr;
25
26
use arrow::array::*;
27
use arrow::compute::kernels::zip::zip;
28
use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
29
use arrow::datatypes::{DataType, Schema};
30
use datafusion_common::cast::as_boolean_array;
31
use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
32
use datafusion_expr::ColumnarValue;
33
34
use super::{Column, Literal};
35
use datafusion_physical_expr_common::datum::compare_with_eq;
36
use itertools::Itertools;
37
38
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
39
40
#[derive(Debug, Hash)]
41
enum EvalMethod {
42
    /// CASE WHEN condition THEN result
43
    ///      [WHEN ...]
44
    ///      [ELSE result]
45
    /// END
46
    NoExpression,
47
    /// CASE expression
48
    ///     WHEN value THEN result
49
    ///     [WHEN ...]
50
    ///     [ELSE result]
51
    /// END
52
    WithExpression,
53
    /// This is a specialization for a specific use case where we can take a fast path
54
    /// for expressions that are infallible and can be cheaply computed for the entire
55
    /// record batch rather than just for the rows where the predicate is true.
56
    ///
57
    /// CASE WHEN condition THEN column [ELSE NULL] END
58
    InfallibleExprOrNull,
59
    /// This is a specialization for a specific use case where we can take a fast path
60
    /// if there is just one when/then pair and both the `then` and `else` expressions
61
    /// are literal values
62
    /// CASE WHEN condition THEN literal ELSE literal END
63
    ScalarOrScalar,
64
}
65
66
/// The CASE expression is similar to a series of nested if/else and there are two forms that
67
/// can be used. The first form consists of a series of boolean "when" expressions with
68
/// corresponding "then" expressions, and an optional "else" expression.
69
///
70
/// CASE WHEN condition THEN result
71
///      [WHEN ...]
72
///      [ELSE result]
73
/// END
74
///
75
/// The second form uses a base expression and then a series of "when" clauses that match on a
76
/// literal value.
77
///
78
/// CASE expression
79
///     WHEN value THEN result
80
///     [WHEN ...]
81
///     [ELSE result]
82
/// END
83
#[derive(Debug, Hash)]
84
pub struct CaseExpr {
85
    /// Optional base expression that can be compared to literal values in the "when" expressions
86
    expr: Option<Arc<dyn PhysicalExpr>>,
87
    /// One or more when/then expressions
88
    when_then_expr: Vec<WhenThen>,
89
    /// Optional "else" expression
90
    else_expr: Option<Arc<dyn PhysicalExpr>>,
91
    /// Evaluation method to use
92
    eval_method: EvalMethod,
93
}
94
95
impl std::fmt::Display for CaseExpr {
96
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
97
0
        write!(f, "CASE ")?;
98
0
        if let Some(e) = &self.expr {
99
0
            write!(f, "{e} ")?;
100
0
        }
101
0
        for (w, t) in &self.when_then_expr {
102
0
            write!(f, "WHEN {w} THEN {t} ")?;
103
        }
104
0
        if let Some(e) = &self.else_expr {
105
0
            write!(f, "ELSE {e} ")?;
106
0
        }
107
0
        write!(f, "END")
108
0
    }
109
}
110
111
/// This is a specialization for a specific use case where we can take a fast path
112
/// for expressions that are infallible and can be cheaply computed for the entire
113
/// record batch rather than just for the rows where the predicate is true. For now,
114
/// this is limited to use with Column expressions but could potentially be used for other
115
/// expressions in the future
116
0
fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
117
0
    expr.as_any().is::<Column>()
118
0
}
119
120
impl CaseExpr {
121
    /// Create a new CASE WHEN expression
122
0
    pub fn try_new(
123
0
        expr: Option<Arc<dyn PhysicalExpr>>,
124
0
        when_then_expr: Vec<WhenThen>,
125
0
        else_expr: Option<Arc<dyn PhysicalExpr>>,
126
0
    ) -> Result<Self> {
127
        // normalize null literals to None in the else_expr (this already happens
128
        // during SQL planning, but not necessarily for other use cases)
129
0
        let else_expr = match &else_expr {
130
0
            Some(e) => match e.as_any().downcast_ref::<Literal>() {
131
0
                Some(lit) if lit.value().is_null() => None,
132
0
                _ => else_expr,
133
            },
134
0
            _ => else_expr,
135
        };
136
137
0
        if when_then_expr.is_empty() {
138
0
            exec_err!("There must be at least one WHEN clause")
139
        } else {
140
0
            let eval_method = if expr.is_some() {
141
0
                EvalMethod::WithExpression
142
0
            } else if when_then_expr.len() == 1
143
0
                && is_cheap_and_infallible(&(when_then_expr[0].1))
144
0
                && else_expr.is_none()
145
            {
146
0
                EvalMethod::InfallibleExprOrNull
147
0
            } else if when_then_expr.len() == 1
148
0
                && when_then_expr[0].1.as_any().is::<Literal>()
149
0
                && else_expr.is_some()
150
0
                && else_expr.as_ref().unwrap().as_any().is::<Literal>()
151
            {
152
0
                EvalMethod::ScalarOrScalar
153
            } else {
154
0
                EvalMethod::NoExpression
155
            };
156
157
0
            Ok(Self {
158
0
                expr,
159
0
                when_then_expr,
160
0
                else_expr,
161
0
                eval_method,
162
0
            })
163
        }
164
0
    }
165
166
    /// Optional base expression that can be compared to literal values in the "when" expressions
167
0
    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
168
0
        self.expr.as_ref()
169
0
    }
170
171
    /// One or more when/then expressions
172
0
    pub fn when_then_expr(&self) -> &[WhenThen] {
173
0
        &self.when_then_expr
174
0
    }
175
176
    /// Optional "else" expression
177
0
    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
178
0
        self.else_expr.as_ref()
179
0
    }
180
}
181
182
impl CaseExpr {
183
    /// This function evaluates the form of CASE that matches an expression to fixed values.
184
    ///
185
    /// CASE expression
186
    ///     WHEN value THEN result
187
    ///     [WHEN ...]
188
    ///     [ELSE result]
189
    /// END
190
0
    fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
191
0
        let return_type = self.data_type(&batch.schema())?;
192
0
        let expr = self.expr.as_ref().unwrap();
193
0
        let base_value = expr.evaluate(batch)?;
194
0
        let base_value = base_value.into_array(batch.num_rows())?;
195
0
        let base_nulls = is_null(base_value.as_ref())?;
196
197
        // start with nulls as default output
198
0
        let mut current_value = new_null_array(&return_type, batch.num_rows());
199
        // We only consider non-null values while comparing with whens
200
0
        let mut remainder = not(&base_nulls)?;
201
0
        for i in 0..self.when_then_expr.len() {
202
0
            let when_value = self.when_then_expr[i]
203
0
                .0
204
0
                .evaluate_selection(batch, &remainder)?;
205
0
            let when_value = when_value.into_array(batch.num_rows())?;
206
            // build boolean array representing which rows match the "when" value
207
0
            let when_match = compare_with_eq(
208
0
                &when_value,
209
0
                &base_value,
210
0
                // The types of case and when expressions will be coerced to match.
211
0
                // We only need to check if the base_value is nested.
212
0
                base_value.data_type().is_nested(),
213
0
            )?;
214
            // Treat nulls as false
215
0
            let when_match = match when_match.null_count() {
216
0
                0 => Cow::Borrowed(&when_match),
217
0
                _ => Cow::Owned(prep_null_mask_filter(&when_match)),
218
            };
219
            // Make sure we only consider rows that have not been matched yet
220
0
            let when_match = and(&when_match, &remainder)?;
221
222
            // When no rows available for when clause, skip then clause
223
0
            if when_match.true_count() == 0 {
224
0
                continue;
225
0
            }
226
227
0
            let then_value = self.when_then_expr[i]
228
0
                .1
229
0
                .evaluate_selection(batch, &when_match)?;
230
231
0
            current_value = match then_value {
232
                ColumnarValue::Scalar(ScalarValue::Null) => {
233
0
                    nullif(current_value.as_ref(), &when_match)?
234
                }
235
0
                ColumnarValue::Scalar(then_value) => {
236
0
                    zip(&when_match, &then_value.to_scalar()?, &current_value)?
237
                }
238
0
                ColumnarValue::Array(then_value) => {
239
0
                    zip(&when_match, &then_value, &current_value)?
240
                }
241
            };
242
243
0
            remainder = and_not(&remainder, &when_match)?;
244
        }
245
246
0
        if let Some(e) = &self.else_expr {
247
            // keep `else_expr`'s data type and return type consistent
248
0
            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
249
0
                .unwrap_or_else(|_| Arc::clone(e));
250
            // null and unmatched tuples should be assigned else value
251
0
            remainder = or(&base_nulls, &remainder)?;
252
0
            let else_ = expr
253
0
                .evaluate_selection(batch, &remainder)?
254
0
                .into_array(batch.num_rows())?;
255
0
            current_value = zip(&remainder, &else_, &current_value)?;
256
0
        }
257
258
0
        Ok(ColumnarValue::Array(current_value))
259
0
    }
260
261
    /// This function evaluates the form of CASE where each WHEN expression is a boolean
262
    /// expression.
263
    ///
264
    /// CASE WHEN condition THEN result
265
    ///      [WHEN ...]
266
    ///      [ELSE result]
267
    /// END
268
0
    fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
269
0
        let return_type = self.data_type(&batch.schema())?;
270
271
        // start with nulls as default output
272
0
        let mut current_value = new_null_array(&return_type, batch.num_rows());
273
0
        let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
274
0
        for i in 0..self.when_then_expr.len() {
275
0
            let when_value = self.when_then_expr[i]
276
0
                .0
277
0
                .evaluate_selection(batch, &remainder)?;
278
0
            let when_value = when_value.into_array(batch.num_rows())?;
279
0
            let when_value = as_boolean_array(&when_value).map_err(|e| {
280
0
                DataFusionError::Context(
281
0
                    "WHEN expression did not return a BooleanArray".to_string(),
282
0
                    Box::new(e),
283
0
                )
284
0
            })?;
285
            // Treat 'NULL' as false value
286
0
            let when_value = match when_value.null_count() {
287
0
                0 => Cow::Borrowed(when_value),
288
0
                _ => Cow::Owned(prep_null_mask_filter(when_value)),
289
            };
290
            // Make sure we only consider rows that have not been matched yet
291
0
            let when_value = and(&when_value, &remainder)?;
292
293
            // When no rows available for when clause, skip then clause
294
0
            if when_value.true_count() == 0 {
295
0
                continue;
296
0
            }
297
298
0
            let then_value = self.when_then_expr[i]
299
0
                .1
300
0
                .evaluate_selection(batch, &when_value)?;
301
302
0
            current_value = match then_value {
303
                ColumnarValue::Scalar(ScalarValue::Null) => {
304
0
                    nullif(current_value.as_ref(), &when_value)?
305
                }
306
0
                ColumnarValue::Scalar(then_value) => {
307
0
                    zip(&when_value, &then_value.to_scalar()?, &current_value)?
308
                }
309
0
                ColumnarValue::Array(then_value) => {
310
0
                    zip(&when_value, &then_value, &current_value)?
311
                }
312
            };
313
314
            // Succeed tuples should be filtered out for short-circuit evaluation,
315
            // null values for the current when expr should be kept
316
0
            remainder = and_not(&remainder, &when_value)?;
317
        }
318
319
0
        if let Some(e) = &self.else_expr {
320
            // keep `else_expr`'s data type and return type consistent
321
0
            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
322
0
                .unwrap_or_else(|_| Arc::clone(e));
323
0
            let else_ = expr
324
0
                .evaluate_selection(batch, &remainder)?
325
0
                .into_array(batch.num_rows())?;
326
0
            current_value = zip(&remainder, &else_, &current_value)?;
327
0
        }
328
329
0
        Ok(ColumnarValue::Array(current_value))
330
0
    }
331
332
    /// This function evaluates the specialized case of:
333
    ///
334
    /// CASE WHEN condition THEN column
335
    ///      [ELSE NULL]
336
    /// END
337
    ///
338
    /// Note that this function is only safe to use for "then" expressions
339
    /// that are infallible because the expression will be evaluated for all
340
    /// rows in the input batch.
341
0
    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
342
0
        let when_expr = &self.when_then_expr[0].0;
343
0
        let then_expr = &self.when_then_expr[0].1;
344
0
        if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? {
345
0
            let bit_mask = bit_mask
346
0
                .as_any()
347
0
                .downcast_ref::<BooleanArray>()
348
0
                .expect("predicate should evaluate to a boolean array");
349
            // invert the bitmask
350
0
            let bit_mask = not(bit_mask)?;
351
0
            match then_expr.evaluate(batch)? {
352
0
                ColumnarValue::Array(array) => {
353
0
                    Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
354
                }
355
                ColumnarValue::Scalar(_) => {
356
0
                    internal_err!("expression did not evaluate to an array")
357
                }
358
            }
359
        } else {
360
0
            internal_err!("predicate did not evaluate to an array")
361
        }
362
0
    }
363
364
0
    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
365
0
        let return_type = self.data_type(&batch.schema())?;
366
367
        // evaluate when expression
368
0
        let when_value = self.when_then_expr[0].0.evaluate(batch)?;
369
0
        let when_value = when_value.into_array(batch.num_rows())?;
370
0
        let when_value = as_boolean_array(&when_value).map_err(|e| {
371
0
            DataFusionError::Context(
372
0
                "WHEN expression did not return a BooleanArray".to_string(),
373
0
                Box::new(e),
374
0
            )
375
0
        })?;
376
377
        // Treat 'NULL' as false value
378
0
        let when_value = match when_value.null_count() {
379
0
            0 => Cow::Borrowed(when_value),
380
0
            _ => Cow::Owned(prep_null_mask_filter(when_value)),
381
        };
382
383
        // evaluate then_value
384
0
        let then_value = self.when_then_expr[0].1.evaluate(batch)?;
385
0
        let then_value = Scalar::new(then_value.into_array(1)?);
386
387
        // keep `else_expr`'s data type and return type consistent
388
0
        let e = self.else_expr.as_ref().unwrap();
389
0
        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)
390
0
            .unwrap_or_else(|_| Arc::clone(e));
391
0
        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
392
393
0
        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
394
0
    }
395
}
396
397
impl PhysicalExpr for CaseExpr {
398
    /// Return a reference to Any that can be used for down-casting
399
0
    fn as_any(&self) -> &dyn Any {
400
0
        self
401
0
    }
402
403
0
    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
404
0
        // since all then results have the same data type, we can choose any one as the
405
0
        // return data type except for the null.
406
0
        let mut data_type = DataType::Null;
407
0
        for i in 0..self.when_then_expr.len() {
408
0
            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
409
0
            if !data_type.equals_datatype(&DataType::Null) {
410
0
                break;
411
0
            }
412
        }
413
        // if all then results are null, we use data type of else expr instead if possible.
414
0
        if data_type.equals_datatype(&DataType::Null) {
415
0
            if let Some(e) = &self.else_expr {
416
0
                data_type = e.data_type(input_schema)?;
417
0
            }
418
0
        }
419
420
0
        Ok(data_type)
421
0
    }
422
423
0
    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
424
        // this expression is nullable if any of the input expressions are nullable
425
0
        let then_nullable = self
426
0
            .when_then_expr
427
0
            .iter()
428
0
            .map(|(_, t)| t.nullable(input_schema))
429
0
            .collect::<Result<Vec<_>>>()?;
430
0
        if then_nullable.contains(&true) {
431
0
            Ok(true)
432
0
        } else if let Some(e) = &self.else_expr {
433
0
            e.nullable(input_schema)
434
        } else {
435
            // CASE produces NULL if there is no `else` expr
436
            // (aka when none of the `when_then_exprs` match)
437
0
            Ok(true)
438
        }
439
0
    }
440
441
0
    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
442
0
        match self.eval_method {
443
            EvalMethod::WithExpression => {
444
                // this use case evaluates "expr" and then compares the values with the "when"
445
                // values
446
0
                self.case_when_with_expr(batch)
447
            }
448
            EvalMethod::NoExpression => {
449
                // The "when" conditions all evaluate to boolean in this use case and can be
450
                // arbitrary expressions
451
0
                self.case_when_no_expr(batch)
452
            }
453
            EvalMethod::InfallibleExprOrNull => {
454
                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
455
0
                self.case_column_or_null(batch)
456
            }
457
0
            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
458
        }
459
0
    }
460
461
0
    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
462
0
        let mut children = vec![];
463
0
        if let Some(expr) = &self.expr {
464
0
            children.push(expr)
465
0
        }
466
0
        self.when_then_expr.iter().for_each(|(cond, value)| {
467
0
            children.push(cond);
468
0
            children.push(value);
469
0
        });
470
471
0
        if let Some(else_expr) = &self.else_expr {
472
0
            children.push(else_expr)
473
0
        }
474
0
        children
475
0
    }
476
477
    // For physical CaseExpr, we do not allow modifying children size
478
0
    fn with_new_children(
479
0
        self: Arc<Self>,
480
0
        children: Vec<Arc<dyn PhysicalExpr>>,
481
0
    ) -> Result<Arc<dyn PhysicalExpr>> {
482
0
        if children.len() != self.children().len() {
483
0
            internal_err!("CaseExpr: Wrong number of children")
484
        } else {
485
0
            let (expr, when_then_expr, else_expr) =
486
0
                match (self.expr().is_some(), self.else_expr().is_some()) {
487
0
                    (true, true) => (
488
0
                        Some(&children[0]),
489
0
                        &children[1..children.len() - 1],
490
0
                        Some(&children[children.len() - 1]),
491
0
                    ),
492
                    (true, false) => {
493
0
                        (Some(&children[0]), &children[1..children.len()], None)
494
                    }
495
0
                    (false, true) => (
496
0
                        None,
497
0
                        &children[0..children.len() - 1],
498
0
                        Some(&children[children.len() - 1]),
499
0
                    ),
500
0
                    (false, false) => (None, &children[0..children.len()], None),
501
                };
502
0
            Ok(Arc::new(CaseExpr::try_new(
503
0
                expr.cloned(),
504
0
                when_then_expr.iter().cloned().tuples().collect(),
505
0
                else_expr.cloned(),
506
0
            )?))
507
        }
508
0
    }
509
510
0
    fn dyn_hash(&self, state: &mut dyn Hasher) {
511
0
        let mut s = state;
512
0
        self.hash(&mut s);
513
0
    }
514
}
515
516
impl PartialEq<dyn Any> for CaseExpr {
517
0
    fn eq(&self, other: &dyn Any) -> bool {
518
0
        down_cast_any_ref(other)
519
0
            .downcast_ref::<Self>()
520
0
            .map(|x| {
521
0
                let expr_eq = match (&self.expr, &x.expr) {
522
0
                    (Some(expr1), Some(expr2)) => expr1.eq(expr2),
523
0
                    (None, None) => true,
524
0
                    _ => false,
525
                };
526
0
                let else_expr_eq = match (&self.else_expr, &x.else_expr) {
527
0
                    (Some(expr1), Some(expr2)) => expr1.eq(expr2),
528
0
                    (None, None) => true,
529
0
                    _ => false,
530
                };
531
0
                expr_eq
532
0
                    && else_expr_eq
533
0
                    && self.when_then_expr.len() == x.when_then_expr.len()
534
0
                    && self.when_then_expr.iter().zip(x.when_then_expr.iter()).all(
535
0
                        |((when1, then1), (when2, then2))| {
536
0
                            when1.eq(when2) && then1.eq(then2)
537
0
                        },
538
0
                    )
539
0
            })
540
0
            .unwrap_or(false)
541
0
    }
542
}
543
544
/// Create a CASE expression
545
0
pub fn case(
546
0
    expr: Option<Arc<dyn PhysicalExpr>>,
547
0
    when_thens: Vec<WhenThen>,
548
0
    else_expr: Option<Arc<dyn PhysicalExpr>>,
549
0
) -> Result<Arc<dyn PhysicalExpr>> {
550
0
    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
551
0
}
552
553
#[cfg(test)]
554
mod tests {
555
    use super::*;
556
557
    use crate::expressions::{binary, cast, col, lit, BinaryExpr};
558
    use arrow::buffer::Buffer;
559
    use arrow::datatypes::DataType::Float64;
560
    use arrow::datatypes::*;
561
    use datafusion_common::cast::{as_float64_array, as_int32_array};
562
    use datafusion_common::plan_err;
563
    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
564
    use datafusion_expr::type_coercion::binary::comparison_coercion;
565
    use datafusion_expr::Operator;
566
567
    #[test]
568
    fn case_with_expr() -> Result<()> {
569
        let batch = case_test_batch()?;
570
        let schema = batch.schema();
571
572
        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
573
        let when1 = lit("foo");
574
        let then1 = lit(123i32);
575
        let when2 = lit("bar");
576
        let then2 = lit(456i32);
577
578
        let expr = generate_case_when_with_type_coercion(
579
            Some(col("a", &schema)?),
580
            vec![(when1, then1), (when2, then2)],
581
            None,
582
            schema.as_ref(),
583
        )?;
584
        let result = expr
585
            .evaluate(&batch)?
586
            .into_array(batch.num_rows())
587
            .expect("Failed to convert to array");
588
        let result = as_int32_array(&result)?;
589
590
        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
591
592
        assert_eq!(expected, result);
593
594
        Ok(())
595
    }
596
597
    #[test]
598
    fn case_with_expr_else() -> Result<()> {
599
        let batch = case_test_batch()?;
600
        let schema = batch.schema();
601
602
        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
603
        let when1 = lit("foo");
604
        let then1 = lit(123i32);
605
        let when2 = lit("bar");
606
        let then2 = lit(456i32);
607
        let else_value = lit(999i32);
608
609
        let expr = generate_case_when_with_type_coercion(
610
            Some(col("a", &schema)?),
611
            vec![(when1, then1), (when2, then2)],
612
            Some(else_value),
613
            schema.as_ref(),
614
        )?;
615
        let result = expr
616
            .evaluate(&batch)?
617
            .into_array(batch.num_rows())
618
            .expect("Failed to convert to array");
619
        let result = as_int32_array(&result)?;
620
621
        let expected =
622
            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
623
624
        assert_eq!(expected, result);
625
626
        Ok(())
627
    }
628
629
    #[test]
630
    fn case_with_expr_divide_by_zero() -> Result<()> {
631
        let batch = case_test_batch1()?;
632
        let schema = batch.schema();
633
634
        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
635
        let when1 = lit(0i32);
636
        let then1 = lit(ScalarValue::Float64(None));
637
        let else_value = binary(
638
            lit(25.0f64),
639
            Operator::Divide,
640
            cast(col("a", &schema)?, &batch.schema(), Float64)?,
641
            &batch.schema(),
642
        )?;
643
644
        let expr = generate_case_when_with_type_coercion(
645
            Some(col("a", &schema)?),
646
            vec![(when1, then1)],
647
            Some(else_value),
648
            schema.as_ref(),
649
        )?;
650
        let result = expr
651
            .evaluate(&batch)?
652
            .into_array(batch.num_rows())
653
            .expect("Failed to convert to array");
654
        let result =
655
            as_float64_array(&result).expect("failed to downcast to Float64Array");
656
657
        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
658
659
        assert_eq!(expected, result);
660
661
        Ok(())
662
    }
663
664
    #[test]
665
    fn case_without_expr() -> Result<()> {
666
        let batch = case_test_batch()?;
667
        let schema = batch.schema();
668
669
        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
670
        let when1 = binary(
671
            col("a", &schema)?,
672
            Operator::Eq,
673
            lit("foo"),
674
            &batch.schema(),
675
        )?;
676
        let then1 = lit(123i32);
677
        let when2 = binary(
678
            col("a", &schema)?,
679
            Operator::Eq,
680
            lit("bar"),
681
            &batch.schema(),
682
        )?;
683
        let then2 = lit(456i32);
684
685
        let expr = generate_case_when_with_type_coercion(
686
            None,
687
            vec![(when1, then1), (when2, then2)],
688
            None,
689
            schema.as_ref(),
690
        )?;
691
        let result = expr
692
            .evaluate(&batch)?
693
            .into_array(batch.num_rows())
694
            .expect("Failed to convert to array");
695
        let result = as_int32_array(&result)?;
696
697
        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
698
699
        assert_eq!(expected, result);
700
701
        Ok(())
702
    }
703
704
    #[test]
705
    fn case_with_expr_when_null() -> Result<()> {
706
        let batch = case_test_batch()?;
707
        let schema = batch.schema();
708
709
        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
710
        let when1 = lit(ScalarValue::Utf8(None));
711
        let then1 = lit(0i32);
712
        let when2 = col("a", &schema)?;
713
        let then2 = lit(123i32);
714
        let else_value = lit(999i32);
715
716
        let expr = generate_case_when_with_type_coercion(
717
            Some(col("a", &schema)?),
718
            vec![(when1, then1), (when2, then2)],
719
            Some(else_value),
720
            schema.as_ref(),
721
        )?;
722
        let result = expr
723
            .evaluate(&batch)?
724
            .into_array(batch.num_rows())
725
            .expect("Failed to convert to array");
726
        let result = as_int32_array(&result)?;
727
728
        let expected =
729
            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
730
731
        assert_eq!(expected, result);
732
733
        Ok(())
734
    }
735
736
    #[test]
737
    fn case_without_expr_divide_by_zero() -> Result<()> {
738
        let batch = case_test_batch1()?;
739
        let schema = batch.schema();
740
741
        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
742
        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
743
        let then1 = binary(
744
            lit(25.0f64),
745
            Operator::Divide,
746
            cast(col("a", &schema)?, &batch.schema(), Float64)?,
747
            &batch.schema(),
748
        )?;
749
        let x = lit(ScalarValue::Float64(None));
750
751
        let expr = generate_case_when_with_type_coercion(
752
            None,
753
            vec![(when1, then1)],
754
            Some(x),
755
            schema.as_ref(),
756
        )?;
757
        let result = expr
758
            .evaluate(&batch)?
759
            .into_array(batch.num_rows())
760
            .expect("Failed to convert to array");
761
        let result =
762
            as_float64_array(&result).expect("failed to downcast to Float64Array");
763
764
        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
765
766
        assert_eq!(expected, result);
767
768
        Ok(())
769
    }
770
771
    fn case_test_batch1() -> Result<RecordBatch> {
772
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
773
        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
774
        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
775
        Ok(batch)
776
    }
777
778
    #[test]
779
    fn case_without_expr_else() -> Result<()> {
780
        let batch = case_test_batch()?;
781
        let schema = batch.schema();
782
783
        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
784
        let when1 = binary(
785
            col("a", &schema)?,
786
            Operator::Eq,
787
            lit("foo"),
788
            &batch.schema(),
789
        )?;
790
        let then1 = lit(123i32);
791
        let when2 = binary(
792
            col("a", &schema)?,
793
            Operator::Eq,
794
            lit("bar"),
795
            &batch.schema(),
796
        )?;
797
        let then2 = lit(456i32);
798
        let else_value = lit(999i32);
799
800
        let expr = generate_case_when_with_type_coercion(
801
            None,
802
            vec![(when1, then1), (when2, then2)],
803
            Some(else_value),
804
            schema.as_ref(),
805
        )?;
806
        let result = expr
807
            .evaluate(&batch)?
808
            .into_array(batch.num_rows())
809
            .expect("Failed to convert to array");
810
        let result = as_int32_array(&result)?;
811
812
        let expected =
813
            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
814
815
        assert_eq!(expected, result);
816
817
        Ok(())
818
    }
819
820
    #[test]
821
    fn case_with_type_cast() -> Result<()> {
822
        let batch = case_test_batch()?;
823
        let schema = batch.schema();
824
825
        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
826
        let when = binary(
827
            col("a", &schema)?,
828
            Operator::Eq,
829
            lit("foo"),
830
            &batch.schema(),
831
        )?;
832
        let then = lit(123.3f64);
833
        let else_value = lit(999i32);
834
835
        let expr = generate_case_when_with_type_coercion(
836
            None,
837
            vec![(when, then)],
838
            Some(else_value),
839
            schema.as_ref(),
840
        )?;
841
        let result = expr
842
            .evaluate(&batch)?
843
            .into_array(batch.num_rows())
844
            .expect("Failed to convert to array");
845
        let result =
846
            as_float64_array(&result).expect("failed to downcast to Float64Array");
847
848
        let expected =
849
            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
850
851
        assert_eq!(expected, result);
852
853
        Ok(())
854
    }
855
856
    #[test]
857
    fn case_with_matches_and_nulls() -> Result<()> {
858
        let batch = case_test_batch_nulls()?;
859
        let schema = batch.schema();
860
861
        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
862
        let when = binary(
863
            col("load4", &schema)?,
864
            Operator::Eq,
865
            lit(1.77f64),
866
            &batch.schema(),
867
        )?;
868
        let then = col("load4", &schema)?;
869
870
        let expr = generate_case_when_with_type_coercion(
871
            None,
872
            vec![(when, then)],
873
            None,
874
            schema.as_ref(),
875
        )?;
876
        let result = expr
877
            .evaluate(&batch)?
878
            .into_array(batch.num_rows())
879
            .expect("Failed to convert to array");
880
        let result =
881
            as_float64_array(&result).expect("failed to downcast to Float64Array");
882
883
        let expected =
884
            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
885
886
        assert_eq!(expected, result);
887
888
        Ok(())
889
    }
890
891
    #[test]
892
    fn case_expr_matches_and_nulls() -> Result<()> {
893
        let batch = case_test_batch_nulls()?;
894
        let schema = batch.schema();
895
896
        // SELECT CASE load4 WHEN 1.77 THEN load4 END
897
        let expr = col("load4", &schema)?;
898
        let when = lit(1.77f64);
899
        let then = col("load4", &schema)?;
900
901
        let expr = generate_case_when_with_type_coercion(
902
            Some(expr),
903
            vec![(when, then)],
904
            None,
905
            schema.as_ref(),
906
        )?;
907
        let result = expr
908
            .evaluate(&batch)?
909
            .into_array(batch.num_rows())
910
            .expect("Failed to convert to array");
911
        let result =
912
            as_float64_array(&result).expect("failed to downcast to Float64Array");
913
914
        let expected =
915
            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
916
917
        assert_eq!(expected, result);
918
919
        Ok(())
920
    }
921
922
    fn case_test_batch() -> Result<RecordBatch> {
923
        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
924
        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
925
        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
926
        Ok(batch)
927
    }
928
929
    // Construct an array that has several NULL values whose
930
    // underlying buffer actually matches the where expr predicate
931
    fn case_test_batch_nulls() -> Result<RecordBatch> {
932
        let load4: Float64Array = vec![
933
            Some(1.77), // 1.77
934
            Some(1.77), // null <-- same value, but will be set to null
935
            Some(1.77), // null <-- same value, but will be set to null
936
            Some(1.78), // 1.78
937
            None,       // null
938
            Some(1.77), // 1.77
939
        ]
940
        .into_iter()
941
        .collect();
942
943
        //let valid_array = vec![true, false, false, true, false, tru
944
        let null_buffer = Buffer::from([0b00101001u8]);
945
        let load4 = load4
946
            .into_data()
947
            .into_builder()
948
            .null_bit_buffer(Some(null_buffer))
949
            .build()
950
            .unwrap();
951
        let load4: Float64Array = load4.into();
952
953
        let batch =
954
            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
955
        Ok(batch)
956
    }
957
958
    #[test]
959
    fn case_test_incompatible() -> Result<()> {
960
        // 1 then is int64
961
        // 2 then is boolean
962
        let batch = case_test_batch()?;
963
        let schema = batch.schema();
964
965
        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
966
        let when1 = binary(
967
            col("a", &schema)?,
968
            Operator::Eq,
969
            lit("foo"),
970
            &batch.schema(),
971
        )?;
972
        let then1 = lit(123i32);
973
        let when2 = binary(
974
            col("a", &schema)?,
975
            Operator::Eq,
976
            lit("bar"),
977
            &batch.schema(),
978
        )?;
979
        let then2 = lit(true);
980
981
        let expr = generate_case_when_with_type_coercion(
982
            None,
983
            vec![(when1, then1), (when2, then2)],
984
            None,
985
            schema.as_ref(),
986
        );
987
        assert!(expr.is_err());
988
989
        // then 1 is int32
990
        // then 2 is int64
991
        // else is float
992
        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
993
        let when1 = binary(
994
            col("a", &schema)?,
995
            Operator::Eq,
996
            lit("foo"),
997
            &batch.schema(),
998
        )?;
999
        let then1 = lit(123i32);
1000
        let when2 = binary(
1001
            col("a", &schema)?,
1002
            Operator::Eq,
1003
            lit("bar"),
1004
            &batch.schema(),
1005
        )?;
1006
        let then2 = lit(456i64);
1007
        let else_expr = lit(1.23f64);
1008
1009
        let expr = generate_case_when_with_type_coercion(
1010
            None,
1011
            vec![(when1, then1), (when2, then2)],
1012
            Some(else_expr),
1013
            schema.as_ref(),
1014
        );
1015
        assert!(expr.is_ok());
1016
        let result_type = expr.unwrap().data_type(schema.as_ref())?;
1017
        assert_eq!(Float64, result_type);
1018
        Ok(())
1019
    }
1020
1021
    #[test]
1022
    fn case_eq() -> Result<()> {
1023
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1024
1025
        let when1 = lit("foo");
1026
        let then1 = lit(123i32);
1027
        let when2 = lit("bar");
1028
        let then2 = lit(456i32);
1029
        let else_value = lit(999i32);
1030
1031
        let expr1 = generate_case_when_with_type_coercion(
1032
            Some(col("a", &schema)?),
1033
            vec![
1034
                (Arc::clone(&when1), Arc::clone(&then1)),
1035
                (Arc::clone(&when2), Arc::clone(&then2)),
1036
            ],
1037
            Some(Arc::clone(&else_value)),
1038
            &schema,
1039
        )?;
1040
1041
        let expr2 = generate_case_when_with_type_coercion(
1042
            Some(col("a", &schema)?),
1043
            vec![
1044
                (Arc::clone(&when1), Arc::clone(&then1)),
1045
                (Arc::clone(&when2), Arc::clone(&then2)),
1046
            ],
1047
            Some(Arc::clone(&else_value)),
1048
            &schema,
1049
        )?;
1050
1051
        let expr3 = generate_case_when_with_type_coercion(
1052
            Some(col("a", &schema)?),
1053
            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1054
            None,
1055
            &schema,
1056
        )?;
1057
1058
        let expr4 = generate_case_when_with_type_coercion(
1059
            Some(col("a", &schema)?),
1060
            vec![(when1, then1)],
1061
            Some(else_value),
1062
            &schema,
1063
        )?;
1064
1065
        assert!(expr1.eq(&expr2));
1066
        assert!(expr2.eq(&expr1));
1067
1068
        assert!(expr2.ne(&expr3));
1069
        assert!(expr3.ne(&expr2));
1070
1071
        assert!(expr1.ne(&expr4));
1072
        assert!(expr4.ne(&expr1));
1073
1074
        Ok(())
1075
    }
1076
1077
    #[test]
1078
    fn case_transform() -> Result<()> {
1079
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1080
1081
        let when1 = lit("foo");
1082
        let then1 = lit(123i32);
1083
        let when2 = lit("bar");
1084
        let then2 = lit(456i32);
1085
        let else_value = lit(999i32);
1086
1087
        let expr = generate_case_when_with_type_coercion(
1088
            Some(col("a", &schema)?),
1089
            vec![
1090
                (Arc::clone(&when1), Arc::clone(&then1)),
1091
                (Arc::clone(&when2), Arc::clone(&then2)),
1092
            ],
1093
            Some(Arc::clone(&else_value)),
1094
            &schema,
1095
        )?;
1096
1097
        let expr2 = Arc::clone(&expr)
1098
            .transform(|e| {
1099
                let transformed =
1100
                    match e.as_any().downcast_ref::<crate::expressions::Literal>() {
1101
                        Some(lit_value) => match lit_value.value() {
1102
                            ScalarValue::Utf8(Some(str_value)) => {
1103
                                Some(lit(str_value.to_uppercase()))
1104
                            }
1105
                            _ => None,
1106
                        },
1107
                        _ => None,
1108
                    };
1109
                Ok(if let Some(transformed) = transformed {
1110
                    Transformed::yes(transformed)
1111
                } else {
1112
                    Transformed::no(e)
1113
                })
1114
            })
1115
            .data()
1116
            .unwrap();
1117
1118
        let expr3 = Arc::clone(&expr)
1119
            .transform_down(|e| {
1120
                let transformed =
1121
                    match e.as_any().downcast_ref::<crate::expressions::Literal>() {
1122
                        Some(lit_value) => match lit_value.value() {
1123
                            ScalarValue::Utf8(Some(str_value)) => {
1124
                                Some(lit(str_value.to_uppercase()))
1125
                            }
1126
                            _ => None,
1127
                        },
1128
                        _ => None,
1129
                    };
1130
                Ok(if let Some(transformed) = transformed {
1131
                    Transformed::yes(transformed)
1132
                } else {
1133
                    Transformed::no(e)
1134
                })
1135
            })
1136
            .data()
1137
            .unwrap();
1138
1139
        assert!(expr.ne(&expr2));
1140
        assert!(expr2.eq(&expr3));
1141
1142
        Ok(())
1143
    }
1144
1145
    #[test]
1146
    fn test_column_or_null_specialization() -> Result<()> {
1147
        // create input data
1148
        let mut c1 = Int32Builder::new();
1149
        let mut c2 = StringBuilder::new();
1150
        for i in 0..1000 {
1151
            c1.append_value(i);
1152
            if i % 7 == 0 {
1153
                c2.append_null();
1154
            } else {
1155
                c2.append_value(format!("string {i}"));
1156
            }
1157
        }
1158
        let c1 = Arc::new(c1.finish());
1159
        let c2 = Arc::new(c2.finish());
1160
        let schema = Schema::new(vec![
1161
            Field::new("c1", DataType::Int32, true),
1162
            Field::new("c2", DataType::Utf8, true),
1163
        ]);
1164
        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1165
1166
        // CaseWhenExprOrNull should produce same results as CaseExpr
1167
        let predicate = Arc::new(BinaryExpr::new(
1168
            make_col("c1", 0),
1169
            Operator::LtEq,
1170
            make_lit_i32(250),
1171
        ));
1172
        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1173
        assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1174
        match expr.evaluate(&batch)? {
1175
            ColumnarValue::Array(array) => {
1176
                assert_eq!(1000, array.len());
1177
                assert_eq!(785, array.null_count());
1178
            }
1179
            _ => unreachable!(),
1180
        }
1181
        Ok(())
1182
    }
1183
1184
    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1185
        Arc::new(Column::new(name, index))
1186
    }
1187
1188
    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1189
        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1190
    }
1191
1192
    fn generate_case_when_with_type_coercion(
1193
        expr: Option<Arc<dyn PhysicalExpr>>,
1194
        when_thens: Vec<WhenThen>,
1195
        else_expr: Option<Arc<dyn PhysicalExpr>>,
1196
        input_schema: &Schema,
1197
    ) -> Result<Arc<dyn PhysicalExpr>> {
1198
        let coerce_type =
1199
            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
1200
        let (when_thens, else_expr) = match coerce_type {
1201
            None => plan_err!(
1202
                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
1203
            ),
1204
            Some(data_type) => {
1205
                // cast then expr
1206
                let left = when_thens
1207
                    .into_iter()
1208
                    .map(|(when, then)| {
1209
                        let then = try_cast(then, input_schema, data_type.clone())?;
1210
                        Ok((when, then))
1211
                    })
1212
                    .collect::<Result<Vec<_>>>()?;
1213
                let right = match else_expr {
1214
                    None => None,
1215
                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
1216
                };
1217
1218
                Ok((left, right))
1219
            }
1220
        }?;
1221
        case(expr, when_thens, else_expr)
1222
    }
1223
1224
    fn get_case_common_type(
1225
        when_thens: &[WhenThen],
1226
        else_expr: Option<Arc<dyn PhysicalExpr>>,
1227
        input_schema: &Schema,
1228
    ) -> Option<DataType> {
1229
        let thens_type = when_thens
1230
            .iter()
1231
            .map(|when_then| {
1232
                let data_type = &when_then.1.data_type(input_schema).unwrap();
1233
                data_type.clone()
1234
            })
1235
            .collect::<Vec<_>>();
1236
        let else_type = match else_expr {
1237
            None => {
1238
                // case when then exprs must have one then value
1239
                thens_type[0].clone()
1240
            }
1241
            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
1242
        };
1243
        thens_type
1244
            .iter()
1245
            .try_fold(else_type, |left_type, right_type| {
1246
                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
1247
                // refactor again.
1248
                comparison_coercion(&left_type, right_type)
1249
            })
1250
    }
1251
}