diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 6de8c98f391f5..6cf9da06bb0ae 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -856,6 +856,7 @@ dependencies = [ "arrow-schema", "datafusion-common", "datafusion-expr", + "datafusion-optimizer", "log", "sqlparser", ] diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 67a367c5ef3f8..da2296177188d 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -630,9 +630,9 @@ impl ExprSchema for DFSchema { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DFField { /// Optional qualifier (usually a table or relation name) - qualifier: Option, + pub qualifier: Option, /// Arrow field definition - field: Field, + pub field: Field, } impl DFField { diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 23c6623ab1b5a..aeb1ed47da0b1 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -32,36 +32,93 @@ use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; +use datafusion::test_util::parquet_test_data; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; +use datafusion_common::ScalarValue; +use datafusion_common::ScalarValue::UInt64; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable}; +use datafusion_expr::Expr::Wildcard; +use datafusion_expr::{ + avg, col, count, expr, lit, max, sum, AggregateFunction, Expr, ExprSchemable, + Subquery, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, +}; + +#[tokio::test] +async fn test_count_wildcard_on_where_exist() -> Result<()> { + let ctx = create_join_context()?; + + let df_results = ctx + .table("t1") + .await? + .select(vec![col("a"), col("b")])? + .filter(Expr::Exists { + subquery: Subquery { + subquery: Arc::new( + ctx.table("t2") + .await? + .aggregate(vec![], vec![count(Expr::Wildcard)])? + .select(vec![count(Expr::Wildcard)])? + .into_optimized_plan()?, + ), + outer_ref_columns: vec![], + }, + negated: false, + })? + .explain(false, false)? + .collect() + .await?; + + #[rustfmt::skip] + let expected = vec![ + "+--------------+-------------------------------------------------------+", + "| plan_type | plan |", + "+--------------+-------------------------------------------------------+", + "| logical_plan | Filter: EXISTS () |", + "| | Subquery: |", + "| | Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] |", + "| | TableScan: t2 projection=[a] |", + "| | TableScan: t1 projection=[a, b] |", + "+--------------+-------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &df_results); + Ok(()) +} #[tokio::test] -async fn count_wildcard() -> Result<()> { +async fn test_count_wildcard_on_window() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_tiny_pages", - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) - .await?; + register_alltypes_tiny_pages_parquet(&ctx).await?; let sql_results = ctx - .sql("select count(*) from alltypes_tiny_pages") + .sql("select COUNT(*) OVER(ORDER BY timestamp_col DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from alltypes_tiny_pages") .await? - .select(vec![count(Expr::Wildcard)])? .explain(false, false)? .collect() .await?; - // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("alltypes_tiny_pages") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .select(vec![Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::AggregateFunction(AggregateFunction::Count), + vec![Expr::Wildcard], + vec![], + vec![Expr::Sort(Sort::new( + Box::new(col("timestamp_col")), + false, + true, + ))], + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::IntervalDayTime( + Some(6), + )), + end_bound: WindowFrameBound::Following(ScalarValue::IntervalDayTime( + Some(2), + )), + }, + ))])? .explain(false, false)? .collect() .await?; @@ -72,21 +129,37 @@ async fn count_wildcard() -> Result<()> { pretty_format_batches(&df_results)?.to_string() ); - let results = ctx + Ok(()) +} + +#[tokio::test] +async fn test_count_wildcard_on_aggregate() -> Result<()> { + let ctx = SessionContext::new(); + register_alltypes_tiny_pages_parquet(&ctx).await?; + + let sql_results = ctx + .sql("select count(*) from alltypes_tiny_pages") + .await? + .select(vec![count(Expr::Wildcard)])? + .explain(false, false)? + .collect() + .await?; + + // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. + let df_results = ctx .table("alltypes_tiny_pages") .await? .aggregate(vec![], vec![count(Expr::Wildcard)])? + .select(vec![count(Expr::Wildcard)])? + .explain(false, false)? .collect() .await?; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 7300 |", - "+-----------------+", - ]; - assert_batches_sorted_eq!(expected, &results); + //make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&sql_results)?.to_string(), + pretty_format_batches(&df_results)?.to_string() + ); Ok(()) } @@ -1047,3 +1120,14 @@ async fn table_with_nested_types(n: usize) -> Result { ctx.register_batch("shapes", batch)?; ctx.table("shapes").await } + +pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> { + let testdata = parquet_test_data(); + ctx.register_parquet( + "alltypes_tiny_pages", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await?; + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 65cbcdfe06d95..beae3ac140d87 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -1161,15 +1161,6 @@ async fn try_execute_to_batches( /// Execute query and return results as a Vec of RecordBatches async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec { let df = ctx.sql(sql).await.unwrap(); - - // We are not really interested in the direct output of optimized_logical_plan - // since the physical plan construction already optimizes the given logical plan - // and we want to avoid double-optimization as a consequence. So we just construct - // it here to make sure that it doesn't fail at this step and get the optimized - // schema (to assert later that the logical and optimized schemas are the same). - let optimized = df.clone().into_optimized_plan().unwrap(); - assert_eq!(df.logical_plan().schema(), optimized.schema()); - df.collect().await.unwrap() } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 2772090e024e3..a1199f33cb939 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -16,16 +16,26 @@ // under the License. use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; +use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; +use datafusion_expr::Expr::{Exists, InSubquery, ScalarSubquery}; +use datafusion_expr::{ + aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter, + LogicalPlan, Projection, Sort, Subquery, Window, +}; +use std::string::ToString; +use std::sync::Arc; use crate::analyzer::AnalyzerRule; use crate::rewrite::TreeNodeRewritable; +pub const COUNT_STAR: &str = "COUNT(*)"; + /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473. +#[derive(Default)] pub struct CountWildcardRule {} impl CountWildcardRule { @@ -45,35 +55,113 @@ impl AnalyzerRule for CountWildcardRule { } fn analyze_internal(plan: LogicalPlan) -> Result> { + let mut rewriter = CountWildcardRewriter {}; + match plan { LogicalPlan::Window(window) => { - let window_expr = handle_wildcard(&window.window_expr); + let window_expr = window + .window_expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect::>(); + Ok(Some(LogicalPlan::Window(Window { input: window.input.clone(), window_expr, - schema: window.schema, + schema: rewrite_schema(window.schema), }))) } LogicalPlan::Aggregate(agg) => { - let aggr_expr = handle_wildcard(&agg.aggr_expr); + let aggr_expr = agg + .aggr_expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect(); Ok(Some(LogicalPlan::Aggregate( Aggregate::try_new_with_schema( agg.input.clone(), agg.group_expr.clone(), aggr_expr, - agg.schema, + rewrite_schema(agg.schema), )?, ))) } + LogicalPlan::Sort(Sort { expr, input, fetch }) => { + let sort_expr = expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect(); + Ok(Some(LogicalPlan::Sort(Sort { + expr: sort_expr, + input, + fetch, + }))) + } + LogicalPlan::Projection(projection) => { + let projection_expr = projection + .expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect(); + Ok(Some(LogicalPlan::Projection( + Projection::try_new_with_schema( + projection_expr, + projection.input, + rewrite_schema(projection.schema), + )?, + ))) + } + LogicalPlan::Filter(Filter { + predicate, input, .. + }) => { + let predicate = predicate.rewrite(&mut rewriter).unwrap(); + Ok(Some(LogicalPlan::Filter( + Filter::try_new(predicate, input).unwrap(), + ))) + } + _ => Ok(None), } } -// handle Count(Expr:Wildcard) with DataFrame API -pub fn handle_wildcard(exprs: &[Expr]) -> Vec { - exprs - .iter() - .map(|expr| match expr { +struct CountWildcardRewriter {} + +impl ExprRewriter for CountWildcardRewriter { + fn mutate(&mut self, old_expr: Expr) -> Result { + let new_expr = match old_expr.clone() { + Expr::Column(Column { name, relation }) if name.contains(COUNT_STAR) => { + Expr::Column(Column { + name: name.replace( + COUNT_STAR, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ), + relation: relation.clone(), + }) + } + Expr::WindowFunction(expr::WindowFunction { + fun: + window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args, + partition_by, + order_by, + window_frame, + }) if args.len() == 1 => match args[0] { + Expr::Wildcard => { + Expr::WindowFunction(datafusion_expr::expr::WindowFunction { + fun: window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args: vec![lit(COUNT_STAR_EXPANSION)], + partition_by, + order_by, + window_frame, + }) + } + + _ => old_expr, + }, Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, args, @@ -83,12 +171,91 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec { Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, args: vec![lit(COUNT_STAR_EXPANSION)], - distinct: *distinct, - filter: filter.clone(), + distinct, + filter, }), - _ => expr.clone(), + _ => old_expr, }, - _ => expr.clone(), + + ScalarSubquery(Subquery { + subquery, + outer_ref_columns, + }) => { + let new_plan = subquery + .as_ref() + .clone() + .transform_down(&analyze_internal) + .unwrap(); + ScalarSubquery(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + }) + } + InSubquery { + expr, + subquery, + negated, + } => { + let new_plan = subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal) + .unwrap(); + + InSubquery { + expr, + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + } + } + Exists { subquery, negated } => { + let new_plan = subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal) + .unwrap(); + + Exists { + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + } + } + _ => old_expr, + }; + Ok(new_expr) + } +} + +fn rewrite_schema(schema: DFSchemaRef) -> DFSchemaRef { + let new_fields = schema + .fields() + .iter() + .map(|DFField { qualifier, field }| { + let mut name = field.name().clone(); + if name.contains(COUNT_STAR) { + name = name.replace( + COUNT_STAR, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ) + } + DFField::new( + qualifier.clone(), + name.as_str(), + field.data_type().clone(), + field.is_nullable(), + ) }) - .collect() + .collect::>(); + + DFSchemaRef::new( + DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(), + ) } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index b5cb6aec57f8d..ce92ba910df46 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -40,6 +40,7 @@ unicode_expressions = [] arrow-schema = { workspace = true } datafusion-common = { path = "../common", version = "21.0.0" } datafusion-expr = { path = "../expr", version = "21.0.0" } +datafusion-optimizer = { path = "../optimizer", version = "21.0.0" } log = "^0.4" sqlparser = "0.32" diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e5af0eb26976c..f529726fbdd52 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -18,7 +18,6 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; use datafusion_common::{DFSchema, DataFusionError, Result}; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::window_frame::regularize; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, @@ -216,12 +215,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Special case rewrite COUNT(*) to COUNT(constant) AggregateFunction::Count => args .into_iter() - .map(|a| match a { - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) - } - _ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context), - }) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context)) .collect::>>()?, _ => self.function_args_to_expr(args, schema, planner_context)?, }; diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 3242989f574d4..352a8264f7cba 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -34,6 +34,8 @@ use datafusion_expr::{AggregateUDF, ScalarUDF}; use datafusion_sql::parser::DFParser; use datafusion_sql::planner::{ContextProvider, ParserOptions, SqlToRel}; +use datafusion_optimizer::analyzer::Analyzer; +use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use rstest::rstest; #[cfg(test)] @@ -755,7 +757,7 @@ fn select_aggregate_with_having_referencing_column_not_in_select() { assert_eq!( "Plan(\"HAVING clause references non-aggregate values: \ Expression person.first_name could not be resolved from available columns: \ - COUNT(UInt8(1))\")", + COUNT(*)\")", format!("{err:?}") ); } @@ -2424,7 +2426,12 @@ fn logical_plan_with_dialect_and_options( let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; - planner.statement_to_plan(ast.pop_front().unwrap()) + match planner.statement_to_plan(ast.pop_front().unwrap()) { + Ok(plan) => { + Analyzer::new().execute_and_check(&plan, OptimizerContext::new().options()) + } + Err(err) => Err(err), + } } /// Create logical plan, write with formatter, compare to expected output