diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index f9b02f4d0c10..0c39877cd11e 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -122,8 +122,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context - Expr::AggregateUDF { .. } - | Expr::AggregateFunction { .. } + Expr::AggregateFunction { .. } | Expr::Sort { .. } | Expr::WindowFunction { .. } | Expr::Wildcard { .. } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ef364c22ee7d..9e64eb9c5108 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -82,8 +82,9 @@ use datafusion_common::{ }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, + self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, + Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, + WindowFunction, }; use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -229,30 +230,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(&fun.to_string(), false, args) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, - .. - }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF(AggregateUDF { - fun, - args, filter, order_by, - }) => { - // TODO: Add support for filter and order by in AggregateUDF - if filter.is_some() { - return exec_err!("aggregate expression with filter is not supported"); + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(..) => { + create_function_physical_name(func_def.name(), *distinct, args) } - if order_by.is_some() { - return exec_err!("aggregate expression with order_by is not supported"); + AggregateFunctionDefinition::UDF(fun) => { + // TODO: Add support for filter and order by in AggregateUDF + if filter.is_some() { + return exec_err!( + "aggregate expression with filter is not supported" + ); + } + if order_by.is_some() { + return exec_err!( + "aggregate expression with order_by is not supported" + ); + } + let names = args + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()?; + Ok(format!("{}({})", fun.name(), names.join(","))) } - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_physical_name(e, false)?); + AggregateFunctionDefinition::Name(_) => { + internal_err!("Aggregate function `Expr` with name should be resolved.") } - Ok(format!("{}({})", fun.name(), names.join(","))) - } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -1705,7 +1713,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, @@ -1746,63 +1754,35 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; - let ordering_reqs = order_by.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - &ordering_reqs, - physical_input_schema, - name, - )?; - Ok((agg_expr, filter, order_by)) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, + let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let agg_expr = aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + &ordering_reqs, physical_input_schema, - execution_props, + name, + )?; + (agg_expr, filter, order_by) + } + AggregateFunctionDefinition::UDF(fun) => { + let agg_expr = udaf::create_aggregate_expr( + fun, + &args, + physical_input_schema, + name, + ); + (agg_expr?, filter, order_by) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Aggregate function name should have been resolved" ) - }) - .collect::>>()?; - - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, + } }; - - let agg_expr = - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter, order_by)) + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 4611c7fb10d7..cea72c3cb5e6 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -105,7 +105,7 @@ pub enum AggregateFunction { } impl AggregateFunction { - fn name(&self) -> &str { + pub fn name(&self) -> &str { use AggregateFunction::*; match self { Count => "COUNT", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b46d204faafb..256f5b210ec2 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -154,8 +154,6 @@ pub enum Expr { AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), - /// aggregate function - AggregateUDF(AggregateUDF), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -484,11 +482,33 @@ impl Sort { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum AggregateFunctionDefinition { + BuiltIn(aggregate_function::AggregateFunction), + /// Resolved to a user defined aggregate function + UDF(Arc), + /// A aggregation function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +impl AggregateFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), + AggregateFunctionDefinition::UDF(udf) => udf.name(), + AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } +} + /// Aggregate function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub fun: aggregate_function::AggregateFunction, + pub func_def: AggregateFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -508,7 +528,24 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), + args, + distinct, + filter, + order_by, + } + } + + /// Create a new AggregateFunction expression with a user-defined function (UDF) + pub fn new_udf( + udf: Arc, + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + ) -> Self { + Self { + func_def: AggregateFunctionDefinition::UDF(udf), args, distinct, filter, @@ -736,7 +773,6 @@ impl Expr { pub fn variant_name(&self) -> &str { match self { Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::AggregateUDF { .. } => "AggregateUDF", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", Expr::BinaryExpr { .. } => "BinaryExpr", @@ -1251,30 +1287,14 @@ impl fmt::Display for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, ref args, filter, order_by, .. }) => { - fmt_function(f, &fun.to_string(), *distinct, args, true)?; - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; - } - Ok(()) - } - Expr::AggregateUDF(AggregateUDF { - fun, - ref args, - filter, - order_by, - .. - }) => { - fmt_function(f, fun.name(), false, args, true)?; + fmt_function(f, func_def.name(), *distinct, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } @@ -1579,39 +1599,39 @@ fn create_name(e: &Expr) -> Result { Ok(parts.join(" ")) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, order_by, }) => { - let mut name = create_function_name(&fun.to_string(), *distinct, args)?; - if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); + let name = match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + create_function_name(func_def.name(), *distinct, args)? + } + AggregateFunctionDefinition::UDF(..) => { + let names: Vec = + args.iter().map(create_name).collect::>()?; + names.join(",") + } }; - Ok(name) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } let mut info = String::new(); if let Some(fe) = filter { info += &format!(" FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); + }; + match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + Ok(format!("{}{}", name, info)) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(format!("{}({}){}", fun.name(), name, info)) + } } - if let Some(ob) = order_by { - info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); - } - Ok(format!("{}({}){}", fun.name(), names.join(","), info)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d5d9c848b2e9..99b27e8912bc 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, + GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; @@ -123,19 +123,22 @@ impl ExprSchemable for Expr { .collect::>>()?; fun.return_type(&data_types) } - Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) - } - Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - fun.return_type(&data_types) + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + fun.return_type(&data_types) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&data_types)?) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::Not(_) | Expr::IsNull(_) @@ -252,7 +255,6 @@ impl ExprSchemable for Expr { | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 474b5f7689b9..fcb0a4cd93f3 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -18,9 +18,9 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, - GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarFunctionDefinition, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, + Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, + ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; @@ -108,7 +108,7 @@ impl TreeNode for Expr { expr_vec } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => { + => { let mut expr_vec = args.clone(); if let Some(f) = filter { @@ -304,17 +304,40 @@ impl TreeNode for Expr { )), Expr::AggregateFunction(AggregateFunction { args, - fun, + func_def, distinct, filter, order_by, - }) => Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )), + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Expr::AggregateFunction(AggregateFunction::new( + fun, + transform_vec(args, &mut transform)?, + distinct, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::UDF(fun) => { + let order_by = if let Some(order_by) = order_by { + Some(transform_vec(order_by, &mut transform)?) + } else { + None + }; + Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + transform_vec(args, &mut transform)?, + false, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( transform_vec(exprs, &mut transform)?, @@ -331,24 +354,7 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF(AggregateUDF { - args, - fun, - filter, - order_by, - }) => { - let order_by = if let Some(order_by) = order_by { - Some(transform_vec(order_by, &mut transform)?) - } else { - None - }; - Expr::AggregateUDF(AggregateUDF::new( - fun, - transform_vec(args, &mut transform)?, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } + Expr::InList(InList { expr, list, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b06e97acc283..cfbca4ab1337 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -107,12 +107,13 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF(crate::expr::AggregateUDF { - fun: Arc::new(self.clone()), + Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Arc::new(self.clone()), args, - filter: None, - order_by: None, - }) + false, + None, + None, + )) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7deb13c89be5..7d126a0f3373 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -291,7 +291,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) - | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) @@ -595,15 +594,12 @@ pub fn group_window_expr_by_sort_keys( Ok(result) } -/// Collect all deeply nested `Expr::AggregateFunction` and -/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth +/// Collect all deeply nested `Expr::AggregateFunction`. +/// They are returned in order of occurrence (depth /// first), with duplicates omitted. pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } - ) + matches!(nested_expr, Expr::AggregateFunction { .. }) }) } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index b4de322f76f6..fd84bb80160b 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -19,7 +19,7 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, InSubquery}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; @@ -144,20 +144,23 @@ impl TreeNodeRewriter for CountWildcardRewriter { _ => old_expr, }, Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, + func_def: + AggregateFunctionDefinition::BuiltIn( + aggregate_function::AggregateFunction::Count, + ), args, distinct, filter, order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, - args: vec![lit(COUNT_STAR_EXPANSION)], + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::Count, + vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, order_by, - }) + )) } _ => old_expr, }, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index eb5d8c53a5e0..bedc86e2f4f1 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -28,8 +28,8 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - WindowFunction, + self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, WindowFunction, }; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; @@ -346,39 +346,39 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } }, Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def, args, distinct, filter, order_by, - }) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, - &self.schema, - &fun.signature(), - )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) - } - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - fun.signature(), - )?; - let expr = Expr::AggregateUDF(expr::AggregateUDF::new( - fun, new_expr, filter, order_by, - )); - Ok(expr) - } + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let new_expr = coerce_agg_exprs_for_signature( + &fun, + &args, + &self.schema, + &fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::WindowFunction(WindowFunction { fun, args, @@ -914,9 +914,10 @@ mod test { Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], + false, None, None, )); @@ -941,9 +942,10 @@ mod test { &accumulator, &state_type, ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], + false, None, None, )); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f5ad767c5016..1d21407a6985 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -509,10 +509,9 @@ enum ExprMask { /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) - /// - [`AggregateUDF`](Expr::AggregateUDF) Normal, - /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF). + /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). NormalAndAggregates, } @@ -528,10 +527,7 @@ impl ExprMask { | Expr::Wildcard { .. } ); - let is_aggr = matches!( - expr, - Expr::AggregateFunction(..) | Expr::AggregateUDF { .. } - ); + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); match self { Self::Normal => is_normal_minus_aggregates || is_aggr, @@ -908,7 +904,7 @@ mod test { let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::new( "my_agg", &Signature::exact(vec![DataType::UInt32], Volatility::Stable), @@ -917,6 +913,7 @@ mod test { &state_type, )), vec![inner], + false, None, None, )) diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index ed6f472186d4..b1000f042c98 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -22,7 +22,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -372,16 +372,25 @@ fn agg_exprs_evaluation_result_on_empty_batch( for e in agg_expr.iter() { let result_expr = e.clone().transform_up(&|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, .. }) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some(0)))) - } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + if matches!(fun, datafusion_expr::AggregateFunction::Count) { + Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + } + AggregateFunctionDefinition::UDF { .. } => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + AggregateFunctionDefinition::Name(_) => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } } } - Expr::AggregateUDF(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) - } _ => Transformed::No(expr), }; Ok(new_expr) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 95eeee931b4f..bad6e24715c9 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -253,7 +253,6 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) - | Expr::AggregateUDF { .. } | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3310bfed75bf..c7366e17619c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -332,7 +332,6 @@ impl<'a> ConstEvaluator<'a> { // Has no runtime cost, but needed during planning Expr::Alias(..) | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index fa142438c4a3..7e6fb6b355ab 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -23,6 +23,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ aggregate_function::AggregateFunction::{Max, Min, Sum}, col, @@ -70,7 +71,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), distinct, args, filter, @@ -170,7 +171,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, .. diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b2455d5a0d13..d7071c6ddf10 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1739,12 +1739,13 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = registry.udaf(pb.fun_name.as_str())?; - Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, pb.args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, + false, parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), parse_vec_expr(&pb.order_by, registry)?, ))) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9be4a532bb5b..6bfd4c3438f5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -44,8 +44,9 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, + self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, + GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -652,104 +653,139 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction(expr::AggregateFunction { - ref fun, + ref func_def, ref args, ref distinct, ref filter, ref order_by, }) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } + }; + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }; + Self { + expr_type: Some(ExprType::AggregateExpr(Box::new( + aggregate_expr, + ))), + } } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], + AggregateFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }, + ))), }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } + Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported" @@ -790,34 +826,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { )); } }, - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: args.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), - }, Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3ab001298ed2..45727c39a373 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1375,9 +1375,10 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new( + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], + false, Some(Box::new(lit(true))), None, )); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 24ba4d1b506a..958e03879842 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -135,8 +135,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fm, args, false, None, None, ))); } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 25fe6b6633c2..b8c130055a5a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -34,6 +34,7 @@ use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ @@ -706,7 +707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, order_by, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 356c53605131..c546ca755206 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -170,11 +170,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs .iter() .filter(|select_expr| match select_expr { - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false, - Expr::Alias(Alias { expr, name: _, .. }) => !matches!( - **expr, - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) - ), + Expr::AggregateFunction(_) => false, + Expr::Alias(Alias { expr, name: _, .. }) => { + !matches!(**expr, Expr::AggregateFunction(_)) + } _ => true, }) .cloned() diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b7a51032dcd9..cf05d814a5cb 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -692,21 +692,14 @@ pub async fn from_substrait_agg_func( // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { - Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }))) + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun, - args, - distinct, - filter, - order_by, - }))) + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new(fun, args, distinct, filter, order_by), + ))) } else { not_impl_err!( "Aggregated function {} is not supported: function anchor = {:?}", diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 2be3e7b4e884..d576e70711df 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -33,8 +33,8 @@ use datafusion::common::{exec_err, internal_err, not_impl_err}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, ScalarFunctionDefinition, Sort, - WindowFunction, + AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -578,65 +578,73 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_anchor = _register_function(fun.to_string(), extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn (fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) } - }) - } - Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{ - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_anchor = _register_function(fun.name().to_string(), extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: AggregationInvocation::All as i32, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + AggregateFunctionDefinition::UDF(fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.name().to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) } - }) - }, + AggregateFunctionDefinition::Name(name) => { + internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) + } + } + + } Expr::Alias(Alias{expr,..})=> { to_substrait_agg_measure(expr, schema, extension_info) }