From 3c50b7cc49f944af524b8d95e5f016230f82fc12 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 1 Dec 2023 21:14:19 +0100 Subject: [PATCH] Refactor function argument handling in (#8387) ScalarFunctionDefinition --- datafusion/expr/src/expr_schema.rs | 15 ++--- datafusion/physical-expr/src/planner.rs | 66 ++++++++----------- datafusion/proto/src/logical_plan/to_proto.rs | 53 ++++++++------- 3 files changed, 58 insertions(+), 76 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2795ac5f0962..e5b0185d90e0 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -83,13 +83,12 @@ impl ExprSchemable for Expr { Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let arg_data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - // verify that input data types is consistent with function's `TypeSignature` data_types(&arg_data_types, &fun.signature()).map_err(|_| { plan_datafusion_err!( @@ -105,11 +104,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok(fun.return_type(&data_types)?) + Ok(fun.return_type(&arg_data_types)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5501647da2c3..9c212cb81f6b 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -348,50 +348,38 @@ pub fn create_physical_expr( ))) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let physical_args = args - .iter() - .map(|e| { - create_physical_expr( - e, - input_dfschema, - input_schema, - execution_props, - ) - }) - .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } - ScalarFunctionDefinition::UDF(fun) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let mut physical_args = args + .iter() + .map(|e| { + create_physical_expr(e, input_dfschema, input_schema, execution_props) + }) + .collect::>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + functions::create_physical_expr( + fun, + &physical_args, input_schema, execution_props, - )?); + ) + } + ScalarFunctionDefinition::UDF(fun) => { + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } - udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - ) } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }, + } Expr::Between(Between { expr, negated, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ab8e850014e5..ecbfaca5dbfe 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -792,40 +792,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args - .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), args, }, )), - } - } - ScalarFunctionDefinition::UDF(fun) => Self { - expr_type: Some(ExprType::ScalarUdfExpr( - protobuf::ScalarUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - }, - )), - }, - ScalarFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); + } } - }, + } Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)),