diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 67a367c5ef3f8..da2296177188d 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -630,9 +630,9 @@ impl ExprSchema for DFSchema { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DFField { /// Optional qualifier (usually a table or relation name) - qualifier: Option, + pub qualifier: Option, /// Arrow field definition - field: Field, + pub field: Field, } impl DFField { diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 23c6623ab1b5a..af9abb5bed5cc 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -51,11 +51,9 @@ async fn count_wildcard() -> Result<()> { let sql_results = ctx .sql("select count(*) from alltypes_tiny_pages") .await? - .select(vec![count(Expr::Wildcard)])? .explain(false, false)? .collect() .await?; - // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("alltypes_tiny_pages") @@ -452,7 +450,7 @@ async fn select_with_alias_overwrite() -> Result<()> { let results = df.collect().await?; #[rustfmt::skip] - let expected = vec![ + let expected = vec![ "+-------+", "| a |", "+-------+", diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 6f7150d2a53b6..f47ceb5ce608f 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -756,10 +756,11 @@ async fn explain_logical_plan_only() { let expected = vec![ vec![ "logical_plan", - "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n SubqueryAlias: t\ - \n Projection: column1\ - \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))" + "Projection: COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n SubqueryAlias: t\ + \n Projection: column1\ + \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))" ]]; assert_eq!(expected, actual); } @@ -775,9 +776,9 @@ async fn explain_physical_plan_only() { let expected = vec![vec![ "physical_plan", - "ProjectionExec: expr=[2 as COUNT(UInt8(1))]\ - \n EmptyExec: produce_one_row=true\ - \n", + "ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\ + \n ProjectionExec: expr=[2 as COUNT(UInt8(1))]\ + \n EmptyExec: produce_one_row=true\n", ]]; assert_eq!(expected, actual); } diff --git a/datafusion/core/tests/sql/json.rs b/datafusion/core/tests/sql/json.rs index 965a9c14fc985..02bbb7ac91a42 100644 --- a/datafusion/core/tests/sql/json.rs +++ b/datafusion/core/tests/sql/json.rs @@ -82,18 +82,19 @@ async fn json_explain() { let actual = normalize_vec_for_explain(actual); let expected = vec![ vec![ - "logical_plan", - "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: t1 projection=[a]", + "logical_plan", "Projection: COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: t1 projection=[a]" ], vec![ - "physical_plan", - "AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\ - \n CoalescePartitionsExec\ - \n AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\ - \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1\ - \n JsonExec: limit=None, files={1 group: [[WORKING_DIR/tests/jsons/2.json]]}\n", - ], + "physical_plan", + "ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\ + \n AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\ + \n CoalescePartitionsExec\ + \n AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\ + \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1\ + \n JsonExec: limit=None, files={1 group: [[WORKING_DIR/tests/jsons/2.json]]}\ + \n" ], ]; assert_eq!(expected, actual); } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 4b4c603bcfe46..fb426eaa40864 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -16,24 +16,28 @@ // under the License. use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; +use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; +use datafusion_expr::Expr::Exists; +use datafusion_expr::{ + aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter, + LogicalPlan, Projection, Subquery, Window, +}; +use std::string::ToString; +use std::sync::Arc; use crate::analyzer::AnalyzerRule; use crate::rewrite::TreeNodeRewritable; +pub const COUNT_STAR: &str = "COUNT(*)"; + /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473. +#[derive(Default)] pub struct CountWildcardRule {} -impl Default for CountWildcardRule { - fn default() -> Self { - CountWildcardRule::new() - } -} - impl CountWildcardRule { pub fn new() -> Self { CountWildcardRule {} @@ -41,7 +45,7 @@ impl CountWildcardRule { } impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result { - plan.clone().transform_down(&analyze_internal) + Ok(plan.clone().transform_down(&analyze_internal).unwrap()) } fn name(&self) -> &str { @@ -50,35 +54,145 @@ impl AnalyzerRule for CountWildcardRule { } fn analyze_internal(plan: LogicalPlan) -> Result> { + let mut rewriter = CountWildcardRewriter {}; + match plan { LogicalPlan::Window(window) => { - let window_expr = handle_wildcard(&window.window_expr); + let window_expr = window + .window_expr + .iter() + .map(|expr| { + let name = expr.name(); + let variant_name = expr.variant_name(); + expr.clone().rewrite(&mut rewriter).unwrap() + }) + .collect::>(); + Ok(Some(LogicalPlan::Window(Window { input: window.input.clone(), window_expr, - schema: window.schema, + schema: rewrite_schema(window.schema), }))) } LogicalPlan::Aggregate(agg) => { - let aggr_expr = handle_wildcard(&agg.aggr_expr); + let aggr_expr = agg + .aggr_expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect(); Ok(Some(LogicalPlan::Aggregate( Aggregate::try_new_with_schema( agg.input.clone(), agg.group_expr.clone(), aggr_expr, - agg.schema, + rewrite_schema(agg.schema), + )?, + ))) + } + LogicalPlan::Projection(projection) => { + let projection_expr = projection + .expr + .iter() + .map(|expr| { + let name = expr.name(); + let variant_name = expr.variant_name(); + expr.clone().rewrite(&mut rewriter).unwrap() + }) + .collect(); + Ok(Some(LogicalPlan::Projection( + Projection::try_new_with_schema( + projection_expr, + projection.input, + rewrite_schema(projection.schema), )?, ))) } + LogicalPlan::Filter(Filter { + predicate, input, .. + }) => { + let predicate = match predicate { + Exists { subquery, negated } => { + let new_plan = subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal) + .unwrap(); + + Exists { + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + } + } + _ => predicate, + }; + + Ok(Some(LogicalPlan::Filter( + Filter::try_new(predicate, input).unwrap(), + ))) + } + _ => Ok(None), } } -// handle Count(Expr:Wildcard) with DataFrame API -pub fn handle_wildcard(exprs: &[Expr]) -> Vec { - exprs - .iter() - .map(|expr| match expr { +struct CountWildcardRewriter {} + +impl ExprRewriter for CountWildcardRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + let count_star: String = count(Expr::Wildcard).to_string(); + let old_expr = expr.clone(); + + let new_expr = match old_expr.clone() { + Expr::Column(Column { name, relation }) if name.contains(&count_star) => { + Expr::Column(Column { + name: name.replace( + &count_star, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ), + relation: relation.clone(), + }) + } + Expr::WindowFunction(expr::WindowFunction { + fun: + window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args, + partition_by, + order_by, + window_frame, + }) if args.len() == 1 => match args[0] { + Expr::Wildcard => { + Expr::WindowFunction(datafusion_expr::expr::WindowFunction { + fun: window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args: vec![lit(COUNT_STAR_EXPANSION)], + partition_by: partition_by.clone(), + order_by: order_by.clone(), + window_frame: window_frame.clone(), + }) + } + + _ => old_expr.clone(), + }, + Expr::WindowFunction(expr::WindowFunction { + fun: + window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args, + partition_by, + order_by, + window_frame, + }) => { + println!("hahahhaha {}", args[0]); + old_expr.clone() + } Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, args, @@ -88,12 +202,39 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec { Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, args: vec![lit(COUNT_STAR_EXPANSION)], - distinct: *distinct, + distinct, filter: filter.clone(), }), - _ => expr.clone(), + _ => old_expr.clone(), }, - _ => expr.clone(), + _ => old_expr.clone(), + }; + Ok(new_expr) + } +} + +fn rewrite_schema(schema: DFSchemaRef) -> DFSchemaRef { + let new_fields = schema + .fields() + .iter() + .map(|DFField { qualifier, field }| { + let mut name = field.name().clone(); + if name.contains(COUNT_STAR.clone()) { + name = name.replace( + COUNT_STAR, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ) + } + DFField::new( + qualifier.clone(), + name.as_str(), + field.data_type().clone(), + field.is_nullable(), + ) }) - .collect() + .collect::>(); + + DFSchemaRef::new( + DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(), + ) } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e5af0eb26976c..f529726fbdd52 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -18,7 +18,6 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; use datafusion_common::{DFSchema, DataFusionError, Result}; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::window_frame::regularize; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, @@ -216,12 +215,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Special case rewrite COUNT(*) to COUNT(constant) AggregateFunction::Count => args .into_iter() - .map(|a| match a { - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) - } - _ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context), - }) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context)) .collect::>>()?, _ => self.function_args_to_expr(args, schema, planner_context)?, };