Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/expr_schema.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use super::{Between, Expr, Like};
19
use crate::expr::{
20
    AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder,
21
    ScalarFunction, TryCast, Unnest, WindowFunction,
22
};
23
use crate::type_coercion::binary::get_result_type;
24
use crate::type_coercion::functions::{
25
    data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
26
};
27
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
28
use arrow::compute::can_cast_types;
29
use arrow::datatypes::{DataType, Field};
30
use datafusion_common::{
31
    not_impl_err, plan_datafusion_err, plan_err, Column, ExprSchema, Result,
32
    TableReference,
33
};
34
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
35
use std::collections::HashMap;
36
use std::sync::Arc;
37
38
/// trait to allow expr to typable with respect to a schema
39
pub trait ExprSchemable {
40
    /// given a schema, return the type of the expr
41
    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
42
43
    /// given a schema, return the nullability of the expr
44
    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
45
46
    /// given a schema, return the expr's optional metadata
47
    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;
48
49
    /// convert to a field with respect to a schema
50
    fn to_field(
51
        &self,
52
        input_schema: &dyn ExprSchema,
53
    ) -> Result<(Option<TableReference>, Arc<Field>)>;
54
55
    /// cast to a type with respect to a schema
56
    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
57
58
    /// given a schema, return the type and nullability of the expr
59
    fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
60
        -> Result<(DataType, bool)>;
61
}
62
63
impl ExprSchemable for Expr {
64
    /// Returns the [arrow::datatypes::DataType] of the expression
65
    /// based on [ExprSchema]
66
    ///
67
    /// Note: [`DFSchema`] implements [ExprSchema].
68
    ///
69
    /// [`DFSchema`]: datafusion_common::DFSchema
70
    ///
71
    /// # Examples
72
    ///
73
    /// Get the type of an expression that adds 2 columns. Adding an Int32
74
    /// and Float32 results in Float32 type
75
    ///
76
    /// ```
77
    /// # use arrow::datatypes::{DataType, Field};
78
    /// # use datafusion_common::DFSchema;
79
    /// # use datafusion_expr::{col, ExprSchemable};
80
    /// # use std::collections::HashMap;
81
    ///
82
    /// fn main() {
83
    ///   let expr = col("c1") + col("c2");
84
    ///   let schema = DFSchema::from_unqualified_fields(
85
    ///     vec![
86
    ///       Field::new("c1", DataType::Int32, true),
87
    ///       Field::new("c2", DataType::Float32, true),
88
    ///       ].into(),
89
    ///       HashMap::new(),
90
    ///   ).unwrap();
91
    ///   assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap()));
92
    /// }
93
    /// ```
94
    ///
95
    /// # Errors
96
    ///
97
    /// This function errors when it is not possible to compute its
98
    /// [arrow::datatypes::DataType].  This happens when e.g. the
99
    /// expression refers to a column that does not exist in the
100
    /// schema, or when the expression is incorrectly typed
101
    /// (e.g. `[utf8] + [bool]`).
102
0
    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
103
0
        match self {
104
0
            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
105
0
                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
106
0
                    None => schema.data_type(&Column::from_name(name)).cloned(),
107
0
                    Some(dt) => Ok(dt.clone()),
108
                },
109
0
                _ => expr.get_type(schema),
110
            },
111
0
            Expr::Negative(expr) => expr.get_type(schema),
112
0
            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
113
0
            Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
114
0
            Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
115
0
            Expr::Literal(l) => Ok(l.data_type()),
116
0
            Expr::Case(case) => {
117
0
                for (_, then_expr) in &case.when_then_expr {
118
0
                    let then_type = then_expr.get_type(schema)?;
119
0
                    if !then_type.is_null() {
120
0
                        return Ok(then_type);
121
0
                    }
122
                }
123
0
                case.else_expr
124
0
                    .as_ref()
125
0
                    .map_or(Ok(DataType::Null), |e| e.get_type(schema))
126
            }
127
0
            Expr::Cast(Cast { data_type, .. })
128
0
            | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
129
0
            Expr::Unnest(Unnest { expr }) => {
130
0
                let arg_data_type = expr.get_type(schema)?;
131
                // Unnest's output type is the inner type of the list
132
0
                match arg_data_type {
133
0
                    DataType::List(field)
134
0
                    | DataType::LargeList(field)
135
0
                    | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
136
0
                    DataType::Struct(_) => Ok(arg_data_type),
137
                    DataType::Null => {
138
0
                        not_impl_err!("unnest() does not support null yet")
139
                    }
140
                    _ => {
141
0
                        plan_err!(
142
0
                            "unnest() can only be applied to array, struct and null"
143
0
                        )
144
                    }
145
                }
146
            }
147
0
            Expr::ScalarFunction(ScalarFunction { func, args }) => {
148
0
                let arg_data_types = args
149
0
                    .iter()
150
0
                    .map(|e| e.get_type(schema))
151
0
                    .collect::<Result<Vec<_>>>()?;
152
153
                // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
154
0
                let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
155
0
                    .map_err(|err| {
156
0
                        plan_datafusion_err!(
157
0
                            "{} {}",
158
0
                            err,
159
0
                            utils::generate_signature_error_msg(
160
0
                                func.name(),
161
0
                                func.signature().clone(),
162
0
                                &arg_data_types,
163
0
                            )
164
0
                        )
165
0
                    })?;
166
167
                // perform additional function arguments validation (due to limited
168
                // expressiveness of `TypeSignature`), then infer return type
169
0
                Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
170
            }
171
0
            Expr::WindowFunction(window_function) => self
172
0
                .data_type_and_nullable_with_window_function(schema, window_function)
173
0
                .map(|(return_type, _)| return_type),
174
0
            Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
175
0
                let data_types = args
176
0
                    .iter()
177
0
                    .map(|e| e.get_type(schema))
178
0
                    .collect::<Result<Vec<_>>>()?;
179
0
                let new_types = data_types_with_aggregate_udf(&data_types, func)
180
0
                    .map_err(|err| {
181
0
                        plan_datafusion_err!(
182
0
                            "{} {}",
183
0
                            err,
184
0
                            utils::generate_signature_error_msg(
185
0
                                func.name(),
186
0
                                func.signature().clone(),
187
0
                                &data_types
188
0
                            )
189
0
                        )
190
0
                    })?;
191
0
                Ok(func.return_type(&new_types)?)
192
            }
193
            Expr::Not(_)
194
            | Expr::IsNull(_)
195
            | Expr::Exists { .. }
196
            | Expr::InSubquery(_)
197
            | Expr::Between { .. }
198
            | Expr::InList { .. }
199
            | Expr::IsNotNull(_)
200
            | Expr::IsTrue(_)
201
            | Expr::IsFalse(_)
202
            | Expr::IsUnknown(_)
203
            | Expr::IsNotTrue(_)
204
            | Expr::IsNotFalse(_)
205
0
            | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
206
0
            Expr::ScalarSubquery(subquery) => {
207
0
                Ok(subquery.subquery.schema().field(0).data_type().clone())
208
            }
209
            Expr::BinaryExpr(BinaryExpr {
210
0
                ref left,
211
0
                ref right,
212
0
                ref op,
213
0
            }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?),
214
0
            Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
215
0
            Expr::Placeholder(Placeholder { data_type, .. }) => {
216
0
                data_type.clone().ok_or_else(|| {
217
0
                    plan_datafusion_err!(
218
0
                        "Placeholder type could not be resolved. Make sure that the \
219
0
                         placeholder is bound to a concrete type, e.g. by providing \
220
0
                         parameter values."
221
0
                    )
222
0
                })
223
            }
224
0
            Expr::Wildcard { .. } => Ok(DataType::Null),
225
            Expr::GroupingSet(_) => {
226
                // grouping sets do not really have a type and do not appear in projections
227
0
                Ok(DataType::Null)
228
            }
229
        }
230
0
    }
231
232
    /// Returns the nullability of the expression based on [ExprSchema].
233
    ///
234
    /// Note: [`DFSchema`] implements [ExprSchema].
235
    ///
236
    /// [`DFSchema`]: datafusion_common::DFSchema
237
    ///
238
    /// # Errors
239
    ///
240
    /// This function errors when it is not possible to compute its
241
    /// nullability.  This happens when the expression refers to a
242
    /// column that does not exist in the schema.
243
0
    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
244
0
        match self {
245
0
            Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
246
0
                expr.nullable(input_schema)
247
            }
248
249
0
            Expr::InList(InList { expr, list, .. }) => {
250
                // Avoid inspecting too many expressions.
251
                const MAX_INSPECT_LIMIT: usize = 6;
252
                // Stop if a nullable expression is found or an error occurs.
253
0
                let has_nullable = std::iter::once(expr.as_ref())
254
0
                    .chain(list)
255
0
                    .take(MAX_INSPECT_LIMIT)
256
0
                    .find_map(|e| {
257
0
                        e.nullable(input_schema)
258
0
                            .map(|nullable| if nullable { Some(()) } else { None })
259
0
                            .transpose()
260
0
                    })
261
0
                    .transpose()?;
262
0
                Ok(match has_nullable {
263
                    // If a nullable subexpression is found, the result may also be nullable.
264
0
                    Some(_) => true,
265
                    // If the list is too long, we assume it is nullable.
266
0
                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
267
                    // All the subexpressions are non-nullable, so the result must be non-nullable.
268
0
                    _ => false,
269
                })
270
            }
271
272
            Expr::Between(Between {
273
0
                expr, low, high, ..
274
0
            }) => Ok(expr.nullable(input_schema)?
275
0
                || low.nullable(input_schema)?
276
0
                || high.nullable(input_schema)?),
277
278
0
            Expr::Column(c) => input_schema.nullable(c),
279
0
            Expr::OuterReferenceColumn(_, _) => Ok(true),
280
0
            Expr::Literal(value) => Ok(value.is_null()),
281
0
            Expr::Case(case) => {
282
                // this expression is nullable if any of the input expressions are nullable
283
0
                let then_nullable = case
284
0
                    .when_then_expr
285
0
                    .iter()
286
0
                    .map(|(_, t)| t.nullable(input_schema))
287
0
                    .collect::<Result<Vec<_>>>()?;
288
0
                if then_nullable.contains(&true) {
289
0
                    Ok(true)
290
0
                } else if let Some(e) = &case.else_expr {
291
0
                    e.nullable(input_schema)
292
                } else {
293
                    // CASE produces NULL if there is no `else` expr
294
                    // (aka when none of the `when_then_exprs` match)
295
0
                    Ok(true)
296
                }
297
            }
298
0
            Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
299
0
            Expr::ScalarFunction(ScalarFunction { func, args }) => {
300
0
                Ok(func.is_nullable(args, input_schema))
301
            }
302
0
            Expr::AggregateFunction(AggregateFunction { func, .. }) => {
303
0
                Ok(func.is_nullable())
304
            }
305
0
            Expr::WindowFunction(window_function) => self
306
0
                .data_type_and_nullable_with_window_function(
307
0
                    input_schema,
308
0
                    window_function,
309
0
                )
310
0
                .map(|(_, nullable)| nullable),
311
            Expr::ScalarVariable(_, _)
312
            | Expr::TryCast { .. }
313
            | Expr::Unnest(_)
314
0
            | Expr::Placeholder(_) => Ok(true),
315
            Expr::IsNull(_)
316
            | Expr::IsNotNull(_)
317
            | Expr::IsTrue(_)
318
            | Expr::IsFalse(_)
319
            | Expr::IsUnknown(_)
320
            | Expr::IsNotTrue(_)
321
            | Expr::IsNotFalse(_)
322
            | Expr::IsNotUnknown(_)
323
0
            | Expr::Exists { .. } => Ok(false),
324
0
            Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
325
0
            Expr::ScalarSubquery(subquery) => {
326
0
                Ok(subquery.subquery.schema().field(0).is_nullable())
327
            }
328
            Expr::BinaryExpr(BinaryExpr {
329
0
                ref left,
330
0
                ref right,
331
0
                ..
332
0
            }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
333
0
            Expr::Like(Like { expr, pattern, .. })
334
0
            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
335
0
                Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
336
            }
337
0
            Expr::Wildcard { .. } => Ok(false),
338
            Expr::GroupingSet(_) => {
339
                // grouping sets do not really have the concept of nullable and do not appear
340
                // in projections
341
0
                Ok(true)
342
            }
343
        }
344
0
    }
345
346
0
    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
347
0
        match self {
348
0
            Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
349
0
            Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
350
0
            _ => Ok(HashMap::new()),
351
        }
352
0
    }
353
354
    /// Returns the datatype and nullability of the expression based on [ExprSchema].
355
    ///
356
    /// Note: [`DFSchema`] implements [ExprSchema].
357
    ///
358
    /// [`DFSchema`]: datafusion_common::DFSchema
359
    ///
360
    /// # Errors
361
    ///
362
    /// This function errors when it is not possible to compute its
363
    /// datatype or nullability.
364
0
    fn data_type_and_nullable(
365
0
        &self,
366
0
        schema: &dyn ExprSchema,
367
0
    ) -> Result<(DataType, bool)> {
368
0
        match self {
369
0
            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
370
0
                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
371
0
                    None => schema
372
0
                        .data_type_and_nullable(&Column::from_name(name))
373
0
                        .map(|(d, n)| (d.clone(), n)),
374
0
                    Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)),
375
                },
376
0
                _ => expr.data_type_and_nullable(schema),
377
            },
378
0
            Expr::Negative(expr) => expr.data_type_and_nullable(schema),
379
0
            Expr::Column(c) => schema
380
0
                .data_type_and_nullable(c)
381
0
                .map(|(d, n)| (d.clone(), n)),
382
0
            Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)),
383
0
            Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)),
384
0
            Expr::Literal(l) => Ok((l.data_type(), l.is_null())),
385
            Expr::IsNull(_)
386
            | Expr::IsNotNull(_)
387
            | Expr::IsTrue(_)
388
            | Expr::IsFalse(_)
389
            | Expr::IsUnknown(_)
390
            | Expr::IsNotTrue(_)
391
            | Expr::IsNotFalse(_)
392
            | Expr::IsNotUnknown(_)
393
0
            | Expr::Exists { .. } => Ok((DataType::Boolean, false)),
394
0
            Expr::ScalarSubquery(subquery) => Ok((
395
0
                subquery.subquery.schema().field(0).data_type().clone(),
396
0
                subquery.subquery.schema().field(0).is_nullable(),
397
0
            )),
398
            Expr::BinaryExpr(BinaryExpr {
399
0
                ref left,
400
0
                ref right,
401
0
                ref op,
402
            }) => {
403
0
                let left = left.data_type_and_nullable(schema)?;
404
0
                let right = right.data_type_and_nullable(schema)?;
405
0
                Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1))
406
            }
407
0
            Expr::WindowFunction(window_function) => {
408
0
                self.data_type_and_nullable_with_window_function(schema, window_function)
409
            }
410
0
            _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
411
        }
412
0
    }
413
414
    /// Returns a [arrow::datatypes::Field] compatible with this expression.
415
    ///
416
    /// So for example, a projected expression `col(c1) + col(c2)` is
417
    /// placed in an output field **named** col("c1 + c2")
418
0
    fn to_field(
419
0
        &self,
420
0
        input_schema: &dyn ExprSchema,
421
0
    ) -> Result<(Option<TableReference>, Arc<Field>)> {
422
0
        let (relation, schema_name) = self.qualified_name();
423
0
        let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
424
0
        let field = Field::new(schema_name, data_type, nullable)
425
0
            .with_metadata(self.metadata(input_schema)?)
426
0
            .into();
427
0
        Ok((relation, field))
428
0
    }
429
430
    /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
431
    ///
432
    /// # Errors
433
    ///
434
    /// This function errors when it is impossible to cast the
435
    /// expression to the target [arrow::datatypes::DataType].
436
0
    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
437
0
        let this_type = self.get_type(schema)?;
438
0
        if this_type == *cast_to_type {
439
0
            return Ok(self);
440
0
        }
441
0
442
0
        // TODO(kszucs): most of the operations do not validate the type correctness
443
0
        // like all of the binary expressions below. Perhaps Expr should track the
444
0
        // type of the expression?
445
0
446
0
        if can_cast_types(&this_type, cast_to_type) {
447
0
            match self {
448
0
                Expr::ScalarSubquery(subquery) => {
449
0
                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
450
                }
451
0
                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
452
            }
453
        } else {
454
0
            plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
455
        }
456
0
    }
457
}
458
459
impl Expr {
460
    /// Common method for window functions that applies type coercion
461
    /// to all arguments of the window function to check if it matches
462
    /// its signature.
463
    ///
464
    /// If successful, this method returns the data type and
465
    /// nullability of the window function's result.
466
    ///
467
    /// Otherwise, returns an error if there's a type mismatch between
468
    /// the window function's signature and the provided arguments.
469
0
    fn data_type_and_nullable_with_window_function(
470
0
        &self,
471
0
        schema: &dyn ExprSchema,
472
0
        window_function: &WindowFunction,
473
0
    ) -> Result<(DataType, bool)> {
474
0
        let WindowFunction { fun, args, .. } = window_function;
475
476
0
        let data_types = args
477
0
            .iter()
478
0
            .map(|e| e.get_type(schema))
479
0
            .collect::<Result<Vec<_>>>()?;
480
0
        match fun {
481
0
            WindowFunctionDefinition::BuiltInWindowFunction(window_fun) => {
482
0
                let return_type = window_fun.return_type(&data_types)?;
483
0
                let nullable =
484
0
                    !["RANK", "NTILE", "CUME_DIST"].contains(&window_fun.name());
485
0
                Ok((return_type, nullable))
486
            }
487
0
            WindowFunctionDefinition::AggregateUDF(udaf) => {
488
0
                let new_types = data_types_with_aggregate_udf(&data_types, udaf)
489
0
                    .map_err(|err| {
490
0
                        plan_datafusion_err!(
491
0
                            "{} {}",
492
0
                            err,
493
0
                            utils::generate_signature_error_msg(
494
0
                                fun.name(),
495
0
                                fun.signature(),
496
0
                                &data_types
497
0
                            )
498
0
                        )
499
0
                    })?;
500
501
0
                let return_type = udaf.return_type(&new_types)?;
502
0
                let nullable = udaf.is_nullable();
503
0
504
0
                Ok((return_type, nullable))
505
            }
506
0
            WindowFunctionDefinition::WindowUDF(udwf) => {
507
0
                let new_types =
508
0
                    data_types_with_window_udf(&data_types, udwf).map_err(|err| {
509
0
                        plan_datafusion_err!(
510
0
                            "{} {}",
511
0
                            err,
512
0
                            utils::generate_signature_error_msg(
513
0
                                fun.name(),
514
0
                                fun.signature(),
515
0
                                &data_types
516
0
                            )
517
0
                        )
518
0
                    })?;
519
0
                let (_, function_name) = self.qualified_name();
520
0
                let field_args = WindowUDFFieldArgs::new(&new_types, &function_name);
521
0
522
0
                udwf.field(field_args)
523
0
                    .map(|field| (field.data_type().clone(), field.is_nullable()))
524
            }
525
        }
526
0
    }
527
}
528
529
/// cast subquery in InSubquery/ScalarSubquery to a given type.
530
0
pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
531
0
    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
532
0
        return Ok(subquery);
533
0
    }
534
0
535
0
    let plan = subquery.subquery.as_ref();
536
0
    let new_plan = match plan {
537
0
        LogicalPlan::Projection(projection) => {
538
0
            let cast_expr = projection.expr[0]
539
0
                .clone()
540
0
                .cast_to(cast_to_type, projection.input.schema())?;
541
0
            LogicalPlan::Projection(Projection::try_new(
542
0
                vec![cast_expr],
543
0
                Arc::clone(&projection.input),
544
0
            )?)
545
        }
546
        _ => {
547
0
            let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
548
0
                .cast_to(cast_to_type, subquery.subquery.schema())?;
549
0
            LogicalPlan::Projection(Projection::try_new(
550
0
                vec![cast_expr],
551
0
                subquery.subquery,
552
0
            )?)
553
        }
554
    };
555
0
    Ok(Subquery {
556
0
        subquery: Arc::new(new_plan),
557
0
        outer_ref_columns: subquery.outer_ref_columns,
558
0
    })
559
0
}
560
561
#[cfg(test)]
562
mod tests {
563
    use super::*;
564
    use crate::{col, lit};
565
566
    use datafusion_common::{internal_err, DFSchema, ScalarValue};
567
568
    macro_rules! test_is_expr_nullable {
569
        ($EXPR_TYPE:ident) => {{
570
            let expr = lit(ScalarValue::Null).$EXPR_TYPE();
571
            assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
572
        }};
573
    }
574
575
    #[test]
576
    fn expr_schema_nullability() {
577
        let expr = col("foo").eq(lit(1));
578
        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
579
        assert!(expr
580
            .nullable(&MockExprSchema::new().with_nullable(true))
581
            .unwrap());
582
583
        test_is_expr_nullable!(is_null);
584
        test_is_expr_nullable!(is_not_null);
585
        test_is_expr_nullable!(is_true);
586
        test_is_expr_nullable!(is_not_true);
587
        test_is_expr_nullable!(is_false);
588
        test_is_expr_nullable!(is_not_false);
589
        test_is_expr_nullable!(is_unknown);
590
        test_is_expr_nullable!(is_not_unknown);
591
    }
592
593
    #[test]
594
    fn test_between_nullability() {
595
        let get_schema = |nullable| {
596
            MockExprSchema::new()
597
                .with_data_type(DataType::Int32)
598
                .with_nullable(nullable)
599
        };
600
601
        let expr = col("foo").between(lit(1), lit(2));
602
        assert!(!expr.nullable(&get_schema(false)).unwrap());
603
        assert!(expr.nullable(&get_schema(true)).unwrap());
604
605
        let null = lit(ScalarValue::Int32(None));
606
607
        let expr = col("foo").between(null.clone(), lit(2));
608
        assert!(expr.nullable(&get_schema(false)).unwrap());
609
610
        let expr = col("foo").between(lit(1), null.clone());
611
        assert!(expr.nullable(&get_schema(false)).unwrap());
612
613
        let expr = col("foo").between(null.clone(), null);
614
        assert!(expr.nullable(&get_schema(false)).unwrap());
615
    }
616
617
    #[test]
618
    fn test_inlist_nullability() {
619
        let get_schema = |nullable| {
620
            MockExprSchema::new()
621
                .with_data_type(DataType::Int32)
622
                .with_nullable(nullable)
623
        };
624
625
        let expr = col("foo").in_list(vec![lit(1); 5], false);
626
        assert!(!expr.nullable(&get_schema(false)).unwrap());
627
        assert!(expr.nullable(&get_schema(true)).unwrap());
628
        // Testing nullable() returns an error.
629
        assert!(expr
630
            .nullable(&get_schema(false).with_error_on_nullable(true))
631
            .is_err());
632
633
        let null = lit(ScalarValue::Int32(None));
634
        let expr = col("foo").in_list(vec![null, lit(1)], false);
635
        assert!(expr.nullable(&get_schema(false)).unwrap());
636
637
        // Testing on long list
638
        let expr = col("foo").in_list(vec![lit(1); 6], false);
639
        assert!(expr.nullable(&get_schema(false)).unwrap());
640
    }
641
642
    #[test]
643
    fn test_like_nullability() {
644
        let get_schema = |nullable| {
645
            MockExprSchema::new()
646
                .with_data_type(DataType::Utf8)
647
                .with_nullable(nullable)
648
        };
649
650
        let expr = col("foo").like(lit("bar"));
651
        assert!(!expr.nullable(&get_schema(false)).unwrap());
652
        assert!(expr.nullable(&get_schema(true)).unwrap());
653
654
        let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
655
        assert!(expr.nullable(&get_schema(false)).unwrap());
656
    }
657
658
    #[test]
659
    fn expr_schema_data_type() {
660
        let expr = col("foo");
661
        assert_eq!(
662
            DataType::Utf8,
663
            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
664
                .unwrap()
665
        );
666
    }
667
668
    #[test]
669
    fn test_expr_metadata() {
670
        let mut meta = HashMap::new();
671
        meta.insert("bar".to_string(), "buzz".to_string());
672
        let expr = col("foo");
673
        let schema = MockExprSchema::new()
674
            .with_data_type(DataType::Int32)
675
            .with_metadata(meta.clone());
676
677
        // col and alias should be metadata-preserving
678
        assert_eq!(meta, expr.metadata(&schema).unwrap());
679
        assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
680
681
        // cast should drop input metadata since the type has changed
682
        assert_eq!(
683
            HashMap::new(),
684
            expr.clone()
685
                .cast_to(&DataType::Int64, &schema)
686
                .unwrap()
687
                .metadata(&schema)
688
                .unwrap()
689
        );
690
691
        let schema = DFSchema::from_unqualified_fields(
692
            vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())]
693
                .into(),
694
            HashMap::new(),
695
        )
696
        .unwrap();
697
698
        // verify to_field method populates metadata
699
        assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata());
700
    }
701
702
    #[derive(Debug)]
703
    struct MockExprSchema {
704
        nullable: bool,
705
        data_type: DataType,
706
        error_on_nullable: bool,
707
        metadata: HashMap<String, String>,
708
    }
709
710
    impl MockExprSchema {
711
        fn new() -> Self {
712
            Self {
713
                nullable: false,
714
                data_type: DataType::Null,
715
                error_on_nullable: false,
716
                metadata: HashMap::new(),
717
            }
718
        }
719
720
        fn with_nullable(mut self, nullable: bool) -> Self {
721
            self.nullable = nullable;
722
            self
723
        }
724
725
        fn with_data_type(mut self, data_type: DataType) -> Self {
726
            self.data_type = data_type;
727
            self
728
        }
729
730
        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
731
            self.error_on_nullable = error_on_nullable;
732
            self
733
        }
734
735
        fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
736
            self.metadata = metadata;
737
            self
738
        }
739
    }
740
741
    impl ExprSchema for MockExprSchema {
742
        fn nullable(&self, _col: &Column) -> Result<bool> {
743
            if self.error_on_nullable {
744
                internal_err!("nullable error")
745
            } else {
746
                Ok(self.nullable)
747
            }
748
        }
749
750
        fn data_type(&self, _col: &Column) -> Result<&DataType> {
751
            Ok(&self.data_type)
752
        }
753
754
        fn metadata(&self, _col: &Column) -> Result<&HashMap<String, String>> {
755
            Ok(&self.metadata)
756
        }
757
758
        fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> {
759
            Ok((self.data_type(col)?, self.nullable(col)?))
760
        }
761
    }
762
}