diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index 626d92c79..72a633394 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -127,7 +127,7 @@ impl PyAggregate { // TODO: This Alias logic seems to be returning some strange results that we should investigate Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()), Expr::AggregateFunction(AggregateFunction { - func_def: _, args, .. + func: _, args, .. }) => Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()), _ => Err(py_type_err( "Encountered a non Aggregate type in aggregation_arguments", @@ -138,8 +138,8 @@ impl PyAggregate { fn _agg_func_name(expr: &Expr) -> PyResult { match expr { Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()), - Expr::AggregateFunction(AggregateFunction { func_def, .. }) => { - Ok(func_def.name().to_owned()) + Expr::AggregateFunction(AggregateFunction { func, .. }) => { + Ok(func.name().to_owned()) } _ => Err(py_type_err( "Encountered a non Aggregate type in agg_func_name", diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs index 04ec29a15..15097e007 100644 --- a/src/expr/aggregate_expr.rs +++ b/src/expr/aggregate_expr.rs @@ -41,7 +41,7 @@ impl From for PyAggregateFunction { impl Display for PyAggregateFunction { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let args: Vec = self.aggr.args.iter().map(|expr| expr.to_string()).collect(); - write!(f, "{}({})", self.aggr.func_def.name(), args.join(", ")) + write!(f, "{}({})", self.aggr.func.name(), args.join(", ")) } } @@ -49,7 +49,7 @@ impl Display for PyAggregateFunction { impl PyAggregateFunction { /// Get the aggregate type, such as "MIN", or "MAX" fn aggregate_type(&self) -> String { - self.aggr.func_def.name().to_string() + self.aggr.func.name().to_string() } /// is this a distinct aggregate such as `COUNT(DISTINCT expr)` diff --git a/src/functions.rs b/src/functions.rs index f8f478166..c74711552 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -16,7 +16,7 @@ // under the License. use datafusion::functions_aggregate::all_default_aggregate_functions; -use datafusion_expr::AggregateExt; +use datafusion_expr::ExprFunctionExt as AggregateExt; use pyo3::{prelude::*, wrap_pyfunction}; use crate::common::data_type::NullTreatment; @@ -31,9 +31,7 @@ use datafusion::functions_aggregate; use datafusion_common::{Column, ScalarValue, TableReference}; use datafusion_expr::expr::Alias; use datafusion_expr::{ - expr::{ - find_df_window_func, AggregateFunction, AggregateFunctionDefinition, Sort, WindowFunction, - }, + expr::{find_df_window_func, AggregateFunction, Sort, WindowFunction}, lit, Expr, WindowFunctionDefinition, }; @@ -638,18 +636,16 @@ fn window( } macro_rules! aggregate_function { - ($NAME: ident, $FUNC: ident) => { + ($NAME: ident, $FUNC: path) => { aggregate_function!($NAME, $FUNC, stringify!($NAME)); }; - ($NAME: ident, $FUNC: ident, $DOC: expr) => { + ($NAME: ident, $FUNC: path, $DOC: expr) => { #[doc = $DOC] #[pyfunction] #[pyo3(signature = (*args, distinct=false))] fn $NAME(args: Vec, distinct: bool) -> PyExpr { let expr = datafusion_expr::Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::$FUNC, - ), + func: $FUNC(), args: args.into_iter().map(|e| e.into()).collect(), distinct, filter: None, @@ -884,9 +880,9 @@ array_fn!(array_resize, array size value); array_fn!(flatten, array); array_fn!(range, start stop step); -aggregate_function!(array_agg, ArrayAgg); -aggregate_function!(max, Max); -aggregate_function!(min, Min); +aggregate_function!(array_agg, functions_aggregate::array_agg::array_agg_udaf); +aggregate_function!(max, functions_aggregate::min_max::max_udaf); +aggregate_function!(min, functions_aggregate::min_max::min_udaf); pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?;