Skip to content

Commit

Permalink
fix more rules
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Nov 27, 2024
1 parent d63b6fc commit 7fa8610
Show file tree
Hide file tree
Showing 10 changed files with 583 additions and 422 deletions.
12 changes: 8 additions & 4 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1799,12 +1799,15 @@ impl Expr {
}

pub fn in_subquery(in_subquery: InSubquery) -> Self {
let stats = in_subquery.stats();
let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprInSubquery))
.merge(in_subquery.stats());
Expr::InSubquery(in_subquery, stats)
}

pub fn scalar_subquery(subquery: Subquery) -> Self {
let stats = subquery.stats();
let stats =
LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprScalarSubquery))
.merge(subquery.stats());
Expr::ScalarSubquery(subquery, stats)
}

Expand Down Expand Up @@ -1919,7 +1922,8 @@ impl Expr {
}

pub fn exists(exists: Exists) -> Self {
let stats = exists.stats();
let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprExists))
.merge(exists.stats());
Expr::Exists(exists, stats)
}

Expand All @@ -1939,7 +1943,7 @@ impl Expr {
}

pub fn placeholder(placeholder: Placeholder) -> Self {
let stats = LogicalPlanStats::empty();
let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprPlaceholder));
Expr::Placeholder(placeholder, stats)
}

Expand Down
124 changes: 100 additions & 24 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Logical plan types
use std::cell::Cell;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
Expand Down Expand Up @@ -674,6 +675,13 @@ impl LogicalPlan {
let mut using_columns: Vec<HashSet<Column>> = vec![];

self.apply_with_subqueries(|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::LogicalPlanJoin)
{
return Ok(TreeNodeRecursion::Jump);
}

if let LogicalPlan::Join(
Join {
join_constraint: JoinConstraint::Using,
Expand Down Expand Up @@ -1693,40 +1701,94 @@ impl LogicalPlan {
self,
param_values: &ParamValues,
) -> Result<LogicalPlan> {
self.transform_up_with_subqueries(|plan| {
let schema = Arc::clone(plan.schema());
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
let (e, has_placeholder) = e.infer_placeholder_types(&schema)?;
if !has_placeholder {
// Performance optimization:
// avoid NamePreserver copy and second pass over expression
// if no placeholders.
Ok(Transformed::no(e))
} else {
let original_name = name_preserver.save(&e);
let transformed_expr = e.transform_up(|e| {
if let Expr::Placeholder(Placeholder { id, .. }, _) = e {
let value = param_values.get_placeholders_with_values(&id)?;
Ok(Transformed::yes(Expr::literal(value)))
} else {
Ok(Transformed::no(e))
}
})?;
// Preserve name to avoid breaking column references to this expression
Ok(transformed_expr.update_data(|expr| original_name.restore(expr)))
let skip = Cell::new(false);
self.transform_down_up_with_subqueries(
|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
skip.set(true);
return Ok(Transformed::jump(plan));
}
})
})

Ok(Transformed::no(plan))
},
|plan| {
if skip.get() {
skip.set(false);
return Ok(Transformed::no(plan));
}

let schema = Arc::clone(plan.schema());
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
let (e, has_placeholder) = e.infer_placeholder_types(&schema)?;
if !has_placeholder {
// Performance optimization:
// avoid NamePreserver copy and second pass over expression
// if no placeholders.
Ok(Transformed::no(e))
} else {
let original_name = name_preserver.save(&e);
let skip = Cell::new(false);
let transformed_expr = e.transform_down_up(
|e| {
if !e
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
skip.set(true);
return Ok(Transformed::jump(e));
}

Ok(Transformed::no(e))
},
|e| {
if skip.get() {
skip.set(false);
return Ok(Transformed::no(e));
}

if let Expr::Placeholder(Placeholder { id, .. }, _) = e {
let value =
param_values.get_placeholders_with_values(&id)?;
Ok(Transformed::yes(Expr::literal(value)))
} else {
Ok(Transformed::no(e))
}
},
)?;
// Preserve name to avoid breaking column references to this expression
Ok(transformed_expr
.update_data(|expr| original_name.restore(expr)))
}
})
},
)
.map(|res| res.data)
}

/// Walk the logical plan, find any `Placeholder` tokens, and return a set of their names.
pub fn get_parameter_names(&self) -> Result<HashSet<String>> {
let mut param_names = HashSet::new();
self.apply_with_subqueries(|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

plan.apply_expressions(|expr| {
expr.apply(|expr| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

if let Expr::Placeholder(Placeholder { id, .. }, _) = expr {
param_names.insert(id.clone());
}
Expand All @@ -1744,8 +1806,22 @@ impl LogicalPlan {
let mut param_types: HashMap<String, Option<DataType>> = HashMap::new();

self.apply_with_subqueries(|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

plan.apply_expressions(|expr| {
expr.apply(|expr| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

if let Expr::Placeholder(Placeholder { id, data_type }, _) = expr {
let prev = param_types.get(id);
match (prev, data_type) {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ pub enum LogicalPlanPattern {
ExprAggregateFunction,
ExprWindowFunction,
ExprInList,
// ExprExists,
// ExprInSubquery,
// ExprScalarSubquery,
ExprExists,
ExprInSubquery,
ExprScalarSubquery,
// ExprWildcard,
// ExprGroupingSet,
// ExprPlaceholder,
ExprPlaceholder,
// ExprOuterReferenceColumn,
// ExprUnnest,

Expand Down
84 changes: 58 additions & 26 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
// under the License.

use crate::analyzer::AnalyzerRule;
use enumset::enum_set;
use std::cell::Cell;

use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{AggregateFunction, WindowFunction};
use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

Expand All @@ -39,7 +42,61 @@ impl CountWildcardRule {

impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down_with_subqueries(analyze_internal).data()
plan.transform_down_with_subqueries(|plan| {
if !plan.stats().contains_any_patterns(enum_set!(
LogicalPlanPattern::ExprWindowFunction
| LogicalPlanPattern::ExprAggregateFunction
)) {
return Ok(Transformed::jump(plan));
}

let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr);
let skip = Cell::new(false);
let transformed_expr = expr.transform_down_up(
|expr| {
if !expr.stats().contains_any_patterns(enum_set!(
LogicalPlanPattern::ExprWindowFunction
| LogicalPlanPattern::ExprAggregateFunction
)) {
skip.set(true);
return Ok(Transformed::jump(expr));
}

Ok(Transformed::no(expr))
},
|expr| {
if skip.get() {
skip.set(false);
return Ok(Transformed::no(expr));
}

match expr {
Expr::WindowFunction(mut window_function, _)
if is_count_star_window_aggregate(&window_function) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::window_function(
window_function,
)))
}
Expr::AggregateFunction(mut aggregate_function, _)
if is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::aggregate_function(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
}
},
)?;
Ok(transformed_expr.update_data(|data| original_name.restore(data)))
})
})
.data()
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -67,31 +124,6 @@ fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr);
let transformed_expr = expr.transform_up(|expr| match expr {
Expr::WindowFunction(mut window_function, _)
if is_count_star_window_aggregate(&window_function) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::window_function(window_function)))
}
Expr::AggregateFunction(mut aggregate_function, _)
if is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::aggregate_function(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
})?;
Ok(transformed_expr.update_data(|data| original_name.restore(data)))
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 7fa8610

Please sign in to comment.