Skip to content

Commit

Permalink
UDAF: Extend more args to state_fields and `groups_accumulator_supp…
Browse files Browse the repository at this point in the history
…orted` and introduce `ReversedUDAF` (apache#10525)

* extends args

Signed-off-by: jayzhan211 <[email protected]>

* reuse accumulator args

Signed-off-by: jayzhan211 <[email protected]>

* fix example

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored and findepi committed Jul 16, 2024
1 parent 5e9fc43 commit bbaf15e
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 102 deletions.
15 changes: 5 additions & 10 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
GroupsAccumulator, Signature,
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
Expand Down Expand Up @@ -92,21 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf {
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(
&self,
_name: &str,
value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
Ok(vec![
Field::new("prod", value_type, true),
Field::new("prod", args.return_type.clone(), true),
Field::new("n", DataType::UInt32, true),
])
}

/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
/// which is used for cases when there are grouping columns in the query
fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
11 changes: 3 additions & 8 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::AggregateFunctionSimplification;
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -70,16 +70,11 @@ impl AggregateUDFImpl for BetterAvgUdaf {
unimplemented!("should not be invoked")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
panic!("accumulator shouldn't invoke");
}

fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
8 changes: 2 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::expr::{
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
StateFieldsArgs,
};
use crate::{
aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery,
Expand Down Expand Up @@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
Expand Down
51 changes: 36 additions & 15 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Schema};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use std::sync::Arc;

Expand All @@ -41,11 +41,14 @@ pub type ReturnTypeFunction =
/// [`AccumulatorArgs`] contains information about how an aggregate
/// function was called, including the types of its arguments and any optional
/// ordering expressions.
#[derive(Debug)]
pub struct AccumulatorArgs<'a> {
/// The return type of the aggregate function.
pub data_type: &'a DataType,

/// The schema of the input arguments
pub schema: &'a Schema,

/// Whether to ignore nulls.
///
/// SQL allows the user to specify `IGNORE NULLS`, for example:
Expand All @@ -66,22 +69,40 @@ pub struct AccumulatorArgs<'a> {
///
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],

/// Whether the aggregate function is distinct.
///
/// ```sql
/// SELECT COUNT(DISTINCT column1) FROM t;
/// ```
pub is_distinct: bool,

/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// The number of arguments the aggregate function takes.
pub args_num: usize,
}

impl<'a> AccumulatorArgs<'a> {
pub fn new(
data_type: &'a DataType,
schema: &'a Schema,
ignore_nulls: bool,
sort_exprs: &'a [Expr],
) -> Self {
Self {
data_type,
schema,
ignore_nulls,
sort_exprs,
}
}
/// [`StateFieldsArgs`] contains information about the fields that an
/// aggregate function's accumulator should have. Used for [`AggregateUDFImpl::state_fields`].
///
/// [`AggregateUDFImpl::state_fields`]: crate::udaf::AggregateUDFImpl::state_fields
pub struct StateFieldsArgs<'a> {
/// The name of the aggregate function.
pub name: &'a str,

/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// The return type of the aggregate function.
pub return_type: &'a DataType,

/// The ordering fields of the aggregate function.
pub ordering_fields: &'a [Field],

/// Whether the aggregate function is distinct.
pub is_distinct: bool,
}

/// Factory that returns an accumulator for the given aggregate function.
Expand Down
57 changes: 33 additions & 24 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions
use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
Expand Down Expand Up @@ -177,18 +179,13 @@ impl AggregateUDF {
/// for more details.
///
/// This is used to support multi-phase aggregations
pub fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.inner.state_fields(name, value_type, ordering_fields)
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
self.inner.state_fields(args)
}

/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
pub fn groups_accumulator_supported(&self) -> bool {
self.inner.groups_accumulator_supported()
pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner.groups_accumulator_supported(args)
}

/// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
Expand Down Expand Up @@ -232,7 +229,7 @@ where
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr};
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs};
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
/// # use arrow::datatypes::Schema;
/// # use arrow::datatypes::Field;
/// #[derive(Debug, Clone)]
Expand Down Expand Up @@ -261,9 +258,9 @@ where
/// }
/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
/// Ok(vec![
/// Field::new("value", value_type, true),
/// Field::new("value", args.return_type.clone(), true),
/// Field::new("ordering", DataType::UInt32, true)
/// ])
/// }
Expand Down Expand Up @@ -319,19 +316,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// The name of the fields must be unique within the query and thus should
/// be derived from `name`. See [`format_state_name`] for a utility function
/// to generate a unique name.
fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
let value_fields = vec![Field::new(
format_state_name(name, "value"),
value_type,
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let fields = vec![Field::new(
format_state_name(args.name, "value"),
args.return_type.clone(),
true,
)];

Ok(value_fields.into_iter().chain(ordering_fields).collect())
Ok(fields
.into_iter()
.chain(args.ordering_fields.to_vec())
.collect())
}

/// If the aggregate expression has a specialized
Expand All @@ -344,7 +339,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// `Self::accumulator` for certain queries, such as when this aggregate is
/// used as a window function or when there no GROUP BY columns in the
/// query.
fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}

Expand Down Expand Up @@ -389,6 +384,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
None
}

/// Returns the reverse expression of the aggregate function.
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}
}

pub enum ReversedUDAF {
/// The expression is the same as the original expression, like SUM, COUNT
Identical,
/// The expression does not support reverse calculation, like ArrayAgg
NotSupported,
/// The expression is different from the original expression
Reversed(Arc<dyn AggregateUDFImpl>),
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down
22 changes: 8 additions & 14 deletions datafusion/functions-aggregate/src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ use datafusion_common::{
ScalarValue,
};
use datafusion_expr::{
function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility,
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::stats::StatsType;

Expand Down Expand Up @@ -101,12 +103,8 @@ impl AggregateUDFImpl for CovarianceSample {
Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Expand Down Expand Up @@ -176,12 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation {
Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Expand Down
15 changes: 5 additions & 10 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Expand Down Expand Up @@ -147,18 +147,13 @@ impl AggregateUDFImpl for FirstValue {
.map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
}

fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new(
format_state_name(name, "first_value"),
value_type,
format_state_name(args.name, "first_value"),
args.return_type.clone(),
true,
)];
fields.extend(ordering_fields);
fields.extend(args.ordering_fields.to_vec());
fields.push(Field::new("is_set", DataType::Boolean, true));
Ok(fields)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1759,7 +1759,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{
function::AggregateFunctionSimplification, interval_arithmetic::Interval, *,
function::{AccumulatorArgs, AggregateFunctionSimplification},
interval_arithmetic::Interval,
*,
};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -3783,7 +3785,7 @@ mod tests {
unimplemented!("not needed for tests")
}

fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
unimplemented!("not needed for testing")
}

Expand Down
Loading

0 comments on commit bbaf15e

Please sign in to comment.