Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/conditional_expressions.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Conditional expressions
19
use crate::expr::Case;
20
use crate::{expr_schema::ExprSchemable, Expr};
21
use arrow::datatypes::DataType;
22
use datafusion_common::{plan_err, DFSchema, Result};
23
use std::collections::HashSet;
24
25
/// Helper struct for building [Expr::Case]
26
pub struct CaseBuilder {
27
    expr: Option<Box<Expr>>,
28
    when_expr: Vec<Expr>,
29
    then_expr: Vec<Expr>,
30
    else_expr: Option<Box<Expr>>,
31
}
32
33
impl CaseBuilder {
34
0
    pub fn new(
35
0
        expr: Option<Box<Expr>>,
36
0
        when_expr: Vec<Expr>,
37
0
        then_expr: Vec<Expr>,
38
0
        else_expr: Option<Box<Expr>>,
39
0
    ) -> Self {
40
0
        Self {
41
0
            expr,
42
0
            when_expr,
43
0
            then_expr,
44
0
            else_expr,
45
0
        }
46
0
    }
47
0
    pub fn when(&mut self, when: Expr, then: Expr) -> CaseBuilder {
48
0
        self.when_expr.push(when);
49
0
        self.then_expr.push(then);
50
0
        CaseBuilder {
51
0
            expr: self.expr.clone(),
52
0
            when_expr: self.when_expr.clone(),
53
0
            then_expr: self.then_expr.clone(),
54
0
            else_expr: self.else_expr.clone(),
55
0
        }
56
0
    }
57
0
    pub fn otherwise(&mut self, else_expr: Expr) -> Result<Expr> {
58
0
        self.else_expr = Some(Box::new(else_expr));
59
0
        self.build()
60
0
    }
61
62
0
    pub fn end(&self) -> Result<Expr> {
63
0
        self.build()
64
0
    }
65
66
0
    fn build(&self) -> Result<Expr> {
67
0
        // collect all "then" expressions
68
0
        let mut then_expr = self.then_expr.clone();
69
0
        if let Some(e) = &self.else_expr {
70
0
            then_expr.push(e.as_ref().to_owned());
71
0
        }
72
73
0
        let then_types: Vec<DataType> = then_expr
74
0
            .iter()
75
0
            .map(|e| match e {
76
0
                Expr::Literal(_) => e.get_type(&DFSchema::empty()),
77
0
                _ => Ok(DataType::Null),
78
0
            })
79
0
            .collect::<Result<Vec<_>>>()?;
80
81
0
        if then_types.contains(&DataType::Null) {
82
0
            // cannot verify types until execution type
83
0
        } else {
84
0
            let unique_types: HashSet<&DataType> = then_types.iter().collect();
85
0
            if unique_types.len() != 1 {
86
0
                return plan_err!(
87
0
                    "CASE expression 'then' values had multiple data types: {unique_types:?}"
88
0
                );
89
0
            }
90
        }
91
92
0
        Ok(Expr::Case(Case::new(
93
0
            self.expr.clone(),
94
0
            self.when_expr
95
0
                .iter()
96
0
                .zip(self.then_expr.iter())
97
0
                .map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
98
0
                .collect(),
99
0
            self.else_expr.clone(),
100
0
        )))
101
0
    }
102
}
103
104
#[cfg(test)]
105
mod tests {
106
    use super::*;
107
    use crate::{col, lit, when};
108
109
    #[test]
110
    fn case_when_same_literal_then_types() -> Result<()> {
111
        let _ = when(col("state").eq(lit("CO")), lit(303))
112
            .when(col("state").eq(lit("NY")), lit(212))
113
            .end()?;
114
        Ok(())
115
    }
116
117
    #[test]
118
    fn case_when_different_literal_then_types() {
119
        let maybe_expr = when(col("state").eq(lit("CO")), lit(303))
120
            .when(col("state").eq(lit("NY")), lit("212"))
121
            .end();
122
        assert!(maybe_expr.is_err());
123
    }
124
}