Skip to content

Commit

Permalink
remove unnecessary logic on sql count wildcard
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Mar 24, 2023
1 parent 68e3040 commit c9e610d
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 51 deletions.
4 changes: 2 additions & 2 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,9 +630,9 @@ impl ExprSchema for DFSchema {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DFField {
/// Optional qualifier (usually a table or relation name)
qualifier: Option<OwnedTableReference>,
pub qualifier: Option<OwnedTableReference>,
/// Arrow field definition
field: Field,
pub field: Field,
}

impl DFField {
Expand Down
4 changes: 1 addition & 3 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@ async fn count_wildcard() -> Result<()> {
let sql_results = ctx
.sql("select count(*) from alltypes_tiny_pages")
.await?
.select(vec![count(Expr::Wildcard)])?
.explain(false, false)?
.collect()
.await?;

// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
let df_results = ctx
.table("alltypes_tiny_pages")
Expand Down Expand Up @@ -452,7 +450,7 @@ async fn select_with_alias_overwrite() -> Result<()> {
let results = df.collect().await?;

#[rustfmt::skip]
let expected = vec![
let expected = vec![
"+-------+",
"| a |",
"+-------+",
Expand Down
15 changes: 8 additions & 7 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,11 @@ async fn explain_logical_plan_only() {
let expected = vec![
vec![
"logical_plan",
"Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n SubqueryAlias: t\
\n Projection: column1\
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
"Projection: COUNT(UInt8(1))\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n SubqueryAlias: t\
\n Projection: column1\
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
]];
assert_eq!(expected, actual);
}
Expand All @@ -775,9 +776,9 @@ async fn explain_physical_plan_only() {

let expected = vec![vec![
"physical_plan",
"ProjectionExec: expr=[2 as COUNT(UInt8(1))]\
\n EmptyExec: produce_one_row=true\
\n",
"ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\
\n ProjectionExec: expr=[2 as COUNT(UInt8(1))]\
\n EmptyExec: produce_one_row=true\n",
]];
assert_eq!(expected, actual);
}
21 changes: 11 additions & 10 deletions datafusion/core/tests/sql/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,19 @@ async fn json_explain() {
let actual = normalize_vec_for_explain(actual);
let expected = vec![
vec![
"logical_plan",
"Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n TableScan: t1 projection=[a]",
"logical_plan", "Projection: COUNT(UInt8(1))\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n TableScan: t1 projection=[a]"
],
vec![
"physical_plan",
"AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\
\n CoalescePartitionsExec\
\n AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1\
\n JsonExec: limit=None, files={1 group: [[WORKING_DIR/tests/jsons/2.json]]}\n",
],
"physical_plan",
"ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\
\n AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\
\n CoalescePartitionsExec\
\n AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1\
\n JsonExec: limit=None, files={1 group: [[WORKING_DIR/tests/jsons/2.json]]}\
\n" ],
];
assert_eq!(expected, actual);
}
185 changes: 163 additions & 22 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,36 @@
// under the License.

use datafusion_common::config::ConfigOptions;
use datafusion_common::Result;
use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window};
use datafusion_expr::Expr::Exists;
use datafusion_expr::{
aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter,
LogicalPlan, Projection, Subquery, Window,
};
use std::string::ToString;
use std::sync::Arc;

use crate::analyzer::AnalyzerRule;
use crate::rewrite::TreeNodeRewritable;

pub const COUNT_STAR: &str = "COUNT(*)";

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473.
#[derive(Default)]
pub struct CountWildcardRule {}

impl Default for CountWildcardRule {
fn default() -> Self {
CountWildcardRule::new()
}
}

impl CountWildcardRule {
pub fn new() -> Self {
CountWildcardRule {}
}
}
impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.clone().transform_down(&analyze_internal)
Ok(plan.clone().transform_down(&analyze_internal).unwrap())
}

fn name(&self) -> &str {
Expand All @@ -50,35 +54,145 @@ impl AnalyzerRule for CountWildcardRule {
}

fn analyze_internal(plan: LogicalPlan) -> Result<Option<LogicalPlan>> {
let mut rewriter = CountWildcardRewriter {};

match plan {
LogicalPlan::Window(window) => {
let window_expr = handle_wildcard(&window.window_expr);
let window_expr = window
.window_expr
.iter()
.map(|expr| {
let name = expr.name();
let variant_name = expr.variant_name();
expr.clone().rewrite(&mut rewriter).unwrap()
})
.collect::<Vec<Expr>>();

Ok(Some(LogicalPlan::Window(Window {
input: window.input.clone(),
window_expr,
schema: window.schema,
schema: rewrite_schema(window.schema),
})))
}
LogicalPlan::Aggregate(agg) => {
let aggr_expr = handle_wildcard(&agg.aggr_expr);
let aggr_expr = agg
.aggr_expr
.iter()
.map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
.collect();
Ok(Some(LogicalPlan::Aggregate(
Aggregate::try_new_with_schema(
agg.input.clone(),
agg.group_expr.clone(),
aggr_expr,
agg.schema,
rewrite_schema(agg.schema),
)?,
)))
}
LogicalPlan::Projection(projection) => {
let projection_expr = projection
.expr
.iter()
.map(|expr| {
let name = expr.name();
let variant_name = expr.variant_name();
expr.clone().rewrite(&mut rewriter).unwrap()
})
.collect();
Ok(Some(LogicalPlan::Projection(
Projection::try_new_with_schema(
projection_expr,
projection.input,
rewrite_schema(projection.schema),
)?,
)))
}
LogicalPlan::Filter(Filter {
predicate, input, ..
}) => {
let predicate = match predicate {
Exists { subquery, negated } => {
let new_plan = subquery
.subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)
.unwrap();

Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
},
negated,
}
}
_ => predicate,
};

Ok(Some(LogicalPlan::Filter(
Filter::try_new(predicate, input).unwrap(),
)))
}

_ => Ok(None),
}
}

// handle Count(Expr:Wildcard) with DataFrame API
pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
exprs
.iter()
.map(|expr| match expr {
struct CountWildcardRewriter {}

impl ExprRewriter for CountWildcardRewriter {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let count_star: String = count(Expr::Wildcard).to_string();
let old_expr = expr.clone();

let new_expr = match old_expr.clone() {
Expr::Column(Column { name, relation }) if name.contains(&count_star) => {
Expr::Column(Column {
name: name.replace(
&count_star,
count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
),
relation: relation.clone(),
})
}
Expr::WindowFunction(expr::WindowFunction {
fun:
window_function::WindowFunction::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args,
partition_by,
order_by,
window_frame,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard => {
Expr::WindowFunction(datafusion_expr::expr::WindowFunction {
fun: window_function::WindowFunction::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args: vec![lit(COUNT_STAR_EXPANSION)],
partition_by: partition_by.clone(),
order_by: order_by.clone(),
window_frame: window_frame.clone(),
})
}

_ => old_expr.clone(),
},
Expr::WindowFunction(expr::WindowFunction {
fun:
window_function::WindowFunction::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args,
partition_by,
order_by,
window_frame,
}) => {
println!("hahahhaha {}", args[0]);
old_expr.clone()
}
Expr::AggregateFunction(AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
args,
Expand All @@ -88,12 +202,39 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
args: vec![lit(COUNT_STAR_EXPANSION)],
distinct: *distinct,
distinct,
filter: filter.clone(),
}),
_ => expr.clone(),
_ => old_expr.clone(),
},
_ => expr.clone(),
_ => old_expr.clone(),
};
Ok(new_expr)
}
}

fn rewrite_schema(schema: DFSchemaRef) -> DFSchemaRef {
let new_fields = schema
.fields()
.iter()
.map(|DFField { qualifier, field }| {
let mut name = field.name().clone();
if name.contains(COUNT_STAR.clone()) {
name = name.replace(
COUNT_STAR,
count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
)
}
DFField::new(
qualifier.clone(),
name.as_str(),
field.data_type().clone(),
field.is_nullable(),
)
})
.collect()
.collect::<Vec<DFField>>();

DFSchemaRef::new(
DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(),
)
}
8 changes: 1 addition & 7 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use crate::utils::normalize_ident;
use datafusion_common::{DFSchema, DataFusionError, Result};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::window_frame::regularize;
use datafusion_expr::{
expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame,
Expand Down Expand Up @@ -216,12 +215,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// Special case rewrite COUNT(*) to COUNT(constant)
AggregateFunction::Count => args
.into_iter()
.map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
}
_ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context),
})
.map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context))
.collect::<Result<Vec<Expr>>>()?,
_ => self.function_args_to_expr(args, schema, planner_context)?,
};
Expand Down

0 comments on commit c9e610d

Please sign in to comment.