diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 110e5c3194c7..95b114ca4a00 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1339,6 +1339,7 @@ dependencies = [ "indexmap 2.2.6", "itertools", "log", + "paste", "regex-syntax", ] diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 694911592b5d..b1d9eb057753 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -168,20 +168,6 @@ pub fn max(expr: Expr) -> Expr { )) } -/// Create an expression to represent the sum() aggregate function -/// -/// TODO: Remove this function and use `sum` from `datafusion_functions_aggregate::expr_fn` instead -pub fn sum(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Sum, - vec![expr], - false, - None, - None, - None, - )) -} - /// Create an expression to represent the array_agg() aggregate function pub fn array_agg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index ae3aaa55199e..2f1ece32ab15 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1719,7 +1719,7 @@ pub fn unnest_with_options( mod tests { use super::*; use crate::logical_plan::StringifiedPlan; - use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery, sum}; + use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; use datafusion_common::SchemaError; @@ -1775,28 +1775,6 @@ mod tests { ); } - #[test] - fn plan_builder_aggregate() -> Result<()> { - let plan = - table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? - .aggregate( - vec![col("state")], - vec![sum(col("salary")).alias("total_salary")], - )? - .project(vec![col("state"), col("total_salary")])? - .limit(2, Some(10))? - .build()?; - - let expected = "Limit: skip=2, fetch=10\ - \n Projection: employee_csv.state, total_salary\ - \n Aggregate: groupBy=[[employee_csv.state]], aggr=[[SUM(employee_csv.salary) AS total_salary]]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan:?}")); - - Ok(()) - } - #[test] fn plan_builder_sort() -> Result<()> { let plan = @@ -2037,36 +2015,6 @@ mod tests { } } - #[test] - fn aggregate_non_unique_names() -> Result<()> { - let plan = table_scan( - Some("employee_csv"), - &employee_schema(), - // project state and salary by column index - Some(vec![3, 4]), - )? - // two columns with the same name => error - .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); - - match plan { - Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: - Column { - relation: Some(TableReference::Bare { table }), - name, - }, - }, - _, - )) => { - assert_eq!(*"employee_csv", *table); - assert_eq!("state", &name); - Ok(()) - } - _ => plan_err!("Plan should have returned an DataFusionError::SchemaError"), - } - } - fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index e703250c92e1..cb14f6bdd4a3 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -50,6 +50,7 @@ hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } +paste = "1.0.14" regex-syntax = "0.8.0" [dev-dependencies] diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index dfbd5f5632ee..5d219e625235 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -117,13 +117,14 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { #[cfg(test)] mod tests { use super::*; + use crate::test::function_stub::sum; use crate::test::*; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; use datafusion_expr::{ col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, WindowFrame, + out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use std::sync::Arc; @@ -275,11 +276,9 @@ mod tests { #[test] fn test_count_wildcard_on_non_count_aggregate() -> Result<()> { let table_scan = test_table_scan()?; - let err = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![sum(wildcard())]) - .unwrap_err() - .to_string(); - assert!(err.contains("Error during planning: No function matches the given name and argument types 'SUM(Null)'."), "{err}"); + let res = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(wildcard())]); + assert!(res.is_err()); Ok(()) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 174440dac316..c6940ef711d8 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -840,13 +840,15 @@ mod test { use arrow::datatypes::Schema; use datafusion_expr::logical_plan::{table_scan, JoinType}; - use datafusion_expr::{avg, lit, logical_plan::builder::LogicalPlanBuilder, sum}; + + use datafusion_expr::{avg, lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_expr::{ grouping_set, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF, Volatility, }; use crate::optimizer::OptimizerContext; + use crate::test::function_stub::sum; use crate::test::*; use super::*; diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index c294bc68f027..0e6dd9ac6332 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -91,10 +91,11 @@ mod tests { use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - col, lit, logical_plan::builder::LogicalPlanBuilder, sum, Expr, LogicalPlan, + col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, }; use crate::eliminate_filter::EliminateFilter; + use crate::test::function_stub::sum; use crate::test::*; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 1b0907d9736d..7e60f8f8b0c3 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -100,11 +100,11 @@ mod tests { use datafusion_expr::{ col, logical_plan::{builder::LogicalPlanBuilder, JoinType}, - sum, }; use std::sync::Arc; use crate::push_down_limit::PushDownLimit; + use crate::test::function_stub::sum; fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e56bfd051fe2..3bf16de258ce 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1090,13 +1090,14 @@ mod tests { use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - col, in_list, in_subquery, lit, sum, ColumnarValue, Extension, ScalarUDF, + col, in_list, in_subquery, lit, ColumnarValue, Extension, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, UserDefinedLogicalNodeCore, Volatility, }; use crate::optimizer::Optimizer; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; + use crate::test::function_stub::sum; use crate::test::*; use crate::OptimizerContext; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index cb28961497f4..0b1f1a0e6452 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -400,10 +400,9 @@ mod tests { use super::*; use crate::test::*; + use crate::test::function_stub::sum; use arrow::datatypes::DataType; - use datafusion_expr::{ - col, lit, max, min, out_ref_col, scalar_subquery, sum, Between, - }; + use datafusion_expr::{col, lit, max, min, out_ref_col, scalar_subquery, Between}; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 06d0dee27099..5d9de39f9676 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -360,12 +360,13 @@ impl OptimizerRule for SingleDistinctToGroupBy { #[cfg(test)] mod tests { use super::*; + use crate::test::function_stub::sum; use crate::test::*; use datafusion_expr::expr; use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, - sum, AggregateFunction, + AggregateFunction, }; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { diff --git a/datafusion/optimizer/src/test/function_stub.rs b/datafusion/optimizer/src/test/function_stub.rs new file mode 100644 index 000000000000..997d2cb607e1 --- /dev/null +++ b/datafusion/optimizer/src/test/function_stub.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate function stubs to test SQL optimizers. +//! +//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate + +use std::any::Any; + +use arrow::datatypes::{ + DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ + expr::AggregateFunction, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::AggregateOrderSensitivity, + Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, + Volatility, +}; + +macro_rules! create_func { + ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + paste::paste! { + /// Singleton instance of [$UDAF], ensures the UDAF is only created once + /// named STATIC_$(UDAF). For example `STATIC_FirstValue` + #[allow(non_upper_case_globals)] + static [< STATIC_ $UDAF >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + /// AggregateFunction that returns a [AggregateUDF] for [$UDAF] + /// + /// [AggregateUDF]: datafusion_expr::AggregateUDF + pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { + [< STATIC_ $UDAF >] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + }) + .clone() + } + } + } +} + +create_func!(Sum, sum_udaf); + +pub(crate) fn sum(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + sum_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Stub `sum` used for optimizer testing +#[derive(Debug)] +pub struct Sum { + signature: Signature, + aliases: Vec, +} + +impl Sum { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["sum".to_string()], + } + } +} + +impl Default for Sum { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Sum { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "SUM" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("SUM expects exactly one argument"); + } + + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + + fn coerced_type(data_type: &DataType) -> Result { + match data_type { + DataType::Dictionary(_, v) => coerced_type(v), + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { + Ok(data_type.clone()) + } + dt if dt.is_signed_integer() => Ok(DataType::Int64), + dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), + dt if dt.is_floating() => Ok(DataType::Float64), + _ => exec_err!("Sum not supported for {}", data_type), + } + } + + Ok(vec![coerced_type(&arg_types[0])?]) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Int64 => Ok(DataType::Int64), + DataType::UInt64 => Ok(DataType::UInt64), + DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal128(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal128(new_precision, *scale)) + } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } + other => { + exec_err!("[return_type] SUM not supported for {}", other) + } + } + } + + fn accumulator(&self, _args: AccumulatorArgs) -> Result> { + unreachable!("stub should not have accumulate()") + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unreachable!("stub should not have state_fields()") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + false + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unreachable!("stub should not have accumulate()") + } + + fn create_sliding_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unreachable!("stub should not have accumulate()") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } +} diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 98d19956df3c..fa468ccedca3 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -24,6 +24,7 @@ use datafusion_common::{assert_contains, Result}; use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; use std::sync::Arc; +pub mod function_stub; pub mod user_defined; pub fn test_table_scan_fields() -> Vec { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index deae97fecc96..f32f4b04938f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -32,7 +32,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::expr_fn::{ - covar_pop, covar_samp, first_value, median, var_sample, + covar_pop, covar_samp, first_value, median, sum, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index b08d5846733b..505f78fa9c4e 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5334,3 +5334,34 @@ physical_plan 03)----AggregateExec: mode=Partial, gby=[], aggr=[first_value(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c2 DESC NULLS FIRST]] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/convert_first_last.csv]]}, projection=[c1, c2], output_orderings=[[c1@0 ASC NULLS LAST], [c2@1 DESC]], has_header=true + +# test building plan with aggreagte sum + +statement ok +create table employee_csv(id int, first_name string, last_name varchar, state varchar, salary bigint) as values (1, 'jenson', 'huang', 'unemployed', 10); + +query TI +select state, sum(salary) total_salary from employee_csv group by state; +---- +unemployed 10 + +statement ok +set datafusion.explain.logical_plan_only = true; + +query TT +explain select state, sum(salary) as total_salary from employee_csv group by state; +---- +logical_plan +01)Projection: employee_csv.state, SUM(employee_csv.salary) AS total_salary +02)--Aggregate: groupBy=[[employee_csv.state]], aggr=[[SUM(employee_csv.salary)]] +03)----TableScan: employee_csv projection=[state, salary] + +# fail if there is duplicate name +query error DataFusion error: Schema error: Schema contains qualified field name employee_csv\.state and unqualified field name state which would be ambiguous +select state, sum(salary) as state from employee_csv group by state; + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +drop table employee_csv; \ No newline at end of file