diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index ce5dc041b99c..23464525171e 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -703,9 +703,7 @@ use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_execution::TaskContext; -pub use datafusion_physical_expr::{ - expressions, functions, hash_utils, type_coercion, udf, -}; +pub use datafusion_physical_expr::{expressions, functions, hash_utils, udf}; #[cfg(test)] mod tests { diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index a17b8ba87e6a..b4b2d6bc6474 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -24,14 +24,13 @@ use crate::physical_plan::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, }, - type_coercion::coerce, udaf, ExecutionPlan, PhysicalExpr, }; use crate::scalar::ScalarValue; use arrow::datatypes::Schema; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_expr::{ - window_function::{signature_for_built_in, BuiltInWindowFunction, WindowFunction}, + window_function::{BuiltInWindowFunction, WindowFunction}, WindowFrame, }; use datafusion_physical_expr::window::{ @@ -133,8 +132,7 @@ fn create_built_in_window_expr( BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), BuiltInWindowFunction::Ntile => { - let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; - let n: i64 = get_scalar_value_from_args(&coerced_args, 0)? + let n: i64 = get_scalar_value_from_args(args, 0)? .ok_or_else(|| { DataFusionError::Execution( "NTILE requires at least 1 argument".to_string(), @@ -145,33 +143,26 @@ fn create_built_in_window_expr( Arc::new(Ntile::new(name, n)) } BuiltInWindowFunction::Lag => { - let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; - let arg = coerced_args[0].clone(); + let arg = args[0].clone(); let data_type = args[0].data_type(input_schema)?; - let shift_offset = get_scalar_value_from_args(&coerced_args, 1)? + let shift_offset = get_scalar_value_from_args(args, 1)? .map(|v| v.try_into()) .and_then(|v| v.ok()); - let default_value = get_scalar_value_from_args(&coerced_args, 2)?; + let default_value = get_scalar_value_from_args(args, 2)?; Arc::new(lag(name, data_type, arg, shift_offset, default_value)) } BuiltInWindowFunction::Lead => { - let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; - let arg = coerced_args[0].clone(); + let arg = args[0].clone(); let data_type = args[0].data_type(input_schema)?; - let shift_offset = get_scalar_value_from_args(&coerced_args, 1)? + let shift_offset = get_scalar_value_from_args(args, 1)? .map(|v| v.try_into()) .and_then(|v| v.ok()); - let default_value = get_scalar_value_from_args(&coerced_args, 2)?; + let default_value = get_scalar_value_from_args(args, 2)?; Arc::new(lead(name, data_type, arg, shift_offset, default_value)) } BuiltInWindowFunction::NthValue => { - let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; - let arg = coerced_args[0].clone(); - let n = coerced_args[1] - .as_any() - .downcast_ref::() - .unwrap() - .value(); + let arg = args[0].clone(); + let n = args[1].as_any().downcast_ref::().unwrap().value(); let n: i64 = n .clone() .try_into() @@ -181,14 +172,12 @@ fn create_built_in_window_expr( Arc::new(NthValue::nth(name, arg, data_type, n)?) } BuiltInWindowFunction::FirstValue => { - let arg = - coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let arg = args[0].clone(); let data_type = args[0].data_type(input_schema)?; Arc::new(NthValue::first(name, arg, data_type)) } BuiltInWindowFunction::LastValue => { - let arg = - coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let arg = args[0].clone(); let data_type = args[0].data_type(input_schema)?; Arc::new(NthValue::last(name, arg, data_type)) } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 15728854ff3d..e8b262c7617c 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -809,9 +809,9 @@ pub fn create_physical_fun( #[cfg(test)] mod tests { use super::*; + use crate::expressions::try_cast; use crate::expressions::{col, lit}; use crate::from_slice::FromSlice; - use crate::type_coercion::coerce; use arrow::{ array::{ Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, @@ -822,6 +822,8 @@ mod tests { }; use datafusion_common::cast::as_uint64_array; use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::type_coercion::functions::data_types; + use datafusion_expr::Signature; /// $FUNC function to test /// $ARGS arguments (vec) to pass to function @@ -2885,7 +2887,33 @@ mod tests { Ok(()) } - // Helper function + // Helper function just for testing. + // Returns `expressions` coerced to types compatible with + // `signature`, if possible. + pub fn coerce( + expressions: &[Arc], + schema: &Schema, + signature: &Signature, + ) -> Result>> { + if expressions.is_empty() { + return Ok(vec![]); + } + + let current_types = expressions + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + let new_types = data_types(¤t_types, signature)?; + + expressions + .iter() + .enumerate() + .map(|(i, expr)| try_cast(expr.clone(), schema, new_types[i].clone())) + .collect::>>() + } + + // Helper function just for testing. // The type coercion will be done in the logical phase, should do the type coercion for the test fn create_physical_expr_with_type_coercion( fun: &BuiltinScalarFunction, diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 710e9342b127..21a88b5d891d 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -37,7 +37,6 @@ mod sort_expr; pub mod string_expressions; pub mod struct_expressions; pub mod tree_node; -pub mod type_coercion; pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index dc93b67fa655..df519551d8a6 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -80,7 +80,7 @@ impl PhysicalSortExpr { /// Represents sort requirement associated with a plan /// -/// If the requirement incudes [`SortOptions`] then both the +/// If the requirement includes [`SortOptions`] then both the /// expression *and* the sort options must match. /// /// If the requirement does not include [`SortOptions`]) then only the diff --git a/datafusion/physical-expr/src/type_coercion.rs b/datafusion/physical-expr/src/type_coercion.rs deleted file mode 100644 index 399dcc089900..000000000000 --- a/datafusion/physical-expr/src/type_coercion.rs +++ /dev/null @@ -1,201 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Type coercion rules for functions with multiple valid signatures -//! -//! Coercion is performed automatically by DataFusion when the types -//! of arguments passed to a function do not exactly match the types -//! required by that function. In this case, DataFusion will attempt to -//! *coerce* the arguments to types accepted by the function by -//! inserting CAST operations. -//! -//! CAST operations added by coercion are lossless and never discard -//! information. For example coercion from i32 -> i64 might be -//! performed because all valid i32 values can be represented using an -//! i64. However, i64 -> i32 is never performed as there are i64 -//! values which can not be represented by i32 values. - -use super::PhysicalExpr; -use crate::expressions::try_cast; -use arrow::datatypes::Schema; -use datafusion_common::Result; -use datafusion_expr::{type_coercion::functions::data_types, Signature}; -use std::{sync::Arc, vec}; - -/// Returns `expressions` coerced to types compatible with -/// `signature`, if possible. -/// -/// See the module level documentation for more detail on coercion. -pub fn coerce( - expressions: &[Arc], - schema: &Schema, - signature: &Signature, -) -> Result>> { - if expressions.is_empty() { - return Ok(vec![]); - } - - let current_types = expressions - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - let new_types = data_types(¤t_types, signature)?; - - expressions - .iter() - .enumerate() - .map(|(i, expr)| try_cast(expr.clone(), schema, new_types[i].clone())) - .collect::>>() -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::Fields; - use datafusion_common::DataFusionError; - use datafusion_expr::Volatility; - - #[test] - fn test_coerce() -> Result<()> { - // create a schema - let schema = |t: Vec| { - Schema::new( - t.iter() - .enumerate() - .map(|(i, t)| Field::new(format!("c{i}"), t.clone(), true)) - .collect::(), - ) - }; - - // create a vector of expressions - let expressions = |t: Vec, schema| -> Result> { - t.iter() - .enumerate() - .map(|(i, t)| { - try_cast(col(&format!("c{i}"), &schema)?, &schema, t.clone()) - }) - .collect::>>() - }; - - // create a case: input + expected result - let case = - |observed: Vec, valid, expected: Vec| -> Result<_> { - let schema = schema(observed.clone()); - let expr = expressions(observed, schema.clone())?; - let expected = expressions(expected, schema.clone())?; - Ok((expr.clone(), schema, valid, expected)) - }; - - let cases = vec![ - // u16 -> u32 - case( - vec![DataType::UInt16], - Signature::uniform(1, vec![DataType::UInt32], Volatility::Immutable), - vec![DataType::UInt32], - )?, - // same type - case( - vec![DataType::UInt32, DataType::UInt32], - Signature::uniform(2, vec![DataType::UInt32], Volatility::Immutable), - vec![DataType::UInt32, DataType::UInt32], - )?, - case( - vec![DataType::UInt32], - Signature::uniform( - 1, - vec![DataType::Float32, DataType::Float64], - Volatility::Immutable, - ), - vec![DataType::Float32], - )?, - // u32 -> f32 - case( - vec![DataType::UInt32, DataType::UInt32], - Signature::variadic(vec![DataType::Float32], Volatility::Immutable), - vec![DataType::Float32, DataType::Float32], - )?, - // u32 -> f32 - case( - vec![DataType::Float32, DataType::UInt32], - Signature::variadic_equal(Volatility::Immutable), - vec![DataType::Float32, DataType::Float32], - )?, - // common type is u64 - case( - vec![DataType::UInt32, DataType::UInt64], - Signature::variadic( - vec![DataType::UInt32, DataType::UInt64], - Volatility::Immutable, - ), - vec![DataType::UInt64, DataType::UInt64], - )?, - // f32 -> f32 - case( - vec![DataType::Float32], - Signature::any(1, Volatility::Immutable), - vec![DataType::Float32], - )?, - ]; - - for case in cases { - let observed = format!("{:?}", coerce(&case.0, &case.1, &case.2)?); - let expected = format!("{:?}", case.3); - assert_eq!(observed, expected); - } - - // now cases that are expected to fail - let cases = vec![ - // we do not know how to cast bool to UInt16 => fail - case( - vec![DataType::Boolean], - Signature::uniform(1, vec![DataType::UInt16], Volatility::Immutable), - vec![], - )?, - // u32 and bool are not uniform - case( - vec![DataType::UInt32, DataType::Boolean], - Signature::variadic_equal(Volatility::Immutable), - vec![], - )?, - // bool is not castable to u32 - case( - vec![DataType::Boolean, DataType::Boolean], - Signature::variadic(vec![DataType::UInt32], Volatility::Immutable), - vec![], - )?, - // expected two arguments - case( - vec![DataType::UInt32], - Signature::any(2, Volatility::Immutable), - vec![], - )?, - ]; - - for case in cases { - if coerce(&case.0, &case.1, &case.2).is_ok() { - return Err(DataFusionError::Plan(format!( - "Error was expected in {case:?}" - ))); - } - } - - Ok(()) - } -}