/Users/andrewlamb/Software/datafusion/datafusion/expr/src/expr_schema.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use super::{Between, Expr, Like}; |
19 | | use crate::expr::{ |
20 | | AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, |
21 | | ScalarFunction, TryCast, Unnest, WindowFunction, |
22 | | }; |
23 | | use crate::type_coercion::binary::get_result_type; |
24 | | use crate::type_coercion::functions::{ |
25 | | data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, |
26 | | }; |
27 | | use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; |
28 | | use arrow::compute::can_cast_types; |
29 | | use arrow::datatypes::{DataType, Field}; |
30 | | use datafusion_common::{ |
31 | | not_impl_err, plan_datafusion_err, plan_err, Column, ExprSchema, Result, |
32 | | TableReference, |
33 | | }; |
34 | | use datafusion_functions_window_common::field::WindowUDFFieldArgs; |
35 | | use std::collections::HashMap; |
36 | | use std::sync::Arc; |
37 | | |
38 | | /// trait to allow expr to typable with respect to a schema |
39 | | pub trait ExprSchemable { |
40 | | /// given a schema, return the type of the expr |
41 | | fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>; |
42 | | |
43 | | /// given a schema, return the nullability of the expr |
44 | | fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>; |
45 | | |
46 | | /// given a schema, return the expr's optional metadata |
47 | | fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>; |
48 | | |
49 | | /// convert to a field with respect to a schema |
50 | | fn to_field( |
51 | | &self, |
52 | | input_schema: &dyn ExprSchema, |
53 | | ) -> Result<(Option<TableReference>, Arc<Field>)>; |
54 | | |
55 | | /// cast to a type with respect to a schema |
56 | | fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>; |
57 | | |
58 | | /// given a schema, return the type and nullability of the expr |
59 | | fn data_type_and_nullable(&self, schema: &dyn ExprSchema) |
60 | | -> Result<(DataType, bool)>; |
61 | | } |
62 | | |
63 | | impl ExprSchemable for Expr { |
64 | | /// Returns the [arrow::datatypes::DataType] of the expression |
65 | | /// based on [ExprSchema] |
66 | | /// |
67 | | /// Note: [`DFSchema`] implements [ExprSchema]. |
68 | | /// |
69 | | /// [`DFSchema`]: datafusion_common::DFSchema |
70 | | /// |
71 | | /// # Examples |
72 | | /// |
73 | | /// Get the type of an expression that adds 2 columns. Adding an Int32 |
74 | | /// and Float32 results in Float32 type |
75 | | /// |
76 | | /// ``` |
77 | | /// # use arrow::datatypes::{DataType, Field}; |
78 | | /// # use datafusion_common::DFSchema; |
79 | | /// # use datafusion_expr::{col, ExprSchemable}; |
80 | | /// # use std::collections::HashMap; |
81 | | /// |
82 | | /// fn main() { |
83 | | /// let expr = col("c1") + col("c2"); |
84 | | /// let schema = DFSchema::from_unqualified_fields( |
85 | | /// vec![ |
86 | | /// Field::new("c1", DataType::Int32, true), |
87 | | /// Field::new("c2", DataType::Float32, true), |
88 | | /// ].into(), |
89 | | /// HashMap::new(), |
90 | | /// ).unwrap(); |
91 | | /// assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap())); |
92 | | /// } |
93 | | /// ``` |
94 | | /// |
95 | | /// # Errors |
96 | | /// |
97 | | /// This function errors when it is not possible to compute its |
98 | | /// [arrow::datatypes::DataType]. This happens when e.g. the |
99 | | /// expression refers to a column that does not exist in the |
100 | | /// schema, or when the expression is incorrectly typed |
101 | | /// (e.g. `[utf8] + [bool]`). |
102 | 0 | fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> { |
103 | 0 | match self { |
104 | 0 | Expr::Alias(Alias { expr, name, .. }) => match &**expr { |
105 | 0 | Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { |
106 | 0 | None => schema.data_type(&Column::from_name(name)).cloned(), |
107 | 0 | Some(dt) => Ok(dt.clone()), |
108 | | }, |
109 | 0 | _ => expr.get_type(schema), |
110 | | }, |
111 | 0 | Expr::Negative(expr) => expr.get_type(schema), |
112 | 0 | Expr::Column(c) => Ok(schema.data_type(c)?.clone()), |
113 | 0 | Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), |
114 | 0 | Expr::ScalarVariable(ty, _) => Ok(ty.clone()), |
115 | 0 | Expr::Literal(l) => Ok(l.data_type()), |
116 | 0 | Expr::Case(case) => { |
117 | 0 | for (_, then_expr) in &case.when_then_expr { |
118 | 0 | let then_type = then_expr.get_type(schema)?; |
119 | 0 | if !then_type.is_null() { |
120 | 0 | return Ok(then_type); |
121 | 0 | } |
122 | | } |
123 | 0 | case.else_expr |
124 | 0 | .as_ref() |
125 | 0 | .map_or(Ok(DataType::Null), |e| e.get_type(schema)) |
126 | | } |
127 | 0 | Expr::Cast(Cast { data_type, .. }) |
128 | 0 | | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), |
129 | 0 | Expr::Unnest(Unnest { expr }) => { |
130 | 0 | let arg_data_type = expr.get_type(schema)?; |
131 | | // Unnest's output type is the inner type of the list |
132 | 0 | match arg_data_type { |
133 | 0 | DataType::List(field) |
134 | 0 | | DataType::LargeList(field) |
135 | 0 | | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()), |
136 | 0 | DataType::Struct(_) => Ok(arg_data_type), |
137 | | DataType::Null => { |
138 | 0 | not_impl_err!("unnest() does not support null yet") |
139 | | } |
140 | | _ => { |
141 | 0 | plan_err!( |
142 | 0 | "unnest() can only be applied to array, struct and null" |
143 | 0 | ) |
144 | | } |
145 | | } |
146 | | } |
147 | 0 | Expr::ScalarFunction(ScalarFunction { func, args }) => { |
148 | 0 | let arg_data_types = args |
149 | 0 | .iter() |
150 | 0 | .map(|e| e.get_type(schema)) |
151 | 0 | .collect::<Result<Vec<_>>>()?; |
152 | | |
153 | | // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` |
154 | 0 | let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) |
155 | 0 | .map_err(|err| { |
156 | 0 | plan_datafusion_err!( |
157 | 0 | "{} {}", |
158 | 0 | err, |
159 | 0 | utils::generate_signature_error_msg( |
160 | 0 | func.name(), |
161 | 0 | func.signature().clone(), |
162 | 0 | &arg_data_types, |
163 | 0 | ) |
164 | 0 | ) |
165 | 0 | })?; |
166 | | |
167 | | // perform additional function arguments validation (due to limited |
168 | | // expressiveness of `TypeSignature`), then infer return type |
169 | 0 | Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) |
170 | | } |
171 | 0 | Expr::WindowFunction(window_function) => self |
172 | 0 | .data_type_and_nullable_with_window_function(schema, window_function) |
173 | 0 | .map(|(return_type, _)| return_type), |
174 | 0 | Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { |
175 | 0 | let data_types = args |
176 | 0 | .iter() |
177 | 0 | .map(|e| e.get_type(schema)) |
178 | 0 | .collect::<Result<Vec<_>>>()?; |
179 | 0 | let new_types = data_types_with_aggregate_udf(&data_types, func) |
180 | 0 | .map_err(|err| { |
181 | 0 | plan_datafusion_err!( |
182 | 0 | "{} {}", |
183 | 0 | err, |
184 | 0 | utils::generate_signature_error_msg( |
185 | 0 | func.name(), |
186 | 0 | func.signature().clone(), |
187 | 0 | &data_types |
188 | 0 | ) |
189 | 0 | ) |
190 | 0 | })?; |
191 | 0 | Ok(func.return_type(&new_types)?) |
192 | | } |
193 | | Expr::Not(_) |
194 | | | Expr::IsNull(_) |
195 | | | Expr::Exists { .. } |
196 | | | Expr::InSubquery(_) |
197 | | | Expr::Between { .. } |
198 | | | Expr::InList { .. } |
199 | | | Expr::IsNotNull(_) |
200 | | | Expr::IsTrue(_) |
201 | | | Expr::IsFalse(_) |
202 | | | Expr::IsUnknown(_) |
203 | | | Expr::IsNotTrue(_) |
204 | | | Expr::IsNotFalse(_) |
205 | 0 | | Expr::IsNotUnknown(_) => Ok(DataType::Boolean), |
206 | 0 | Expr::ScalarSubquery(subquery) => { |
207 | 0 | Ok(subquery.subquery.schema().field(0).data_type().clone()) |
208 | | } |
209 | | Expr::BinaryExpr(BinaryExpr { |
210 | 0 | ref left, |
211 | 0 | ref right, |
212 | 0 | ref op, |
213 | 0 | }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), |
214 | 0 | Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), |
215 | 0 | Expr::Placeholder(Placeholder { data_type, .. }) => { |
216 | 0 | data_type.clone().ok_or_else(|| { |
217 | 0 | plan_datafusion_err!( |
218 | 0 | "Placeholder type could not be resolved. Make sure that the \ |
219 | 0 | placeholder is bound to a concrete type, e.g. by providing \ |
220 | 0 | parameter values." |
221 | 0 | ) |
222 | 0 | }) |
223 | | } |
224 | 0 | Expr::Wildcard { .. } => Ok(DataType::Null), |
225 | | Expr::GroupingSet(_) => { |
226 | | // grouping sets do not really have a type and do not appear in projections |
227 | 0 | Ok(DataType::Null) |
228 | | } |
229 | | } |
230 | 0 | } |
231 | | |
232 | | /// Returns the nullability of the expression based on [ExprSchema]. |
233 | | /// |
234 | | /// Note: [`DFSchema`] implements [ExprSchema]. |
235 | | /// |
236 | | /// [`DFSchema`]: datafusion_common::DFSchema |
237 | | /// |
238 | | /// # Errors |
239 | | /// |
240 | | /// This function errors when it is not possible to compute its |
241 | | /// nullability. This happens when the expression refers to a |
242 | | /// column that does not exist in the schema. |
243 | 0 | fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> { |
244 | 0 | match self { |
245 | 0 | Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { |
246 | 0 | expr.nullable(input_schema) |
247 | | } |
248 | | |
249 | 0 | Expr::InList(InList { expr, list, .. }) => { |
250 | | // Avoid inspecting too many expressions. |
251 | | const MAX_INSPECT_LIMIT: usize = 6; |
252 | | // Stop if a nullable expression is found or an error occurs. |
253 | 0 | let has_nullable = std::iter::once(expr.as_ref()) |
254 | 0 | .chain(list) |
255 | 0 | .take(MAX_INSPECT_LIMIT) |
256 | 0 | .find_map(|e| { |
257 | 0 | e.nullable(input_schema) |
258 | 0 | .map(|nullable| if nullable { Some(()) } else { None }) |
259 | 0 | .transpose() |
260 | 0 | }) |
261 | 0 | .transpose()?; |
262 | 0 | Ok(match has_nullable { |
263 | | // If a nullable subexpression is found, the result may also be nullable. |
264 | 0 | Some(_) => true, |
265 | | // If the list is too long, we assume it is nullable. |
266 | 0 | None if list.len() + 1 > MAX_INSPECT_LIMIT => true, |
267 | | // All the subexpressions are non-nullable, so the result must be non-nullable. |
268 | 0 | _ => false, |
269 | | }) |
270 | | } |
271 | | |
272 | | Expr::Between(Between { |
273 | 0 | expr, low, high, .. |
274 | 0 | }) => Ok(expr.nullable(input_schema)? |
275 | 0 | || low.nullable(input_schema)? |
276 | 0 | || high.nullable(input_schema)?), |
277 | | |
278 | 0 | Expr::Column(c) => input_schema.nullable(c), |
279 | 0 | Expr::OuterReferenceColumn(_, _) => Ok(true), |
280 | 0 | Expr::Literal(value) => Ok(value.is_null()), |
281 | 0 | Expr::Case(case) => { |
282 | | // this expression is nullable if any of the input expressions are nullable |
283 | 0 | let then_nullable = case |
284 | 0 | .when_then_expr |
285 | 0 | .iter() |
286 | 0 | .map(|(_, t)| t.nullable(input_schema)) |
287 | 0 | .collect::<Result<Vec<_>>>()?; |
288 | 0 | if then_nullable.contains(&true) { |
289 | 0 | Ok(true) |
290 | 0 | } else if let Some(e) = &case.else_expr { |
291 | 0 | e.nullable(input_schema) |
292 | | } else { |
293 | | // CASE produces NULL if there is no `else` expr |
294 | | // (aka when none of the `when_then_exprs` match) |
295 | 0 | Ok(true) |
296 | | } |
297 | | } |
298 | 0 | Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), |
299 | 0 | Expr::ScalarFunction(ScalarFunction { func, args }) => { |
300 | 0 | Ok(func.is_nullable(args, input_schema)) |
301 | | } |
302 | 0 | Expr::AggregateFunction(AggregateFunction { func, .. }) => { |
303 | 0 | Ok(func.is_nullable()) |
304 | | } |
305 | 0 | Expr::WindowFunction(window_function) => self |
306 | 0 | .data_type_and_nullable_with_window_function( |
307 | 0 | input_schema, |
308 | 0 | window_function, |
309 | 0 | ) |
310 | 0 | .map(|(_, nullable)| nullable), |
311 | | Expr::ScalarVariable(_, _) |
312 | | | Expr::TryCast { .. } |
313 | | | Expr::Unnest(_) |
314 | 0 | | Expr::Placeholder(_) => Ok(true), |
315 | | Expr::IsNull(_) |
316 | | | Expr::IsNotNull(_) |
317 | | | Expr::IsTrue(_) |
318 | | | Expr::IsFalse(_) |
319 | | | Expr::IsUnknown(_) |
320 | | | Expr::IsNotTrue(_) |
321 | | | Expr::IsNotFalse(_) |
322 | | | Expr::IsNotUnknown(_) |
323 | 0 | | Expr::Exists { .. } => Ok(false), |
324 | 0 | Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), |
325 | 0 | Expr::ScalarSubquery(subquery) => { |
326 | 0 | Ok(subquery.subquery.schema().field(0).is_nullable()) |
327 | | } |
328 | | Expr::BinaryExpr(BinaryExpr { |
329 | 0 | ref left, |
330 | 0 | ref right, |
331 | 0 | .. |
332 | 0 | }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), |
333 | 0 | Expr::Like(Like { expr, pattern, .. }) |
334 | 0 | | Expr::SimilarTo(Like { expr, pattern, .. }) => { |
335 | 0 | Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) |
336 | | } |
337 | 0 | Expr::Wildcard { .. } => Ok(false), |
338 | | Expr::GroupingSet(_) => { |
339 | | // grouping sets do not really have the concept of nullable and do not appear |
340 | | // in projections |
341 | 0 | Ok(true) |
342 | | } |
343 | | } |
344 | 0 | } |
345 | | |
346 | 0 | fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> { |
347 | 0 | match self { |
348 | 0 | Expr::Column(c) => Ok(schema.metadata(c)?.clone()), |
349 | 0 | Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), |
350 | 0 | _ => Ok(HashMap::new()), |
351 | | } |
352 | 0 | } |
353 | | |
354 | | /// Returns the datatype and nullability of the expression based on [ExprSchema]. |
355 | | /// |
356 | | /// Note: [`DFSchema`] implements [ExprSchema]. |
357 | | /// |
358 | | /// [`DFSchema`]: datafusion_common::DFSchema |
359 | | /// |
360 | | /// # Errors |
361 | | /// |
362 | | /// This function errors when it is not possible to compute its |
363 | | /// datatype or nullability. |
364 | 0 | fn data_type_and_nullable( |
365 | 0 | &self, |
366 | 0 | schema: &dyn ExprSchema, |
367 | 0 | ) -> Result<(DataType, bool)> { |
368 | 0 | match self { |
369 | 0 | Expr::Alias(Alias { expr, name, .. }) => match &**expr { |
370 | 0 | Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { |
371 | 0 | None => schema |
372 | 0 | .data_type_and_nullable(&Column::from_name(name)) |
373 | 0 | .map(|(d, n)| (d.clone(), n)), |
374 | 0 | Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), |
375 | | }, |
376 | 0 | _ => expr.data_type_and_nullable(schema), |
377 | | }, |
378 | 0 | Expr::Negative(expr) => expr.data_type_and_nullable(schema), |
379 | 0 | Expr::Column(c) => schema |
380 | 0 | .data_type_and_nullable(c) |
381 | 0 | .map(|(d, n)| (d.clone(), n)), |
382 | 0 | Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), |
383 | 0 | Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), |
384 | 0 | Expr::Literal(l) => Ok((l.data_type(), l.is_null())), |
385 | | Expr::IsNull(_) |
386 | | | Expr::IsNotNull(_) |
387 | | | Expr::IsTrue(_) |
388 | | | Expr::IsFalse(_) |
389 | | | Expr::IsUnknown(_) |
390 | | | Expr::IsNotTrue(_) |
391 | | | Expr::IsNotFalse(_) |
392 | | | Expr::IsNotUnknown(_) |
393 | 0 | | Expr::Exists { .. } => Ok((DataType::Boolean, false)), |
394 | 0 | Expr::ScalarSubquery(subquery) => Ok(( |
395 | 0 | subquery.subquery.schema().field(0).data_type().clone(), |
396 | 0 | subquery.subquery.schema().field(0).is_nullable(), |
397 | 0 | )), |
398 | | Expr::BinaryExpr(BinaryExpr { |
399 | 0 | ref left, |
400 | 0 | ref right, |
401 | 0 | ref op, |
402 | | }) => { |
403 | 0 | let left = left.data_type_and_nullable(schema)?; |
404 | 0 | let right = right.data_type_and_nullable(schema)?; |
405 | 0 | Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1)) |
406 | | } |
407 | 0 | Expr::WindowFunction(window_function) => { |
408 | 0 | self.data_type_and_nullable_with_window_function(schema, window_function) |
409 | | } |
410 | 0 | _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), |
411 | | } |
412 | 0 | } |
413 | | |
414 | | /// Returns a [arrow::datatypes::Field] compatible with this expression. |
415 | | /// |
416 | | /// So for example, a projected expression `col(c1) + col(c2)` is |
417 | | /// placed in an output field **named** col("c1 + c2") |
418 | 0 | fn to_field( |
419 | 0 | &self, |
420 | 0 | input_schema: &dyn ExprSchema, |
421 | 0 | ) -> Result<(Option<TableReference>, Arc<Field>)> { |
422 | 0 | let (relation, schema_name) = self.qualified_name(); |
423 | 0 | let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; |
424 | 0 | let field = Field::new(schema_name, data_type, nullable) |
425 | 0 | .with_metadata(self.metadata(input_schema)?) |
426 | 0 | .into(); |
427 | 0 | Ok((relation, field)) |
428 | 0 | } |
429 | | |
430 | | /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. |
431 | | /// |
432 | | /// # Errors |
433 | | /// |
434 | | /// This function errors when it is impossible to cast the |
435 | | /// expression to the target [arrow::datatypes::DataType]. |
436 | 0 | fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> { |
437 | 0 | let this_type = self.get_type(schema)?; |
438 | 0 | if this_type == *cast_to_type { |
439 | 0 | return Ok(self); |
440 | 0 | } |
441 | 0 |
|
442 | 0 | // TODO(kszucs): most of the operations do not validate the type correctness |
443 | 0 | // like all of the binary expressions below. Perhaps Expr should track the |
444 | 0 | // type of the expression? |
445 | 0 |
|
446 | 0 | if can_cast_types(&this_type, cast_to_type) { |
447 | 0 | match self { |
448 | 0 | Expr::ScalarSubquery(subquery) => { |
449 | 0 | Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) |
450 | | } |
451 | 0 | _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), |
452 | | } |
453 | | } else { |
454 | 0 | plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}") |
455 | | } |
456 | 0 | } |
457 | | } |
458 | | |
459 | | impl Expr { |
460 | | /// Common method for window functions that applies type coercion |
461 | | /// to all arguments of the window function to check if it matches |
462 | | /// its signature. |
463 | | /// |
464 | | /// If successful, this method returns the data type and |
465 | | /// nullability of the window function's result. |
466 | | /// |
467 | | /// Otherwise, returns an error if there's a type mismatch between |
468 | | /// the window function's signature and the provided arguments. |
469 | 0 | fn data_type_and_nullable_with_window_function( |
470 | 0 | &self, |
471 | 0 | schema: &dyn ExprSchema, |
472 | 0 | window_function: &WindowFunction, |
473 | 0 | ) -> Result<(DataType, bool)> { |
474 | 0 | let WindowFunction { fun, args, .. } = window_function; |
475 | | |
476 | 0 | let data_types = args |
477 | 0 | .iter() |
478 | 0 | .map(|e| e.get_type(schema)) |
479 | 0 | .collect::<Result<Vec<_>>>()?; |
480 | 0 | match fun { |
481 | 0 | WindowFunctionDefinition::BuiltInWindowFunction(window_fun) => { |
482 | 0 | let return_type = window_fun.return_type(&data_types)?; |
483 | 0 | let nullable = |
484 | 0 | !["RANK", "NTILE", "CUME_DIST"].contains(&window_fun.name()); |
485 | 0 | Ok((return_type, nullable)) |
486 | | } |
487 | 0 | WindowFunctionDefinition::AggregateUDF(udaf) => { |
488 | 0 | let new_types = data_types_with_aggregate_udf(&data_types, udaf) |
489 | 0 | .map_err(|err| { |
490 | 0 | plan_datafusion_err!( |
491 | 0 | "{} {}", |
492 | 0 | err, |
493 | 0 | utils::generate_signature_error_msg( |
494 | 0 | fun.name(), |
495 | 0 | fun.signature(), |
496 | 0 | &data_types |
497 | 0 | ) |
498 | 0 | ) |
499 | 0 | })?; |
500 | | |
501 | 0 | let return_type = udaf.return_type(&new_types)?; |
502 | 0 | let nullable = udaf.is_nullable(); |
503 | 0 |
|
504 | 0 | Ok((return_type, nullable)) |
505 | | } |
506 | 0 | WindowFunctionDefinition::WindowUDF(udwf) => { |
507 | 0 | let new_types = |
508 | 0 | data_types_with_window_udf(&data_types, udwf).map_err(|err| { |
509 | 0 | plan_datafusion_err!( |
510 | 0 | "{} {}", |
511 | 0 | err, |
512 | 0 | utils::generate_signature_error_msg( |
513 | 0 | fun.name(), |
514 | 0 | fun.signature(), |
515 | 0 | &data_types |
516 | 0 | ) |
517 | 0 | ) |
518 | 0 | })?; |
519 | 0 | let (_, function_name) = self.qualified_name(); |
520 | 0 | let field_args = WindowUDFFieldArgs::new(&new_types, &function_name); |
521 | 0 |
|
522 | 0 | udwf.field(field_args) |
523 | 0 | .map(|field| (field.data_type().clone(), field.is_nullable())) |
524 | | } |
525 | | } |
526 | 0 | } |
527 | | } |
528 | | |
529 | | /// cast subquery in InSubquery/ScalarSubquery to a given type. |
530 | 0 | pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> { |
531 | 0 | if subquery.subquery.schema().field(0).data_type() == cast_to_type { |
532 | 0 | return Ok(subquery); |
533 | 0 | } |
534 | 0 |
|
535 | 0 | let plan = subquery.subquery.as_ref(); |
536 | 0 | let new_plan = match plan { |
537 | 0 | LogicalPlan::Projection(projection) => { |
538 | 0 | let cast_expr = projection.expr[0] |
539 | 0 | .clone() |
540 | 0 | .cast_to(cast_to_type, projection.input.schema())?; |
541 | 0 | LogicalPlan::Projection(Projection::try_new( |
542 | 0 | vec![cast_expr], |
543 | 0 | Arc::clone(&projection.input), |
544 | 0 | )?) |
545 | | } |
546 | | _ => { |
547 | 0 | let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0))) |
548 | 0 | .cast_to(cast_to_type, subquery.subquery.schema())?; |
549 | 0 | LogicalPlan::Projection(Projection::try_new( |
550 | 0 | vec![cast_expr], |
551 | 0 | subquery.subquery, |
552 | 0 | )?) |
553 | | } |
554 | | }; |
555 | 0 | Ok(Subquery { |
556 | 0 | subquery: Arc::new(new_plan), |
557 | 0 | outer_ref_columns: subquery.outer_ref_columns, |
558 | 0 | }) |
559 | 0 | } |
560 | | |
561 | | #[cfg(test)] |
562 | | mod tests { |
563 | | use super::*; |
564 | | use crate::{col, lit}; |
565 | | |
566 | | use datafusion_common::{internal_err, DFSchema, ScalarValue}; |
567 | | |
568 | | macro_rules! test_is_expr_nullable { |
569 | | ($EXPR_TYPE:ident) => {{ |
570 | | let expr = lit(ScalarValue::Null).$EXPR_TYPE(); |
571 | | assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); |
572 | | }}; |
573 | | } |
574 | | |
575 | | #[test] |
576 | | fn expr_schema_nullability() { |
577 | | let expr = col("foo").eq(lit(1)); |
578 | | assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); |
579 | | assert!(expr |
580 | | .nullable(&MockExprSchema::new().with_nullable(true)) |
581 | | .unwrap()); |
582 | | |
583 | | test_is_expr_nullable!(is_null); |
584 | | test_is_expr_nullable!(is_not_null); |
585 | | test_is_expr_nullable!(is_true); |
586 | | test_is_expr_nullable!(is_not_true); |
587 | | test_is_expr_nullable!(is_false); |
588 | | test_is_expr_nullable!(is_not_false); |
589 | | test_is_expr_nullable!(is_unknown); |
590 | | test_is_expr_nullable!(is_not_unknown); |
591 | | } |
592 | | |
593 | | #[test] |
594 | | fn test_between_nullability() { |
595 | | let get_schema = |nullable| { |
596 | | MockExprSchema::new() |
597 | | .with_data_type(DataType::Int32) |
598 | | .with_nullable(nullable) |
599 | | }; |
600 | | |
601 | | let expr = col("foo").between(lit(1), lit(2)); |
602 | | assert!(!expr.nullable(&get_schema(false)).unwrap()); |
603 | | assert!(expr.nullable(&get_schema(true)).unwrap()); |
604 | | |
605 | | let null = lit(ScalarValue::Int32(None)); |
606 | | |
607 | | let expr = col("foo").between(null.clone(), lit(2)); |
608 | | assert!(expr.nullable(&get_schema(false)).unwrap()); |
609 | | |
610 | | let expr = col("foo").between(lit(1), null.clone()); |
611 | | assert!(expr.nullable(&get_schema(false)).unwrap()); |
612 | | |
613 | | let expr = col("foo").between(null.clone(), null); |
614 | | assert!(expr.nullable(&get_schema(false)).unwrap()); |
615 | | } |
616 | | |
617 | | #[test] |
618 | | fn test_inlist_nullability() { |
619 | | let get_schema = |nullable| { |
620 | | MockExprSchema::new() |
621 | | .with_data_type(DataType::Int32) |
622 | | .with_nullable(nullable) |
623 | | }; |
624 | | |
625 | | let expr = col("foo").in_list(vec![lit(1); 5], false); |
626 | | assert!(!expr.nullable(&get_schema(false)).unwrap()); |
627 | | assert!(expr.nullable(&get_schema(true)).unwrap()); |
628 | | // Testing nullable() returns an error. |
629 | | assert!(expr |
630 | | .nullable(&get_schema(false).with_error_on_nullable(true)) |
631 | | .is_err()); |
632 | | |
633 | | let null = lit(ScalarValue::Int32(None)); |
634 | | let expr = col("foo").in_list(vec![null, lit(1)], false); |
635 | | assert!(expr.nullable(&get_schema(false)).unwrap()); |
636 | | |
637 | | // Testing on long list |
638 | | let expr = col("foo").in_list(vec![lit(1); 6], false); |
639 | | assert!(expr.nullable(&get_schema(false)).unwrap()); |
640 | | } |
641 | | |
642 | | #[test] |
643 | | fn test_like_nullability() { |
644 | | let get_schema = |nullable| { |
645 | | MockExprSchema::new() |
646 | | .with_data_type(DataType::Utf8) |
647 | | .with_nullable(nullable) |
648 | | }; |
649 | | |
650 | | let expr = col("foo").like(lit("bar")); |
651 | | assert!(!expr.nullable(&get_schema(false)).unwrap()); |
652 | | assert!(expr.nullable(&get_schema(true)).unwrap()); |
653 | | |
654 | | let expr = col("foo").like(lit(ScalarValue::Utf8(None))); |
655 | | assert!(expr.nullable(&get_schema(false)).unwrap()); |
656 | | } |
657 | | |
658 | | #[test] |
659 | | fn expr_schema_data_type() { |
660 | | let expr = col("foo"); |
661 | | assert_eq!( |
662 | | DataType::Utf8, |
663 | | expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) |
664 | | .unwrap() |
665 | | ); |
666 | | } |
667 | | |
668 | | #[test] |
669 | | fn test_expr_metadata() { |
670 | | let mut meta = HashMap::new(); |
671 | | meta.insert("bar".to_string(), "buzz".to_string()); |
672 | | let expr = col("foo"); |
673 | | let schema = MockExprSchema::new() |
674 | | .with_data_type(DataType::Int32) |
675 | | .with_metadata(meta.clone()); |
676 | | |
677 | | // col and alias should be metadata-preserving |
678 | | assert_eq!(meta, expr.metadata(&schema).unwrap()); |
679 | | assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap()); |
680 | | |
681 | | // cast should drop input metadata since the type has changed |
682 | | assert_eq!( |
683 | | HashMap::new(), |
684 | | expr.clone() |
685 | | .cast_to(&DataType::Int64, &schema) |
686 | | .unwrap() |
687 | | .metadata(&schema) |
688 | | .unwrap() |
689 | | ); |
690 | | |
691 | | let schema = DFSchema::from_unqualified_fields( |
692 | | vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())] |
693 | | .into(), |
694 | | HashMap::new(), |
695 | | ) |
696 | | .unwrap(); |
697 | | |
698 | | // verify to_field method populates metadata |
699 | | assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata()); |
700 | | } |
701 | | |
702 | | #[derive(Debug)] |
703 | | struct MockExprSchema { |
704 | | nullable: bool, |
705 | | data_type: DataType, |
706 | | error_on_nullable: bool, |
707 | | metadata: HashMap<String, String>, |
708 | | } |
709 | | |
710 | | impl MockExprSchema { |
711 | | fn new() -> Self { |
712 | | Self { |
713 | | nullable: false, |
714 | | data_type: DataType::Null, |
715 | | error_on_nullable: false, |
716 | | metadata: HashMap::new(), |
717 | | } |
718 | | } |
719 | | |
720 | | fn with_nullable(mut self, nullable: bool) -> Self { |
721 | | self.nullable = nullable; |
722 | | self |
723 | | } |
724 | | |
725 | | fn with_data_type(mut self, data_type: DataType) -> Self { |
726 | | self.data_type = data_type; |
727 | | self |
728 | | } |
729 | | |
730 | | fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self { |
731 | | self.error_on_nullable = error_on_nullable; |
732 | | self |
733 | | } |
734 | | |
735 | | fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self { |
736 | | self.metadata = metadata; |
737 | | self |
738 | | } |
739 | | } |
740 | | |
741 | | impl ExprSchema for MockExprSchema { |
742 | | fn nullable(&self, _col: &Column) -> Result<bool> { |
743 | | if self.error_on_nullable { |
744 | | internal_err!("nullable error") |
745 | | } else { |
746 | | Ok(self.nullable) |
747 | | } |
748 | | } |
749 | | |
750 | | fn data_type(&self, _col: &Column) -> Result<&DataType> { |
751 | | Ok(&self.data_type) |
752 | | } |
753 | | |
754 | | fn metadata(&self, _col: &Column) -> Result<&HashMap<String, String>> { |
755 | | Ok(&self.metadata) |
756 | | } |
757 | | |
758 | | fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { |
759 | | Ok((self.data_type(col)?, self.nullable(col)?)) |
760 | | } |
761 | | } |
762 | | } |