Skip to content

Commit

Permalink
Rename input_type --> input_types on AggregateFunctionExpr / Accu…
Browse files Browse the repository at this point in the history
…mulatorArgs / StateFieldsArgs (#11666)

* UDAF input types

* Rename

* Update COMMENTS.md

* Update datafusion/functions-aggregate/COMMENTS.md

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
lewiszlw and alamb authored Jul 30, 2024
1 parent 35c2e7e commit 66a8570
Show file tree
Hide file tree
Showing 14 changed files with 57 additions and 53 deletions.
8 changes: 4 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ use arrow::{
record_batch::RecordBatch,
util::pretty::pretty_format_batches,
};
use async_trait::async_trait;
use futures::{Stream, StreamExt};

use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::{
common::cast::{as_int64_array, as_string_array},
common::{arrow_datafusion_err, internal_err, DFSchemaRef},
Expand All @@ -90,16 +94,12 @@ use datafusion::{
physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner},
prelude::{SessionConfig, SessionContext},
};

use async_trait::async_trait;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::Projection;
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use futures::{Stream, StreamExt};

/// Execute the specified sql and return the resulting record batches
/// pretty printed as a String.
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ pub struct AccumulatorArgs<'a> {
/// ```
pub is_distinct: bool,

/// The input type of the aggregate function.
pub input_type: &'a DataType,
/// The input types of the aggregate function.
pub input_types: &'a [DataType],

/// The logical expression of arguments the aggregate function takes.
pub input_exprs: &'a [Expr],
Expand All @@ -109,8 +109,8 @@ pub struct StateFieldsArgs<'a> {
/// The name of the aggregate function.
pub name: &'a str,

/// The input type of the aggregate function.
pub input_type: &'a DataType,
/// The input types of the aggregate function.
pub input_types: &'a [DataType],

/// The return type of the aggregate function.
pub return_type: &'a DataType,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/COMMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ first argument and the definition looks like this:
// `input_type` : data type of the first argument
let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
Field::new("item", args.input_type.clone(), true /* nullable of list item */ ),
Field::new("item", args.input_types[0].clone(), true /* nullable of list item */ ),
false, // nullable of list itself
)];
```
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl AggregateUDFImpl for ApproxDistinct {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let accumulator: Box<dyn Accumulator> = match acc_args.input_type {
let accumulator: Box<dyn Accumulator> = match &acc_args.input_types[0] {
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
// TODO support for boolean (trivial case)
// https://github.com/apache/datafusion/issues/1109
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian {

Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.input_type.clone(),
acc_args.input_types[0].clone(),
)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl ApproxPercentileCont {
None
};

let accumulator: ApproxPercentileAccumulator = match args.input_type {
let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] {
t @ (DataType::UInt8
| DataType::UInt16
| DataType::UInt32
Expand Down
12 changes: 7 additions & 5 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ impl AggregateUDFImpl for ArrayAgg {
return Ok(vec![Field::new_list(
format_state_name(args.name, "distinct_array_agg"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
true,
)]);
}

let mut fields = vec![Field::new_list(
format_state_name(args.name, "array_agg"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
true,
)];

Expand All @@ -119,12 +119,14 @@ impl AggregateUDFImpl for ArrayAgg {
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if acc_args.is_distinct {
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
acc_args.input_type,
&acc_args.input_types[0],
)?));
}

if acc_args.sort_exprs.is_empty() {
return Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?));
return Ok(Box::new(ArrayAggAccumulator::try_new(
&acc_args.input_types[0],
)?));
}

let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema(
Expand All @@ -138,7 +140,7 @@ impl AggregateUDFImpl for ArrayAgg {
.collect::<Result<Vec<_>>>()?;

OrderSensitiveArrayAggAccumulator::try_new(
acc_args.input_type,
&acc_args.input_types[0],
&ordering_dtypes,
ordering_req,
acc_args.is_reversed,
Expand Down
16 changes: 8 additions & 8 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl AggregateUDFImpl for Avg {
}
use DataType::*;
// instantiate specialized accumulator based for the type
match (acc_args.input_type, acc_args.data_type) {
match (&acc_args.input_types[0], acc_args.data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
(
Decimal128(sum_precision, sum_scale),
Expand All @@ -120,7 +120,7 @@ impl AggregateUDFImpl for Avg {
})),
_ => exec_err!(
"AvgAccumulator for ({} --> {})",
acc_args.input_type,
&acc_args.input_types[0],
acc_args.data_type
),
}
Expand All @@ -135,7 +135,7 @@ impl AggregateUDFImpl for Avg {
),
Field::new(
format_state_name(args.name, "sum"),
args.input_type.clone(),
args.input_types[0].clone(),
true,
),
])
Expand All @@ -154,10 +154,10 @@ impl AggregateUDFImpl for Avg {
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;
// instantiate specialized accumulator based for the type
match (args.input_type, args.data_type) {
match (&args.input_types[0], args.data_type) {
(Float64, Float64) => {
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
args.input_type,
&args.input_types[0],
args.data_type,
|sum: f64, count: u64| Ok(sum / count as f64),
)))
Expand All @@ -176,7 +176,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);

Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
args.input_type,
&args.input_types[0],
args.data_type,
avg_fn,
)))
Expand All @@ -197,15 +197,15 @@ impl AggregateUDFImpl for Avg {
};

Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
args.input_type,
&args.input_types[0],
args.data_type,
avg_fn,
)))
}

_ => not_impl_err!(
"AvgGroupsAccumulator for ({} --> {})",
args.input_type,
&args.input_types[0],
args.data_type
),
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl AggregateUDFImpl for Count {
Ok(vec![Field::new_list(
format_state_name(args.name, "count distinct"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
false,
)])
} else {
Expand All @@ -148,7 +148,7 @@ impl AggregateUDFImpl for Count {
return not_impl_err!("COUNT DISTINCT with multiple arguments");
}

let data_type = acc_args.input_type;
let data_type = &acc_args.input_types[0];
Ok(match data_type {
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
DataType::Int8 => Box::new(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,14 @@ impl AggregateUDFImpl for LastValue {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let StateFieldsArgs {
name,
input_type,
input_types,
return_type: _,
ordering_fields,
is_distinct: _,
} = args;
let mut fields = vec![Field::new(
format_state_name(name, "last_value"),
input_type.clone(),
input_types[0].clone(),
true,
)];
fields.extend(ordering_fields.to_vec());
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl AggregateUDFImpl for Median {

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
//Intermediate state is a list of the elements we have collected so far
let field = Field::new("item", args.input_type.clone(), true);
let field = Field::new("item", args.input_types[0].clone(), true);
let state_name = if args.is_distinct {
"distinct_median"
} else {
Expand Down Expand Up @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median {
};
}

let dt = acc_args.input_type;
let dt = &acc_args.input_types[0];
downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl AggregateUDFImpl for NthValueAgg {

NthValueAccumulator::try_new(
n,
acc_args.input_type,
&acc_args.input_types[0],
&ordering_dtypes,
ordering_req,
)
Expand All @@ -125,7 +125,7 @@ impl AggregateUDFImpl for NthValueAgg {
let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
false,
)];
let orderings = args.ordering_fields.to_vec();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ mod tests {
name: "a",
is_distinct: false,
is_reversed: false,
input_type: &DataType::Float64,
input_types: &[DataType::Float64],
input_exprs: &[datafusion_expr::col("a")],
};

Expand All @@ -348,7 +348,7 @@ mod tests {
name: "a",
is_distinct: false,
is_reversed: false,
input_type: &DataType::Float64,
input_types: &[DataType::Float64],
input_exprs: &[datafusion_expr::col("a")],
};

Expand Down
38 changes: 20 additions & 18 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@
// specific language governing permissions and limitations
// under the License.

pub mod count_distinct;
pub mod groups_accumulator;
pub mod merge_arrays;
pub mod stats;
pub mod tdigest;
pub mod utils;
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};

use datafusion_common::exec_err;
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::utils::AggregateOrderSensitivity;
use datafusion_expr::ReversedUDAF;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
};
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

use self::utils::down_cast_any_ref;
use crate::physical_expr::PhysicalExpr;
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
use crate::utils::reverse_order_bys;

use datafusion_common::exec_err;
use datafusion_expr::utils::AggregateOrderSensitivity;
use self::utils::down_cast_any_ref;

pub mod count_distinct;
pub mod groups_accumulator;
pub mod merge_arrays;
pub mod stats;
pub mod tdigest;
pub mod utils;

/// Creates a physical expression of the UDAF, that includes all necessary type coercion.
/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF.
Expand Down Expand Up @@ -225,7 +227,7 @@ impl AggregateExprBuilder {
ignore_nulls,
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
input_types: input_exprs_types,
is_reversed,
}))
}
Expand Down Expand Up @@ -466,7 +468,7 @@ pub struct AggregateFunctionExpr {
ordering_fields: Vec<Field>,
is_distinct: bool,
is_reversed: bool,
input_type: DataType,
input_types: Vec<DataType>,
}

impl AggregateFunctionExpr {
Expand Down Expand Up @@ -504,7 +506,7 @@ impl AggregateExpr for AggregateFunctionExpr {
fn state_fields(&self) -> Result<Vec<Field>> {
let args = StateFieldsArgs {
name: &self.name,
input_type: &self.input_type,
input_types: &self.input_types,
return_type: &self.data_type,
ordering_fields: &self.ordering_fields,
is_distinct: self.is_distinct,
Expand All @@ -525,7 +527,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand All @@ -542,7 +544,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand Down Expand Up @@ -614,7 +616,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand All @@ -630,7 +632,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand Down

0 comments on commit 66a8570

Please sign in to comment.