From 552f52b660a6616e59c1bf65c4831506e3392665 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Wed, 19 Jun 2024 11:57:03 +0000 Subject: [PATCH 01/18] Moving min and max to new API and removing from protobuf --- datafusion/core/src/dataframe/mod.rs | 4 +- datafusion/core/tests/dataframe/mod.rs | 4 +- .../core/tests/fuzz_cases/window_fuzz.rs | 7 +- datafusion/expr/src/aggregate_function.rs | 25 - datafusion/expr/src/expr.rs | 12 - datafusion/expr/src/expr_fn.rs | 24 - datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- .../expr/src/type_coercion/aggregates.rs | 33 +- datafusion/expr/src/utils.rs | 8 +- datafusion/functions-aggregate/src/lib.rs | 4 + datafusion/functions-aggregate/src/min_max.rs | 918 ++++++++++++++++++ .../src/single_distinct_to_groupby.rs | 5 +- .../physical-expr/src/aggregate/build_in.rs | 60 +- datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 4 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../proto/src/physical_plan/to_proto.rs | 13 +- .../tests/cases/roundtrip_logical_plan.rs | 11 +- datafusion/sql/src/query.rs | 1 + datafusion/sql/src/relation/mod.rs | 1 + datafusion/sql/src/select.rs | 2 + 22 files changed, 955 insertions(+), 199 deletions(-) create mode 100644 datafusion/functions-aggregate/src/min_max.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index b5c58eff577c4..61916191f8ff1 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,11 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, + avg,utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum}; +use datafusion_functions_aggregate::expr_fn::{count,max, median,min, stddev, sum}; use async_trait::async_trait; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index fa364c5f2a653..fb3f101dc2264 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, + array_agg, avg, cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{count, sum}; +use datafusion_functions_aggregate::expr_fn::{count, max, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 5bd19850cacc8..756a6477e762d 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -35,10 +35,11 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ - AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, + BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -360,14 +361,14 @@ fn get_random_function( window_fn_map.insert( "min", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![arg.clone()], ), ); window_fn_map.insert( "max", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![arg.clone()], ), ); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 1cde1c5050a80..e93a92885993b 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -33,10 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Minimum - Min, - /// Maximum - Max, /// Average Avg, /// Aggregation into an array @@ -57,8 +53,6 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Min => "MIN", - Max => "MAX", Avg => "AVG", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", @@ -84,9 +78,7 @@ impl FromStr for AggregateFunction { "avg" => AggregateFunction::Avg, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, - "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, - "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, "nth_value" => AggregateFunction::NthValue, // statistical @@ -123,11 +115,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Max | AggregateFunction::Min => { - // For min and max agg function, the returned type is same as input type. - // The coerced_data_types is same with input_types. - Ok(coerced_data_types[0].clone()) - } AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Ok(DataType::Boolean) } @@ -167,18 +154,6 @@ impl AggregateFunction { AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { Signature::any(1, Volatility::Immutable) } - AggregateFunction::Min | AggregateFunction::Max => { - let valid = STRINGS - .iter() - .chain(NUMERICS.iter()) - .chain(TIMESTAMPS.iter()) - .chain(DATES.iter()) - .chain(TIMES.iter()) - .chain(BINARYS.iter()) - .cloned() - .collect::>(); - Signature::uniform(1, valid, Volatility::Immutable) - } AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9ba866a4c9198..2de7d4c178267 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2255,18 +2255,6 @@ mod test { #[test] fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Max - )) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Min - )) - ); assert_eq!( find_df_window_func("avg"), Some(WindowFunctionDefinition::AggregateFunction( diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a87412ee63565..5b859f2becc53 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -145,30 +145,6 @@ pub fn not(expr: Expr) -> Expr { expr.not() } -/// Create an expression to represent the min() aggregate function -pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Min, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the max() aggregate function -pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Max, - vec![expr], - false, - None, - None, - None, - )) -} - /// Create an expression to represent the array_agg() aggregate function pub fn array_agg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index eb38fee7cad07..6eefd656dfc1f 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -156,7 +156,7 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast, + avg, cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, LogicalPlanBuilder, }; diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index abe6d8b1823da..d2e44967cf9b3 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; use crate::{AggregateFunction, Signature, TypeSignature}; @@ -96,11 +95,6 @@ pub fn coerce_types( match agg_fun { AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), - AggregateFunction::Min | AggregateFunction::Max => { - // min and max support the dictionary data type - // unpack the dictionary to get the value - get_min_max_result_type(input_types) - } AggregateFunction::Avg => { // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval @@ -208,22 +202,6 @@ pub fn check_arg_count( Ok(()) } -fn get_min_max_result_type(input_types: &[DataType]) -> Result> { - // make sure that the input types only has one element. - assert_eq!(input_types.len(), 1); - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } -} - /// function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { @@ -380,13 +358,6 @@ mod tests { #[test] fn test_aggregate_coerce_types() { - // test input args with error number input types - let fun = AggregateFunction::Min; - let input_types = vec![DataType::Int64, DataType::Int32]; - let signature = fun.signature(); - let result = coerce_types(&fun, &input_types, &signature); - assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace()); - let fun = AggregateFunction::Avg; // test input args is invalid data type for avg let input_types = vec![DataType::Utf8]; @@ -397,12 +368,10 @@ mod tests { result.unwrap_err().strip_backtrace() ); - // test count, array_agg, approx_distinct, min, max. + // test count, array_agg, approx_distinct. // the coerced types is same with input types let funs = vec![ AggregateFunction::ArrayAgg, - AggregateFunction::Min, - AggregateFunction::Max, ]; let input_types = vec![ vec![DataType::Int32], diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 3ab0c180dcba9..60afef8626462 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1268,7 +1268,7 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![], @@ -1276,7 +1276,7 @@ mod tests { None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![], @@ -1284,7 +1284,7 @@ mod tests { None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], vec![], vec![], @@ -1371,7 +1371,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![ diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 20a8d2c159266..45c2fa351a51e 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -61,6 +61,7 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod min_max; pub mod regr; pub mod stddev; pub mod sum; @@ -96,6 +97,8 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::min_max::max; + pub use super::min_max::min; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; @@ -120,6 +123,7 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_samp_udaf(), sum::sum_udaf(), covariance::covar_pop_udaf(), + min_max::max_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs new file mode 100644 index 0000000000000..a53a446ab2a78 --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -0,0 +1,918 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function + +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `MAX` aggregate accumulators + +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow::array::{ + ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, + Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, LargeBinaryArray, LargeStringArray, StringArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::compute; +use datafusion_common::{downcast_value, internal_err, DataFusionError, Result}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use std::fmt::Debug; + +use arrow::datatypes::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; + +use arrow::datatypes::i256; + +use datafusion_common::ScalarValue; +use datafusion_expr::GroupsAccumulator; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, +}; + +// min/max of two non-string scalar values. +macro_rules! typed_min_max { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $($EXTRA_ARGS.clone()),* + ) + }}; +} + +macro_rules! typed_min_max_float { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => match a.total_cmp(b) { + choose_min_max!($OP) => Some(*b), + _ => Some(*a), + }, + }) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for binay types. +macro_rules! typed_min_max_batch_binary { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_vec())); + ScalarValue::$SCALAR(value) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +macro_rules! typed_min_max_batch { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) + }}; +} +// min/max of two scalar string values. +macro_rules! typed_min_max_string { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for string types. +macro_rules! typed_min_max_batch_string { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_string())); + ScalarValue::$SCALAR(value) + }}; +} + +macro_rules! min_max { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + Ok(match ($VALUE, $DELTA) { + ( + lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { + typed_min_max!(lhs, rhs, Boolean, $OP) + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + typed_min_max_float!(lhs, rhs, Float64, $OP) + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + typed_min_max_float!(lhs, rhs, Float32, $OP) + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + typed_min_max!(lhs, rhs, UInt64, $OP) + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + typed_min_max!(lhs, rhs, UInt32, $OP) + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + typed_min_max!(lhs, rhs, UInt16, $OP) + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + typed_min_max!(lhs, rhs, UInt8, $OP) + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + typed_min_max!(lhs, rhs, Int64, $OP) + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + typed_min_max!(lhs, rhs, Int32, $OP) + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + typed_min_max!(lhs, rhs, Int16, $OP) + } + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + typed_min_max!(lhs, rhs, Int8, $OP) + } + (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8, $OP) + } + (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) + } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + typed_min_max_string!(lhs, rhs, Binary, $OP) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeBinary, $OP) + } + (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { + typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) + } + ( + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) + } + ( + ScalarValue::Date32(lhs), + ScalarValue::Date32(rhs), + ) => { + typed_min_max!(lhs, rhs, Date32, $OP) + } + ( + ScalarValue::Date64(lhs), + ScalarValue::Date64(rhs), + ) => { + typed_min_max!(lhs, rhs, Date64, $OP) + } + ( + ScalarValue::Time32Second(lhs), + ScalarValue::Time32Second(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Second, $OP) + } + ( + ScalarValue::Time32Millisecond(lhs), + ScalarValue::Time32Millisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Millisecond, $OP) + } + ( + ScalarValue::Time64Microsecond(lhs), + ScalarValue::Time64Microsecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Microsecond, $OP) + } + ( + ScalarValue::Time64Nanosecond(lhs), + ScalarValue::Time64Nanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) + } + ( + ScalarValue::IntervalYearMonth(lhs), + ScalarValue::IntervalYearMonth(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) + } + ( + ScalarValue::IntervalDayTime(lhs), + ScalarValue::IntervalDayTime(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalDayTime, $OP) + } + ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalMonthDayNano(_), + ) | ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalMonthDayNano(_), + ) => { + interval_min_max!($OP, $VALUE, $DELTA) + } + ( + ScalarValue::DurationSecond(lhs), + ScalarValue::DurationSecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationSecond, $OP) + } + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMillisecond, $OP) + } + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) + } + ( + ScalarValue::DurationNanosecond(lhs), + ScalarValue::DurationNanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationNanosecond, $OP) + } + e => { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + e + ) + } + }) + }}; +} + +macro_rules! choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; +} + +macro_rules! interval_min_max { + ($OP:tt, $LHS:expr, $RHS:expr) => {{ + match $LHS.partial_cmp(&$RHS) { + Some(choose_min_max!($OP)) => $RHS.clone(), + Some(_) => $LHS.clone(), + None => { + return internal_err!("Comparison error while computing interval min/max") + } + } + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +// this is a macro to support both operations (min and max). +macro_rules! min_max_batch { + ($VALUES:expr, $OP:ident) => {{ + match $VALUES.data_type() { + DataType::Decimal128(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal128Array, + Decimal128, + $OP, + precision, + scale + ) + } + DataType::Decimal256(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal256Array, + Decimal256, + $OP, + precision, + scale + ) + } + // all types that have a natural order + DataType::Float64 => { + typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) + } + DataType::Float32 => { + typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) + } + DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), + DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), + DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), + DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), + DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), + DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), + DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), + DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_min_max_batch!( + $VALUES, + TimestampSecondArray, + TimestampSecond, + $OP, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMillisecondArray, + TimestampMillisecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMicrosecondArray, + TimestampMicrosecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampNanosecondArray, + TimestampNanosecond, + $OP, + tz_opt + ), + DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), + DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), + DataType::Time32(TimeUnit::Second) => { + typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) + } + DataType::Time32(TimeUnit::Millisecond) => { + typed_min_max_batch!( + $VALUES, + Time32MillisecondArray, + Time32Millisecond, + $OP + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + typed_min_max_batch!( + $VALUES, + Time64MicrosecondArray, + Time64Microsecond, + $OP + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + typed_min_max_batch!( + $VALUES, + Time64NanosecondArray, + Time64Nanosecond, + $OP + ) + } + other => { + // This should have been handled before + return internal_err!( + "Min/Max accumulator not implemented for type {:?}", + other + ); + } + } + }}; +} + +/// dynamically-typed max(array) -> ScalarValue +fn max_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, max_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + max_binary + ) + } + _ => min_max_batch!(values, max), + }) +} + +// min/max of two non-string scalar values. +macro_rules! typed_min_max { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $($EXTRA_ARGS.clone()),* + ) + }}; +} + +make_udaf_expr_and_func!( + Max, + max, + expression, + "Returns the maximum of a group of values.", + max_udaf +); + +make_udaf_expr_and_func!( + Min, + min, + expression, + "Returns the minimum of a group of values.", + min_udaf +); + +fn min_max_aggregate_data_type(input_type: DataType) -> DataType { + if let DataType::Dictionary(_, value_type) = input_type { + *value_type + } else { + input_type + } +} + +#[derive(Debug)] +pub struct Max { + signature: Signature, +} + +impl Max { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "max" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(min_max_aggregate_data_type(arg_types[0].clone())) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MaxAccumulator::try_new(acc_args.data_type)?)) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + use DataType::*; + use TimeUnit::*; + let data_type = args.data_type; + macro_rules! helper { + ($NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( + data_type, + |cur, new| { + if *cur > new { + *cur = new + } + }, + ) + // Initialize each accumulator to $NATIVE::MIN + .with_starting_value($NATIVE::MIN), + )) + }}; + } + + match args.data_type { + Int8 => helper!(i8, Int8Type), + Int16 => helper!(i16, Int16Type), + Int32 => helper!(i32, Int32Type), + Int64 => helper!(i64, Int64Type), + UInt8 => helper!(u8, UInt8Type), + UInt16 => helper!(u16, UInt16Type), + UInt32 => helper!(u32, UInt32Type), + UInt64 => helper!(u64, UInt64Type), + Float32 => { + helper!(f32, Float32Type) + } + Float64 => { + helper!(f64, Float64Type) + } + Date32 => helper!(i32, Date32Type), + Date64 => helper!(i64, Date64Type), + Time32(Second) => { + helper!(i32, Time32SecondType) + } + Time32(Millisecond) => { + helper!(i32, Time32MillisecondType) + } + Time64(Microsecond) => { + helper!(i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + helper!(i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + helper!(i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + helper!(i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + helper!(i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + helper!(i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + helper!(i128, Decimal128Type) + } + Decimal256(_, _) => { + helper!(i256, Decimal256Type) + } + + // It would be nice to have a fast implementation for Strings as well + // https://github.com/apache/datafusion/issues/6906 + + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), + } + } +} + +/// An accumulator to compute the maximum value +#[derive(Debug)] +pub struct MaxAccumulator { + max: ScalarValue, +} + +impl MaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MaxAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + } +} + + + +#[derive(Debug)] +pub struct Min { + signature: Signature, +} + +impl Min { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "Min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(min_max_aggregate_data_type(arg_types[0].clone())) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MinAccumulator::try_new(acc_args.data_type)?)) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + use DataType::*; + use TimeUnit::*; + let data_type = args.data_type; + macro_rules! helper { + ($NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( + data_type, + |cur, new| { + if *cur < new { + *cur = new + } + }, + ) + // Initialize each accumulator to $NATIVE::MIN + .with_starting_value($NATIVE::MIN), + )) + }}; + } + + match args.data_type { + Int8 => helper!(i8, Int8Type), + Int16 => helper!(i16, Int16Type), + Int32 => helper!(i32, Int32Type), + Int64 => helper!(i64, Int64Type), + UInt8 => helper!(u8, UInt8Type), + UInt16 => helper!(u16, UInt16Type), + UInt32 => helper!(u32, UInt32Type), + UInt64 => helper!(u64, UInt64Type), + Float32 => { + helper!(f32, Float32Type) + } + Float64 => { + helper!(f64, Float64Type) + } + Date32 => helper!(i32, Date32Type), + Date64 => helper!(i64, Date64Type), + Time32(Second) => { + helper!(i32, Time32SecondType) + } + Time32(Millisecond) => { + helper!(i32, Time32MillisecondType) + } + Time64(Microsecond) => { + helper!(i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + helper!(i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + helper!(i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + helper!(i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + helper!(i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + helper!(i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + helper!(i128, Decimal128Type) + } + Decimal256(_, _) => { + helper!(i256, Decimal256Type) + } + + // It would be nice to have a fast implementation for Strings as well + // https://github.com/apache/datafusion/issues/6906 + + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), + } + } +} +/// An accumulator to compute the minimum value +#[derive(Debug)] +pub struct MinAccumulator { + min: ScalarValue, +} + +impl MinAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn float_min_max_with_nans() { + let pos_nan = f32::NAN; + let zero = 0_f32; + let neg_inf = f32::NEG_INFINITY; + + let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| { + for batch in values.iter() { + let batch = + Arc::new(Float32Array::from_iter_values(batch.iter().copied())); + acc.update_batch(&[batch]).unwrap(); + } + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float32(Some(expected))); + }; + + // This test checks both comparison between batches (which uses the min_max macro + // defined above) and within a batch (which uses the arrow min/max compute function + // and verifies both respect the total order comparison for floats) + + let min = || MinAccumulator::try_new(&DataType::Float32).unwrap(); + let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap(); + + check(&mut min(), &[&[zero], &[pos_nan]], zero); + check(&mut min(), &[&[zero, pos_nan]], zero); + check(&mut min(), &[&[zero], &[neg_inf]], neg_inf); + check(&mut min(), &[&[zero, neg_inf]], neg_inf); + check(&mut max(), &[&[zero], &[pos_nan]], pos_nan); + check(&mut max(), &[&[zero, pos_nan]], pos_nan); + check(&mut max(), &[&[zero], &[neg_inf]], zero); + check(&mut max(), &[&[zero, neg_inf]], zero); + } +} \ No newline at end of file diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d3d22eb53f395..4836eaa428a06 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -28,7 +28,6 @@ use datafusion_common::{ use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ - aggregate_function::AggregateFunction::{Max, Min}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, @@ -71,7 +70,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), + func_def: AggregateFunctionDefinition::BuiltIn(_fun), distinct, args, filter, @@ -87,7 +86,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - } else if !matches!(fun, Min | Max) { + } else { return Ok(false); } } else if let Expr::AggregateFunction(AggregateFunction { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 1dfe9ffd69057..4e39b2352d3f7 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -107,17 +107,7 @@ pub fn create_aggregate_expr( data_type, is_expr_nullable, )) - } - (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), + }, (AggregateFunction::Avg, false) => { Arc::new(Avg::new(input_phy_exprs[0].clone(), name, data_type)) } @@ -232,54 +222,6 @@ mod tests { Ok(()) } - #[test] - fn test_min_max_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Min => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::Max => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } #[test] fn test_bool_and_or_expr() -> Result<()> { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6375df721ae6e..55d7da4364dca 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -472,8 +472,8 @@ message InListNode { } enum AggregateFunction { - MIN = 0; - MAX = 1; + // MIN = 0; + // MAX = 1; // SUM = 2; AVG = 3; // COUNT = 4; diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index bc5b6be2ad87f..51b2732a9cee6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1926,8 +1926,7 @@ pub struct PartitionStats { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum AggregateFunction { - Min = 0, - Max = 1, + UNUSED = 0, /// SUM = 2; Avg = 3, /// COUNT = 4; @@ -1969,8 +1968,6 @@ impl AggregateFunction { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - AggregateFunction::Min => "MIN", - AggregateFunction::Max => "MAX", AggregateFunction::Avg => "AVG", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", @@ -1978,13 +1975,12 @@ impl AggregateFunction { AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", + AggregateFunction::UNUSED => "UNUSED", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { - "MIN" => Some(Self::Min), - "MAX" => Some(Self::Max), "AVG" => Some(Self::Avg), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 5bec655bb1ff5..8ff50647c7de4 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -137,8 +137,7 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { impl From for AggregateFunction { fn from(agg_fun: protobuf::AggregateFunction) -> Self { match agg_fun { - protobuf::AggregateFunction::Min => Self::Min, - protobuf::AggregateFunction::Max => Self::Max, + protobuf::AggregateFunction::Avg => Self::Avg, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, @@ -146,6 +145,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, + protobuf::AggregateFunction::UNUSED => panic!("This should never happen, we are retiring this but protobuf doesn't support enum with no 0 values"), } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 66b7c77799ea7..fb31acebb12e4 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -108,8 +108,6 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { impl From<&AggregateFunction> for protobuf::AggregateFunction { fn from(value: &AggregateFunction) -> Self { match value { - AggregateFunction::Min => Self::Min, - AggregateFunction::Max => Self::Max, AggregateFunction::Avg => Self::Avg, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, @@ -374,8 +372,6 @@ pub fn serialize_expr( AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ed966509b842d..8a0331018359f 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,11 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - CumeDist, DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, - Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, + ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, + CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, Grouping, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, + NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, + TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -251,10 +252,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Min - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Max } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 61764394ee74f..1b2f8f6e07e15 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -54,7 +54,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, + Accumulator, AggregateExt, ColumnarValue, ExprSchemable, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -2026,14 +2026,6 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); // 5. test with AggregateUDF #[derive(Debug)] @@ -2168,7 +2160,6 @@ fn roundtrip_window() { roundtrip_expr_test(test_expr1, ctx.clone()); roundtrip_expr_test(test_expr2, ctx.clone()); roundtrip_expr_test(test_expr3, ctx.clone()); - roundtrip_expr_test(test_expr4, ctx.clone()); roundtrip_expr_test(test_expr5, ctx.clone()); roundtrip_expr_test(test_expr6, ctx); } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index cbbff19321d81..d5e5fc9c51cf6 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -45,6 +45,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let set_expr = *query.body; + match set_expr { SetExpr::Select(mut select) => { let select_into = select.into.take(); diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9380e569f2e43..d694035ad429d 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -29,6 +29,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { relation: TableFactor, planner_context: &mut PlannerContext, ) -> Result { + println!("Creating relation {:#?}", relation); let (plan, alias) = match relation { TableFactor::Table { name, alias, args, .. diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 0fa266e4e01d7..59c9c58c73701 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -71,6 +71,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; + let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause @@ -377,6 +378,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { mut from: Vec, planner_context: &mut PlannerContext, ) -> Result { + println!("from len {}", from.len()); match from.len() { 0 => Ok(LogicalPlanBuilder::empty(true).build()?), 1 => { From 0e45a682f2c5ee9296941717c62846d877ec73d8 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 14:58:45 +0000 Subject: [PATCH 02/18] Fixing tests --- .../examples/dataframe_subquery.rs | 1 + datafusion/core/src/dataframe/mod.rs | 5 +- datafusion/core/tests/dataframe/mod.rs | 6 +- .../core/tests/fuzz_cases/window_fuzz.rs | 4 +- datafusion/expr/src/expr_rewriter/order_by.rs | 1 + datafusion/expr/src/test/function_stub.rs | 168 ++++++++++++++++++ .../expr/src/type_coercion/aggregates.rs | 5 +- datafusion/expr/src/utils.rs | 10 +- datafusion/functions-aggregate/src/min_max.rs | 13 +- .../src/analyzer/count_wildcard_rule.rs | 6 +- .../optimizer/src/optimize_projections/mod.rs | 12 +- datafusion/optimizer/src/push_down_limit.rs | 4 +- .../optimizer/src/scalar_subquery_to_join.rs | 4 +- .../simplify_expressions/simplify_exprs.rs | 6 +- .../src/single_distinct_to_groupby.rs | 31 +--- .../physical-expr/src/aggregate/build_in.rs | 28 +-- .../proto/src/logical_plan/from_proto.rs | 1 - .../proto/src/physical_plan/to_proto.rs | 9 +- .../tests/cases/roundtrip_logical_plan.rs | 9 +- datafusion/sql/src/query.rs | 2 +- datafusion/sql/src/select.rs | 2 +- 21 files changed, 221 insertions(+), 106 deletions(-) diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index 9fb61008b9f69..85d0c25c60bc4 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -19,6 +19,7 @@ use arrow_schema::DataType; use std::sync::Arc; use datafusion::error::Result; +use datafusion::logical_expr::test::function_stub::max; use datafusion::prelude::*; use datafusion::test_util::arrow_test_data; use datafusion_common::ScalarValue; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 61916191f8ff1..31e72e83a0732 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,11 +50,10 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg,utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, - UNNAMED_TABLE, + avg, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::{count,max, median,min, stddev, sum}; +use datafusion_functions_aggregate::expr_fn::{count, max, median, min, stddev, sum}; use async_trait::async_trait; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index fb3f101dc2264..dbb53745f3856 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -52,9 +52,9 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, exists, expr, in_subquery, lit, out_ref_col, - placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + array_agg, avg, cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, + scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{count, max, sum}; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 756a6477e762d..ddfa940975d8a 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -35,8 +35,8 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ - BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 6eefd656dfc1f..6c0ed5c077f16 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -161,6 +161,7 @@ mod test { }; use super::*; + use crate::test::function_stub::min; #[test] fn rewrite_sort_cols_by_agg() { diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index ac98ee9747cc1..c0d9ccc94141a 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -273,3 +273,171 @@ impl AggregateUDFImpl for Count { ReversedUDAF::Identical } } + +create_func!(Min, min_udaf); + +pub fn min(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + min_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of COUNT aggregate +pub struct Min { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Min { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Min") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl Min { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MIN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +create_func!(Max, max_udaf); + +pub fn max(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + max_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of COUNT aggregate +pub struct Max { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Max { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Min") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} + +impl Max { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MIN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index d2e44967cf9b3..ac17fecba796f 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - use crate::{AggregateFunction, Signature, TypeSignature}; use arrow::datatypes::{ @@ -370,9 +369,7 @@ mod tests { // test count, array_agg, approx_distinct. // the coerced types is same with input types - let funs = vec![ - AggregateFunction::ArrayAgg, - ]; + let funs = vec![AggregateFunction::ArrayAgg]; let input_types = vec![ vec![DataType::Int32], vec![DataType::Decimal128(10, 2)], diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 60afef8626462..07514ca0edc2d 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1253,8 +1253,8 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, WindowFrame, - WindowFunctionDefinition, + test::function_stub::max_udaf, test::function_stub::min_udaf, + test::function_stub::sum_udaf, Cast, WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1315,7 +1315,7 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], @@ -1323,7 +1323,7 @@ mod tests { None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![], @@ -1331,7 +1331,7 @@ mod tests { None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index a53a446ab2a78..d7f6d5afe5742 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -34,10 +34,6 @@ //! Defines `MAX` aggregate accumulators -use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; use arrow::array::{ ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -47,6 +43,10 @@ use arrow::array::{ TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::compute; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; use datafusion_common::{downcast_value, internal_err, DataFusionError, Result}; use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use std::fmt::Debug; @@ -706,8 +706,6 @@ impl Accumulator for MaxAccumulator { } } - - #[derive(Debug)] pub struct Min { signature: Signature, @@ -877,7 +875,6 @@ impl Accumulator for MinAccumulator { } } - #[cfg(test)] mod tests { use super::*; @@ -915,4 +912,4 @@ mod tests { check(&mut max(), &[&[zero], &[neg_inf]], zero); check(&mut max(), &[&[zero, neg_inf]], zero); } -} \ No newline at end of file +} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index de2af520053a2..d78bd702481ae 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -102,11 +102,11 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, - WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::max; use std::sync::Arc; use datafusion_functions_aggregate::expr_fn::{count, sum}; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 11540d3e162e4..c01dbab400341 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -826,13 +826,13 @@ mod tests { expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, - max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, - WindowFunctionDefinition, + not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + Projection, UserDefinedLogicalNodeCore, WindowFrame, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::count; + use datafusion_functions_aggregate::expr_fn::{count, max, min}; + use datafusion_functions_aggregate::min_max::max_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) @@ -1929,7 +1929,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], vec![col("test.b")], vec![], @@ -1938,7 +1938,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], vec![], vec![], diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6723672ff498f..e76573e465aad 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -338,8 +338,8 @@ mod test { use super::*; use crate::test::*; - - use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder, max}; + use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions_aggregate::expr_fn::max; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 279eca9c912bb..3d8b659aa12bc 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -401,8 +401,10 @@ mod tests { use crate::test::*; use arrow::datatypes::DataType; + // TODO: stubs or real functions use datafusion_expr::test::function_stub::sum; - use datafusion_expr::{col, lit, max, min, out_ref_col, scalar_subquery, Between}; + use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; + use datafusion_functions_aggregate::expr_fn::{max, min}; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index d15d12b690da8..f11bcf3a86ed1 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -168,6 +168,7 @@ mod tests { ExprSchemable, JoinType, }; use datafusion_expr::{or, BinaryExpr, Cast, Operator}; + use datafusion_functions_aggregate::expr_fn::{max, min}; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -403,10 +404,7 @@ mod tests { .project(vec![col("a"), col("c"), col("b")])? .aggregate( vec![col("a"), col("c")], - vec![ - datafusion_expr::max(col("b").eq(lit(true))), - datafusion_expr::min(col("b")), - ], + vec![max(col("b").eq(lit(true))), min(col("b"))], )? .build()?; diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 4836eaa428a06..3acff9862d513 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -362,11 +362,9 @@ mod tests { use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; use datafusion_expr::AggregateExt; - use datafusion_expr::{ - lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, - }; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; use datafusion_functions_aggregate::sum::sum_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -527,17 +525,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![ - count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), - ], + vec![count_distinct(col("b")), max(col("b"))], )? .build()?; // Should work @@ -591,18 +579,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![ - sum(col("c")), - count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), - ], + vec![sum(col("c")), count_distinct(col("b")), max(col("b"))], )? .build()?; // Should work diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 4e39b2352d3f7..bcbbbe3130136 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -107,7 +107,7 @@ pub fn create_aggregate_expr( data_type, is_expr_nullable, )) - }, + } (AggregateFunction::Avg, false) => { Arc::new(Avg::new(input_phy_exprs[0].clone(), name, data_type)) } @@ -156,7 +156,7 @@ mod tests { use datafusion_expr::{type_coercion, Signature}; use crate::expressions::{ - try_cast, ArrayAgg, Avg, BoolAnd, BoolOr, DistinctArrayAgg, Max, Min, + try_cast, ArrayAgg, Avg, BoolAnd, BoolOr, DistinctArrayAgg, }; use super::*; @@ -222,7 +222,6 @@ mod tests { Ok(()) } - #[test] fn test_bool_and_or_expr() -> Result<()> { let funcs = vec![AggregateFunction::BoolAnd, AggregateFunction::BoolOr]; @@ -266,7 +265,7 @@ mod tests { } #[test] - fn test_sum_avg_expr() -> Result<()> { + fn test_avg_expr() -> Result<()> { let funcs = vec![AggregateFunction::Avg]; let data_types = vec![ DataType::UInt32, @@ -303,27 +302,6 @@ mod tests { Ok(()) } - #[test] - fn test_min_max() -> Result<()> { - let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = AggregateFunction::Max.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Int32, observed); - - // test decimal for min - let observed = - AggregateFunction::Min.return_type(&[DataType::Decimal128(10, 6)])?; - assert_eq!(DataType::Decimal128(10, 6), observed); - - // test decimal for max - let observed = - AggregateFunction::Max.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Decimal128(28, 13), observed); - - Ok(()) - } - #[test] fn test_avg_return_type() -> Result<()> { let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 8ff50647c7de4..be65d830e0b3d 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -137,7 +137,6 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { impl From for AggregateFunction { fn from(agg_fun: protobuf::AggregateFunction) -> Self { match agg_fun { - protobuf::AggregateFunction::Avg => Self::Avg, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 8a0331018359f..97ca66ceb61e9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,11 +23,10 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, - CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, - NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, - TryCastExpr, WindowShift, + ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, + CumeDist, DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, + NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, + RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1b2f8f6e07e15..cc3e6f76cdbb7 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -54,10 +54,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, ColumnarValue, ExprSchemable, - LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, - TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, WindowUDF, WindowUDFImpl, + Accumulator, AggregateExt, ColumnarValue, ExprSchemable, LogicalPlan, Operator, + PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{bit_and, bit_or, bit_xor}; use datafusion_functions_aggregate::string_agg::string_agg; @@ -2026,7 +2026,6 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - // 5. test with AggregateUDF #[derive(Debug)] struct DummyAggr {} diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index d5e5fc9c51cf6..658589f5d9fc8 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -45,7 +45,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let set_expr = *query.body; - + match set_expr { SetExpr::Select(mut select) => { let select_into = select.into.take(); diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 59c9c58c73701..749d101aefa49 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -71,7 +71,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; - + let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause From aaa494be13690ffd2206ab4cb382f1eae9d11d54 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 15:42:10 +0000 Subject: [PATCH 03/18] Fixing rollbacks --- .../tests/cases/roundtrip_logical_plan.rs | 23 +++++++++++++++---- datafusion/sql/src/query.rs | 1 - datafusion/sql/src/relation/mod.rs | 1 - datafusion/sql/src/select.rs | 2 -- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 7b47c8ab254f1..54a4110480554 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -34,9 +34,10 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; +use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, + count_distinct, covar_pop, covar_samp, first_value, max, median, min, stddev, stddev_pop, sum, var_pop, var_sample, }; use datafusion::prelude::*; @@ -54,10 +55,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, ColumnarValue, ExprSchemable, LogicalPlan, Operator, - PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + Accumulator, AggregateExt, ColumnarValue, ExprSchemable, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, + TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{ bit_and, bit_or, bit_xor, bool_and, bool_or, @@ -661,7 +662,9 @@ async fn roundtrip_expr_api() -> Result<()> { covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), sum(lit(1)), + max(lit(1)), median(lit(2)), + min(lit(2)), var_sample(lit(2.2)), var_pop(lit(2.2)), stddev(lit(2.2)), @@ -2030,6 +2033,15 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); + let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(max_udaf()), + vec![col("col1")], + vec![col("col1")], + vec![col("col2")], + row_number_frame.clone(), + None, + )); + // 5. test with AggregateUDF #[derive(Debug)] struct DummyAggr {} @@ -2163,6 +2175,7 @@ fn roundtrip_window() { roundtrip_expr_test(test_expr1, ctx.clone()); roundtrip_expr_test(test_expr2, ctx.clone()); roundtrip_expr_test(test_expr3, ctx.clone()); + roundtrip_expr_test(test_expr4, ctx.clone()); roundtrip_expr_test(test_expr5, ctx.clone()); roundtrip_expr_test(test_expr6, ctx); } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 658589f5d9fc8..cbbff19321d81 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -45,7 +45,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let set_expr = *query.body; - match set_expr { SetExpr::Select(mut select) => { let select_into = select.into.take(); diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index d694035ad429d..9380e569f2e43 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -29,7 +29,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { relation: TableFactor, planner_context: &mut PlannerContext, ) -> Result { - println!("Creating relation {:#?}", relation); let (plan, alias) = match relation { TableFactor::Table { name, alias, args, .. diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 749d101aefa49..0fa266e4e01d7 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -71,7 +71,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; - let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause @@ -378,7 +377,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { mut from: Vec, planner_context: &mut PlannerContext, ) -> Result { - println!("from len {}", from.len()); match from.len() { 0 => Ok(LogicalPlanBuilder::empty(true).build()?), 1 => { From 09c4ce19561b7b8aa2afb1c0446c66b7c3da98b7 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 15:48:37 +0000 Subject: [PATCH 04/18] Cleaning up rebase confusion --- datafusion/expr/src/aggregate_function.rs | 6 ---- .../physical-expr/src/aggregate/build_in.rs | 31 ------------------- datafusion/proto/src/logical_plan/to_proto.rs | 2 -- .../proto/src/physical_plan/to_proto.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 14 ++++----- 5 files changed, 8 insertions(+), 47 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index e99f275cdaaf2..4574415c8d0e1 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -107,9 +107,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - Ok(DataType::Boolean) - } AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } @@ -146,9 +143,6 @@ impl AggregateFunction { AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { Signature::any(1, Volatility::Immutable) } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) - } AggregateFunction::Avg => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index b5372d53cc762..9727431ccdd3f 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -212,39 +212,8 @@ mod tests { Ok(()) } - #[test] - fn test_bool_and_or_expr() -> Result<()> { - let funcs = vec![AggregateFunction::BoolAnd, AggregateFunction::BoolOr]; - let data_types = vec![DataType::Boolean]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::BoolAnd => { - assert!(result_agg_phy_exprs.as_any().is::()); - } - AggregateFunction::BoolOr => { - assert!(result_agg_phy_exprs.as_any().is::()); - }; - } - } - Ok(()) - } - #[test] fn test_avg_expr() -> Result<()> { - fn test_sum_avg_expr() -> Result<()> { let funcs = vec![AggregateFunction::Avg]; let data_types = vec![ DataType::UInt32, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b2390b5e04012..8de97ec26451d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -370,8 +370,6 @@ pub fn serialize_expr( AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index de5a77fcd75b7..27b4cfdd5a0db 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,7 +24,7 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, - DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, + DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 54a4110480554..3e329bcbba091 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -34,12 +34,12 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; -use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, max, median, min, stddev, stddev_pop, sum, - var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, max, median, min, stddev, + stddev_pop, sum, var_pop, var_sample, }; +use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -55,10 +55,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, ColumnarValue, ExprSchemable, - LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, - TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, WindowUDF, WindowUDFImpl, + Accumulator, AggregateExt, ColumnarValue, ExprSchemable, LogicalPlan, Operator, + PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{ bit_and, bit_or, bit_xor, bool_and, bool_or, From 7575843b5fde8b5bb829381f4bc827358410a2d8 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 16:08:37 +0000 Subject: [PATCH 05/18] last fix --- datafusion/physical-expr/src/aggregate/build_in.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 9727431ccdd3f..a4c9edcc196dc 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -145,9 +145,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{ - try_cast, ArrayAgg, Avg, BoolAnd, BoolOr, DistinctArrayAgg, - }; + use crate::expressions::{try_cast, ArrayAgg, Avg, DistinctArrayAgg}; use super::*; #[test] From c21bd5e6077961d085bd249dc77b45889a3fbffd Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 16:41:54 +0000 Subject: [PATCH 06/18] Fixing protobuf --- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 9 +++------ datafusion/proto/src/generated/prost.rs | 6 ++++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1d1fed062b972..ac1273a0530d5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -472,7 +472,7 @@ message InListNode { } enum AggregateFunction { - // MIN = 0; + UNUSED = 0; // MAX = 1; // SUM = 2; AVG = 3; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8cca0fe4a8762..0bcd6e391d07a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -532,8 +532,7 @@ impl serde::Serialize for AggregateFunction { S: serde::Serializer, { let variant = match self { - Self::Min => "MIN", - Self::Max => "MAX", + Self::Unused => "UNUSED", Self::Avg => "AVG", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", @@ -550,8 +549,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "MIN", - "MAX", + "UNUSED", "AVG", "ARRAY_AGG", "CORRELATION", @@ -597,8 +595,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { E: serde::de::Error, { match value { - "MIN" => Ok(AggregateFunction::Min), - "MAX" => Ok(AggregateFunction::Max), + "UNUSED" => Ok(AggregateFunction::Unused), "AVG" => Ok(AggregateFunction::Avg), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 5b0b0221a7b41..3db16e8b0e4c0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1926,7 +1926,8 @@ pub struct PartitionStats { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum AggregateFunction { - UNUSED = 0, + Unused = 0, + /// MAX = 1; /// SUM = 2; Avg = 3, /// COUNT = 4; @@ -1968,17 +1969,18 @@ impl AggregateFunction { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { + AggregateFunction::Unused => "UNUSED", AggregateFunction::Avg => "AVG", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::Grouping => "GROUPING", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", - AggregateFunction::UNUSED => "UNUSED", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { + "UNUSED" => Some(Self::Unused), "AVG" => Some(Self::Avg), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), From fe8669fa3afd39c7c063b58b402c4b59ef5719c4 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 16:47:07 +0000 Subject: [PATCH 07/18] Fixing case --- datafusion/proto/src/logical_plan/from_proto.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 64aed535fb97e..9e498a12d9b52 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -142,7 +142,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, - protobuf::AggregateFunction::UNUSED => panic!("This should never happen, we are retiring this but protobuf doesn't support enum with no 0 values"), + protobuf::AggregateFunction::Unused => panic!("This should never happen, we are retiring this but protobuf doesn't support enum with no 0 values"), } } } From 9446b0b0d4cac6dd5a4f18063586743a42794afb Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 17:14:23 +0000 Subject: [PATCH 08/18] Fixed aggregate funcitons name and behavior --- datafusion/functions-aggregate/src/min_max.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index d7f6d5afe5742..d49a82a70b5c1 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -560,7 +560,7 @@ impl AggregateUDFImpl for Max { } fn name(&self) -> &str { - "max" + "MAX" } fn signature(&self) -> &Signature { @@ -596,12 +596,11 @@ impl AggregateUDFImpl for Max { PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( data_type, |cur, new| { - if *cur > new { + if *cur < new { *cur = new } }, ) - // Initialize each accumulator to $NATIVE::MIN .with_starting_value($NATIVE::MIN), )) }}; @@ -730,7 +729,7 @@ impl AggregateUDFImpl for Min { } fn name(&self) -> &str { - "Min" + "MIN" } fn signature(&self) -> &Signature { @@ -766,13 +765,12 @@ impl AggregateUDFImpl for Min { PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( data_type, |cur, new| { - if *cur < new { + if *cur > new { *cur = new } }, ) - // Initialize each accumulator to $NATIVE::MIN - .with_starting_value($NATIVE::MIN), + .with_starting_value($NATIVE::MAX), )) }}; } From fd3dd56e6b63bccdf1c32673521339d51afc787b Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 17:33:57 +0000 Subject: [PATCH 09/18] Adding alias --- datafusion/functions-aggregate/src/lib.rs | 1 + datafusion/functions-aggregate/src/min_max.rs | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 2080ee0e1760b..901f1a594a8ce 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -126,6 +126,7 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), covariance::covar_pop_udaf(), min_max::max_udaf(), + min_max::min_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index d49a82a70b5c1..7a9169e85d68f 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -538,12 +538,14 @@ fn min_max_aggregate_data_type(input_type: DataType) -> DataType { #[derive(Debug)] pub struct Max { signature: Signature, + aliases: Vec, } impl Max { pub fn new() -> Self { Self { signature: Signature::numeric(1, Volatility::Immutable), + aliases: vec!["max".to_owned()], } } } @@ -576,7 +578,7 @@ impl AggregateUDFImpl for Max { } fn aliases(&self) -> &[String] { - &[] + &self.aliases } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -708,12 +710,14 @@ impl Accumulator for MaxAccumulator { #[derive(Debug)] pub struct Min { signature: Signature, + aliases: Vec, } impl Min { pub fn new() -> Self { Self { signature: Signature::numeric(1, Volatility::Immutable), + aliases: vec!["min".to_owned()], } } } @@ -729,7 +733,7 @@ impl AggregateUDFImpl for Min { } fn name(&self) -> &str { - "MIN" + "min" } fn signature(&self) -> &Signature { @@ -745,7 +749,7 @@ impl AggregateUDFImpl for Min { } fn aliases(&self) -> &[String] { - &[] + &self.aliases } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { From 77c67c51d8b5407f0ad8a51c1e7acb8c04faaba9 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 17:36:53 +0000 Subject: [PATCH 10/18] Fixing comment --- datafusion/expr/src/test/function_stub.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index c0d9ccc94141a..14e0947ddb546 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -287,7 +287,7 @@ pub fn min(expr: Expr) -> Expr { )) } -/// Testing stub implementation of COUNT aggregate +/// Testing stub implementation of Min aggregate pub struct Min { signature: Signature, aliases: Vec, @@ -371,7 +371,7 @@ pub fn max(expr: Expr) -> Expr { )) } -/// Testing stub implementation of COUNT aggregate +/// Testing stub implementation of MAX aggregate pub struct Max { signature: Signature, aliases: Vec, From 096d1dc759da0bea5c56cf1fd92214544d6f0164 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 17:45:01 +0000 Subject: [PATCH 11/18] Coherent case --- datafusion/functions-aggregate/src/min_max.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 7a9169e85d68f..85b59bf18171d 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -733,7 +733,7 @@ impl AggregateUDFImpl for Min { } fn name(&self) -> &str { - "min" + "MIN" } fn signature(&self) -> &Signature { From 02708df0f33393fe7b9049be42e75692cb01f113 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 18:18:49 +0000 Subject: [PATCH 12/18] Last step --- datafusion/functions-aggregate/src/min_max.rs | 80 ++++++++++++++++++- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 85b59bf18171d..5c4b82009f8ee 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -496,6 +496,31 @@ fn max_batch(values: &ArrayRef) -> Result { }) } +fn min_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, min_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + min_binary + ) + } + _ => min_max_batch!(values, min), + }) +} // min/max of two non-string scalar values. macro_rules! typed_min_max { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ @@ -582,7 +607,31 @@ impl AggregateUDFImpl for Max { } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true + matches!( + _args.data_type, + DataType::Int8 | + DataType::Int16 | + DataType::Int32 | + DataType::Int64 | + DataType::UInt8 | + DataType::UInt16 | + DataType::UInt32 | + DataType::UInt64 | + DataType::Float32 | + DataType::Float64 | + DataType::Date32 | + DataType::Date64 | + DataType::Time32(TimeUnit::Second) | + DataType::Time32(TimeUnit::Millisecond) | + DataType::Time64(TimeUnit::Microsecond) | + DataType::Time64(TimeUnit::Nanosecond) | + DataType::Timestamp(TimeUnit::Second, _) | + DataType::Timestamp(TimeUnit::Millisecond, _) | + DataType::Timestamp(TimeUnit::Microsecond, _) | + DataType::Timestamp(TimeUnit::Nanosecond, _) | + DataType::Decimal128(_, _) | + DataType::Decimal256(_, _) + ) } fn create_groups_accumulator( @@ -753,7 +802,31 @@ impl AggregateUDFImpl for Min { } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true + matches!( + _args.data_type, + DataType::Int8 | + DataType::Int16 | + DataType::Int32 | + DataType::Int64 | + DataType::UInt8 | + DataType::UInt16 | + DataType::UInt32 | + DataType::UInt64 | + DataType::Float32 | + DataType::Float64 | + DataType::Date32 | + DataType::Date64 | + DataType::Time32(TimeUnit::Second) | + DataType::Time32(TimeUnit::Millisecond) | + DataType::Time64(TimeUnit::Microsecond) | + DataType::Time64(TimeUnit::Nanosecond) | + DataType::Timestamp(TimeUnit::Second, _) | + DataType::Timestamp(TimeUnit::Millisecond, _) | + DataType::Timestamp(TimeUnit::Microsecond, _) | + DataType::Timestamp(TimeUnit::Nanosecond, _) | + DataType::Decimal128(_, _) | + DataType::Decimal256(_, _) + ) } fn create_groups_accumulator( @@ -831,6 +904,7 @@ impl AggregateUDFImpl for Min { // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } } @@ -857,7 +931,7 @@ impl Accumulator for MinAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; - let delta = &max_batch(values)?; + let delta = &min_batch(values)?; let new_min: Result = min_max!(&self.min, delta, min); self.min = new_min?; From 68e0d88fc91491add6d0d9ee6beb0e700143043b Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Thu, 20 Jun 2024 18:20:39 +0000 Subject: [PATCH 13/18] Fixed formatting --- datafusion/functions-aggregate/src/min_max.rs | 89 +++++++++---------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 5c4b82009f8ee..3cb98b842c801 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -609,28 +609,28 @@ impl AggregateUDFImpl for Max { fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { matches!( _args.data_type, - DataType::Int8 | - DataType::Int16 | - DataType::Int32 | - DataType::Int64 | - DataType::UInt8 | - DataType::UInt16 | - DataType::UInt32 | - DataType::UInt64 | - DataType::Float32 | - DataType::Float64 | - DataType::Date32 | - DataType::Date64 | - DataType::Time32(TimeUnit::Second) | - DataType::Time32(TimeUnit::Millisecond) | - DataType::Time64(TimeUnit::Microsecond) | - DataType::Time64(TimeUnit::Nanosecond) | - DataType::Timestamp(TimeUnit::Second, _) | - DataType::Timestamp(TimeUnit::Millisecond, _) | - DataType::Timestamp(TimeUnit::Microsecond, _) | - DataType::Timestamp(TimeUnit::Nanosecond, _) | - DataType::Decimal128(_, _) | - DataType::Decimal256(_, _) + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Timestamp(TimeUnit::Second, _) + | DataType::Timestamp(TimeUnit::Millisecond, _) + | DataType::Timestamp(TimeUnit::Microsecond, _) + | DataType::Timestamp(TimeUnit::Nanosecond, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) ) } @@ -804,28 +804,28 @@ impl AggregateUDFImpl for Min { fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { matches!( _args.data_type, - DataType::Int8 | - DataType::Int16 | - DataType::Int32 | - DataType::Int64 | - DataType::UInt8 | - DataType::UInt16 | - DataType::UInt32 | - DataType::UInt64 | - DataType::Float32 | - DataType::Float64 | - DataType::Date32 | - DataType::Date64 | - DataType::Time32(TimeUnit::Second) | - DataType::Time32(TimeUnit::Millisecond) | - DataType::Time64(TimeUnit::Microsecond) | - DataType::Time64(TimeUnit::Nanosecond) | - DataType::Timestamp(TimeUnit::Second, _) | - DataType::Timestamp(TimeUnit::Millisecond, _) | - DataType::Timestamp(TimeUnit::Microsecond, _) | - DataType::Timestamp(TimeUnit::Nanosecond, _) | - DataType::Decimal128(_, _) | - DataType::Decimal256(_, _) + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Timestamp(TimeUnit::Second, _) + | DataType::Timestamp(TimeUnit::Millisecond, _) + | DataType::Timestamp(TimeUnit::Microsecond, _) + | DataType::Timestamp(TimeUnit::Nanosecond, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) ) } @@ -904,7 +904,6 @@ impl AggregateUDFImpl for Min { // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } } From 034a330d2ffd77de6110e946d42211aeb0dcfe18 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Fri, 21 Jun 2024 13:25:39 +0000 Subject: [PATCH 14/18] Fixing optimzer --- .../aggregate_statistics.rs | 33 +- datafusion/functions-aggregate/Cargo.toml | 3 + datafusion/functions-aggregate/src/min_max.rs | 454 ++++++++++++++++++ datafusion/proto/gen/src/main.rs | 1 + 4 files changed, 479 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index ca1582bcb34f7..912e07c973e88 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -173,6 +173,23 @@ fn take_optimizable_column_and_table_count( None } +fn unwrap_min(agg_expr: &dyn AggregateExpr) -> Option<&AggregateFunctionExpr> { + if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { + if casted_expr.fun().name() == "MIN" { + return Some(casted_expr); + } + } + None +} + +fn unwrap_max(agg_expr: &dyn AggregateExpr) -> Option<&AggregateFunctionExpr> { + if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { + if casted_expr.fun().name() == "MAX" { + return Some(casted_expr); + } + } + None +} /// If this agg_expr is a min that is exactly defined in the statistics, return it. fn take_optimizable_min( agg_expr: &dyn AggregateExpr, @@ -182,9 +199,7 @@ fn take_optimizable_min( match *num_rows { 0 => { // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { + if let Some(casted_expr) = unwrap_min(agg_expr) { if let Ok(min_data_type) = ScalarValue::try_from(casted_expr.field().unwrap().data_type()) { @@ -194,9 +209,7 @@ fn take_optimizable_min( } value if value > 0 => { let col_stats = &stats.column_statistics; - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { + if let Some(casted_expr) = unwrap_min(agg_expr) { if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] @@ -232,9 +245,7 @@ fn take_optimizable_max( match *num_rows { 0 => { // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { + if let Some(casted_expr) = unwrap_max(agg_expr){ if let Ok(max_data_type) = ScalarValue::try_from(casted_expr.field().unwrap().data_type()) { @@ -244,9 +255,7 @@ fn take_optimizable_max( } value if value > 0 => { let col_stats = &stats.column_statistics; - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { + if let Some(casted_expr) = unwrap_max(agg_expr){ if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 26630a0352d58..05b627da3467f 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -48,3 +48,6 @@ datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.14" sqlparser = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } \ No newline at end of file diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 3cb98b842c801..a5e52c5d51137 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -536,6 +536,254 @@ macro_rules! typed_min_max { }}; } +// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. + +// Keep track of the minimum or maximum value in a sliding window. +// +// `moving min max` provides one data structure for keeping track of the +// minimum value and one for keeping track of the maximum value in a sliding +// window. +// +// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +// push to this stack all elements popped from first stack while updating their current min/max. Now pop from +// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +// +// The complexity of the operations are +// - O(1) for getting the minimum/maximum +// - O(1) for push +// - amortized O(1) for pop + +/// ``` +/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMin; +/// let mut moving_min = MovingMin::::new(); +/// moving_min.push(2); +/// moving_min.push(1); +/// moving_min.push(3); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(2)); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(1)); +/// +/// assert_eq!(moving_min.min(), Some(&3)); +/// assert_eq!(moving_min.pop(), Some(3)); +/// +/// assert_eq!(moving_min.min(), None); +/// assert_eq!(moving_min.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMin { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMin { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMin { + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window with `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the minimum of the sliding window or `None` if the window is + /// empty. + #[inline] + pub fn min(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, min)), None) => Some(min), + (None, Some((_, min))) => Some(min), + (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, min)) => { + if val > *min { + (val, min.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let min = if last.1 < val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), min); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +/// ``` +/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMax; +/// let mut moving_max = MovingMax::::new(); +/// moving_max.push(2); +/// moving_max.push(3); +/// moving_max.push(1); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(2)); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(3)); +/// +/// assert_eq!(moving_max.max(), Some(&1)); +/// assert_eq!(moving_max.pop(), Some(1)); +/// +/// assert_eq!(moving_max.max(), None); +/// assert_eq!(moving_max.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMax { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMax { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMax { + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with + /// `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the maximum of the sliding window or `None` if the window is empty. + #[inline] + pub fn max(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, max)), None) => Some(max), + (None, Some((_, max))) => Some(max), + (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, max)) => { + if val < *max { + (val, max.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let max = if last.1 > val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), max); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + + make_udaf_expr_and_func!( Max, max, @@ -712,6 +960,11 @@ impl AggregateUDFImpl for Max { _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } } + + fn create_sliding_accumulator(&self, args:AccumulatorArgs) -> Result> { + Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?)) + } + } /// An accumulator to compute the maximum value @@ -907,6 +1160,12 @@ impl AggregateUDFImpl for Min { _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } } + + + fn create_sliding_accumulator(&self, args:AccumulatorArgs) -> Result> { + Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?)) + } + } /// An accumulator to compute the minimum value #[derive(Debug)] @@ -950,6 +1209,131 @@ impl Accumulator for MinAccumulator { } } + + +#[derive(Debug)] +pub struct SlidingMinAccumulator { + min: ScalarValue, + moving_min: MovingMin, +} + +impl SlidingMinAccumulator { + + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + moving_min: MovingMin::::new(), + }) + } +} + +impl Accumulator for SlidingMinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.min.clone()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + self.moving_min.push(val); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + (self.moving_min).pop(); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + } +} + +#[derive(Debug)] +pub struct SlidingMaxAccumulator { + max: ScalarValue, + moving_max: MovingMax, +} + +impl SlidingMaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + moving_max: MovingMax::::new(), + }) + } +} + +impl Accumulator for SlidingMaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + self.moving_max.push(val); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for _idx in 0..values[0].len() { + (self.moving_max).pop(); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.max.clone()]) + } + + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + } +} + #[cfg(test)] mod tests { use super::*; @@ -987,4 +1371,74 @@ mod tests { check(&mut max(), &[&[zero], &[neg_inf]], zero); check(&mut max(), &[&[zero, neg_inf]], zero); } + + + use datafusion_common::Result; + use rand::Rng; + + fn get_random_vec_i32(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input = Vec::with_capacity(len); + for _i in 0..len { + input.push(rng.gen_range(0..100)); + } + input + } + + fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_min = MovingMin::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().min().unwrap()); + + moving_min.push(data[i]); + if i > n_sliding_window { + moving_min.pop(); + } + res.push(*moving_min.min().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_max = MovingMax::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().max().unwrap()); + + moving_max.push(data[i]); + if i > n_sliding_window { + moving_max.pop(); + } + res.push(*moving_max.max().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + #[test] + fn moving_min_tests() -> Result<()> { + moving_min_i32(100, 10)?; + moving_min_i32(100, 20)?; + moving_min_i32(100, 50)?; + moving_min_i32(100, 100)?; + Ok(()) + } + + #[test] + fn moving_max_tests() -> Result<()> { + moving_max_i32(100, 10)?; + moving_max_i32(100, 20)?; + moving_max_i32(100, 50)?; + moving_max_i32(100, 100)?; + Ok(()) + } + } diff --git a/datafusion/proto/gen/src/main.rs b/datafusion/proto/gen/src/main.rs index 22c16eacb0938..8fd59336115f0 100644 --- a/datafusion/proto/gen/src/main.rs +++ b/datafusion/proto/gen/src/main.rs @@ -32,6 +32,7 @@ fn main() -> Result<(), String> { .file_descriptor_set_path(&descriptor_path) .out_dir(out_dir) .compile_well_known_types() + .protoc_arg("--experimental_allow_proto3_optional") .extern_path(".google.protobuf", "::pbjson_types") .compile_protos(&[proto_path], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {e}"))?; From 803cadbc2eb01d04bdb80f47a9dfc6732a4340d9 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Fri, 21 Jun 2024 14:10:46 +0000 Subject: [PATCH 15/18] Fixing tests --- datafusion/expr/src/expr.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2de7d4c178267..4f10318634b20 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2236,8 +2236,6 @@ mod test { "first_value", "last_value", "nth_value", - "min", - "max", "avg", ]; for name in names { From 33a8ee43b3b7321e942c8b8d89d6feb778877ed3 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Fri, 21 Jun 2024 14:58:00 +0000 Subject: [PATCH 16/18] Fixing tests --- datafusion/functions-aggregate/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 901f1a594a8ce..26e01af07ac35 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -179,10 +179,11 @@ mod tests { #[test] fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); + let migrated_functions = vec!["count", "max", "min"]; for func in all_default_aggregate_functions() { // TODO: remove this // These functions are in intermidiate migration state, skip them - if func.name().to_lowercase() == "count" { + if migrated_functions.contains(&func.name().to_lowercase().as_str()) { continue; } assert!( From 81ad007a9b6f44cbd4aff7898d45e2f122b6226a Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Sun, 23 Jun 2024 20:24:05 +0000 Subject: [PATCH 17/18] Adding optimizer rules --- .../optimizer/src/eliminate_distinct.rs | 140 ++++++++++++++++++ datafusion/optimizer/src/lib.rs | 1 + 2 files changed, 141 insertions(+) create mode 100644 datafusion/optimizer/src/eliminate_distinct.rs diff --git a/datafusion/optimizer/src/eliminate_distinct.rs b/datafusion/optimizer/src/eliminate_distinct.rs new file mode 100644 index 0000000000000..f1d5877b1b49f --- /dev/null +++ b/datafusion/optimizer/src/eliminate_distinct.rs @@ -0,0 +1,140 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateDistinctFromMinMax`] Removes redundant distinct in min and max + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{Aggregate, Expr}; +use std::sync::OnceLock; + +/// Optimization rule that eliminate redundant distinct in min and max expr. +#[derive(Default)] +pub struct EliminateDistinct; + +impl EliminateDistinct { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} +static WORKSPACE_ROOT_LOCK: OnceLock> = OnceLock::new(); + +fn rewrite_aggr_expr(expr:Expr) -> (bool, Expr) { + match expr { + Expr::AggregateFunction(ref fun) => { + let fn_name = fun.func_def.name().to_lowercase(); + if fun.distinct && WORKSPACE_ROOT_LOCK.get_or_init(|| vec!["min".to_string(), "max".to_string()]).contains(&fn_name) { + (true, Expr::AggregateFunction(AggregateFunction{ + func_def:fun.func_def.clone(), + args:fun.args.clone(), + distinct:false, + filter:fun.filter.clone(), + order_by:fun.order_by.clone(), + null_treatment: fun.null_treatment + })) + } else { + (false, expr) + } + }, + _ => (false, expr) + } +} +impl OptimizerRule for EliminateDistinct { + fn try_optimize( + &self, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called EliminateDistinct::rewrite") + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Aggregate(agg) => { + let mut aggr_expr = vec![]; + let mut transformed = false; + for expr in agg.aggr_expr { + let rewrite_result = rewrite_aggr_expr(expr); + transformed = transformed || rewrite_result.0; + aggr_expr.push(rewrite_result.1); + } + + println!("Transformed yes {}", transformed); + let transformed = if transformed { + Transformed::yes + } else { + Transformed::no + }; + Aggregate::try_new(agg.input, agg.group_expr, aggr_expr) + .map(|f| transformed(LogicalPlan::Aggregate(f))) + } + _ => Ok(Transformed::no(plan)), + } + } + fn name(&self) -> &str { + "eliminate_distinct" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_expr::AggregateExt; + use datafusion_expr::test::function_stub::min; + use std::sync::Arc; + + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { + crate::test::assert_optimized_plan_eq( + Arc::new(EliminateDistinct::new()), + plan, + expected, + ) + } + + #[test] + fn eliminate_distinct_from_min_expr() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let aggr_expr = min(col("b")).distinct().build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![aggr_expr])? + .build()?; + let expected = "Limit: skip=5, fetch=10\ + \n Sort: test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index c172d59797569..a2daf828c4f7e 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -33,6 +33,7 @@ pub mod common_subexpr_eliminate; pub mod decorrelate; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; +pub mod eliminate_distinct; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; pub mod eliminate_group_by_constant; From ad2bb65df5045a9022618ba44107f5195a33d6bf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 27 Jun 2024 16:53:32 -0400 Subject: [PATCH 18/18] bring in interval support --- datafusion/functions-aggregate/src/min_max.rs | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 4a03cc3203739..07d4c80498e67 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -41,6 +41,7 @@ use arrow::array::{ Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, }; use arrow::compute; use arrow::datatypes::{ @@ -58,6 +59,7 @@ use arrow::datatypes::{ }; use arrow::datatypes::i256; +use arrow_schema::IntervalUnit; use datafusion_common::ScalarValue; use datafusion_expr::GroupsAccumulator; @@ -443,7 +445,25 @@ macro_rules! min_max_batch { $OP ) } - other => { + DataType::Interval(IntervalUnit::YearMonth) => { + typed_min_max_batch!( + $VALUES, + IntervalYearMonthArray, + IntervalYearMonth, + $OP + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + typed_min_max_batch!( + $VALUES, + IntervalMonthDayNanoArray, + IntervalMonthDayNano, + $OP + ) + } other => { // This should have been handled before return internal_err!( "Min/Max accumulator not implemented for type {:?}",