diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 3ea5c63997858..018e0b9fc5e21 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -172,18 +172,19 @@ pub fn coerce_arguments_for_signature( mod test { use crate::type_coercion::TypeCoercion; use crate::{OptimizerConfig, OptimizerRule}; + use arrow::datatypes::DataType; use datafusion_common::{DFSchema, Result}; use datafusion_expr::logical_plan::{EmptyRelation, Projection}; - use datafusion_expr::{lit, LogicalPlan}; + use datafusion_expr::{ + lit, Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, + ScalarUDF, Signature, Volatility, + }; use std::sync::Arc; #[test] - fn simple_case() -> Result<()> { + fn binary_expr_simple_case() -> Result<()> { let expr = lit(1.2_f64).lt(lit(2_u32)); - let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: Arc::new(DFSchema::empty()), - })); + let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?); let rule = TypeCoercion::new(); let mut config = OptimizerConfig::default(); @@ -196,12 +197,9 @@ mod test { } #[test] - fn nested_case() -> Result<()> { + fn binary_expr_nested_case() -> Result<()> { let expr = lit(1.2_f64).lt(lit(2_u32)); - let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: Arc::new(DFSchema::empty()), - })); + let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new( vec![expr.clone().or(expr)], empty, @@ -214,4 +212,38 @@ mod test { \n EmptyRelation", &format!("{:?}", plan)); Ok(()) } + + #[test] + fn scalar_udf() -> Result<()> { + let empty = empty(); + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); + let udf = Expr::ScalarUDF { + fun: Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + &return_type, + &fun, + )), + args: vec![lit(123_i32)], + }; + let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation", + &format!("{:?}", plan) + ); + Ok(()) + } + + fn empty() -> Arc { + let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })); + empty + } }