Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce OptimizerRule::rewrite to rewrite in place, Rewrite `Simp…
Browse files Browse the repository at this point in the history
…lifyExprs` to avoid copies
alamb committed Apr 10, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 03d8ba1 commit e31a358
Showing 3 changed files with 139 additions and 88 deletions.
25 changes: 13 additions & 12 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ use datafusion_expr::{
LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_functions::math;
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;
@@ -109,14 +110,14 @@ fn test_table_scan() -> LogicalPlan {
.expect("building plan")
}

fn get_optimized_plan_formatted(plan: &LogicalPlan, date_time: &DateTime<Utc>) -> String {
fn get_optimized_plan_formatted(plan: LogicalPlan, date_time: &DateTime<Utc>) -> String {
let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
let rule = SimplifyExpressions::new();

let optimized_plan = rule
.try_optimize(plan, &config)
.unwrap()
.expect("failed to optimize plan");
// Use Optimizer to do plan traversal
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]);
let optimized_plan = optimizer.optimize(plan, &config, observe).unwrap();

format!("{optimized_plan:?}")
}

@@ -238,7 +239,7 @@ fn to_timestamp_expr_folded() -> Result<()> {
let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\
\n TableScan: test"
.to_string();
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
let actual = get_optimized_plan_formatted(plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}
@@ -262,7 +263,7 @@ fn now_less_than_timestamp() -> Result<()> {
// expression down to a single constant (true)
let expected = "Filter: Boolean(true)\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);

assert_eq!(expected, actual);
Ok(())
@@ -290,7 +291,7 @@ fn select_date_plus_interval() -> Result<()> {
// expression down to a single constant (true)
let expected = r#"Projection: Date32("18636") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408")
TableScan: test"#;
let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);

assert_eq!(expected, actual);
Ok(())
@@ -308,7 +309,7 @@ fn simplify_project_scalar_fn() -> Result<()> {
// after simplify: t.f as "power(t.f, 1.0)"
let expected = "Projection: test.f AS power(test.f,Float64(1))\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
let actual = get_optimized_plan_formatted(plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}
@@ -330,7 +331,7 @@ fn simplify_scan_predicate() -> Result<()> {
// before simplify: t.g = power(t.f, 1.0)
// after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)"
let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
let actual = get_optimized_plan_formatted(plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}
@@ -461,7 +462,7 @@ fn multiple_now() -> Result<()> {
.build()?;

// expect the same timestamp appears in both exprs
let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);
let expected = format!(
"Projection: TimestampNanosecond({}, Some(\"+00:00\")) AS now(), TimestampNanosecond({}, Some(\"+00:00\")) AS t2\
\n TableScan: test",
50 changes: 40 additions & 10 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ use datafusion_common::alias::AliasGenerator;
use datafusion_common::config::ConfigOptions;
use datafusion_common::instant::Instant;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{DFSchema, DataFusionError, Result};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result};
use datafusion_expr::logical_plan::LogicalPlan;

use crate::common_subexpr_eliminate::CommonSubexprEliminate;
@@ -69,8 +69,12 @@ use crate::utils::log_plan;
/// [`SessionState::add_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionState.html#method.add_optimizer_rule
pub trait OptimizerRule {
/// Try and rewrite `plan` to an optimized form, returning None if the plan cannot be
/// optimized by this rule.
/// Try and rewrite `plan` to an optimized form, returning None if the plan
/// cannot be optimized by this rule.
///
/// Note this API will be deprecated in the future as it requires `clone`ing
/// the input plan, which can be expensive. OptimizerRules should implement
/// [`Self::rewrite`] instead.
fn try_optimize(
&self,
plan: &LogicalPlan,
@@ -80,12 +84,31 @@ pub trait OptimizerRule {
/// A human readable name for this optimizer rule
fn name(&self) -> &str;

/// How should the rule be applied by the optimizer? See comments on [`ApplyOrder`] for details.
/// How should the rule be applied by the optimizer? See comments on
/// [`ApplyOrder`] for details.
///
/// If a rule use default None, it should traverse recursively plan inside itself
/// If returns `None`, the default, the rule must handle recursion itself
fn apply_order(&self) -> Option<ApplyOrder> {
None
}

/// Does this rule support rewriting owned plans (rather than by reference)?
fn supports_rewrite(&self) -> bool {
false
}

/// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
/// if the plan was rewritten and `Transformed::no` if it was not.
///
/// Note: this function is only called if [`Self::supports_rewrite`] returns
/// true. Otherwise the Optimizer calls [`Self::try_optimize`]
fn rewrite(
&self,
_plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
internal_err!("rewrite is not implemented for {}", self.name())
}
}

/// Options to control the DataFusion Optimizer.
@@ -298,12 +321,19 @@ fn optimize_plan_node(
rule: &dyn OptimizerRule,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
// TODO: add API to OptimizerRule to allow rewriting by ownership
rule.try_optimize(&plan, config)
.map(|maybe_plan| match maybe_plan {
Some(new_plan) => Transformed::yes(new_plan),
if rule.supports_rewrite() {
return rule.rewrite(plan, config);
}

rule.try_optimize(&plan, config).map(|maybe_plan| {
match maybe_plan {
Some(new_plan) => {
// if the node was rewritten by the optimizer, replace the node
Transformed::yes(new_plan)
}
None => Transformed::no(plan),
})
}
})
}

impl Optimizer {
152 changes: 86 additions & 66 deletions datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
Original file line number Diff line number Diff line change
@@ -19,12 +19,14 @@
use std::sync::Arc;

use datafusion_common::{DFSchema, DFSchemaRef, Result};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::logical_plan::LogicalPlan;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::utils::merge_schema;

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

use super::ExprSimplifier;
@@ -46,29 +48,47 @@ use super::ExprSimplifier;
pub struct SimplifyExpressions {}

impl OptimizerRule for SimplifyExpressions {
fn try_optimize(
&self,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
internal_err!("Should have called SimplifyExpressions::try_optimize_owned")
}

fn name(&self) -> &str {
"simplify_expressions"
}

fn try_optimize(
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}

fn supports_rewrite(&self) -> bool {
true
}

/// if supports_owned returns true, the Optimizer calls
/// [`Self::rewrite`] instead of [`Self::try_optimize`]
fn rewrite(
&self,
plan: &LogicalPlan,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
let mut execution_props = ExecutionProps::new();
execution_props.query_execution_start_time = config.query_execution_start_time();
Ok(Some(Self::optimize_internal(plan, &execution_props)?))
Self::optimize_internal(plan, &execution_props)
}
}

impl SimplifyExpressions {
fn optimize_internal(
plan: &LogicalPlan,
plan: LogicalPlan,
execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let schema = if !plan.inputs().is_empty() {
DFSchemaRef::new(merge_schema(plan.inputs()))
} else if let LogicalPlan::TableScan(scan) = plan {
} else if let LogicalPlan::TableScan(scan) = &plan {
// When predicates are pushed into a table scan, there is no input
// schema to resolve predicates against, so it must be handled specially
//
@@ -86,13 +106,11 @@ impl SimplifyExpressions {
} else {
Arc::new(DFSchema::empty())
};

let info = SimplifyContext::new(execution_props).with_schema(schema);

let new_inputs = plan
.inputs()
.iter()
.map(|input| Self::optimize_internal(input, execution_props))
.collect::<Result<Vec<_>>>()?;
// Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
// Just need to rewrite our own expressions

let simplifier = ExprSimplifier::new(info);

@@ -109,18 +127,22 @@ impl SimplifyExpressions {
simplifier
};

let exprs = plan
.expressions()
.into_iter()
.map(|e| {
// the output schema of a filter or join is the input schema. Thus they
// can't handle aliased expressions
let use_alias = !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_));
plan.map_expressions(|e| {
let new_e = if use_alias {
// TODO: unify with `rewrite_preserving_name`
let original_name = e.name_for_alias()?;
let new_e = simplifier.simplify(e)?;
new_e.alias_if_changed(original_name)
})
.collect::<Result<Vec<_>>>()?;
simplifier.simplify(e)?.alias_if_changed(original_name)
} else {
simplifier.simplify(e)
}?;

plan.with_new_exprs(exprs, new_inputs)
// TODO it would be nice to have a way to know if the expression was simplified
// or not. For now conservatively return Transformed::yes
Ok(Transformed::yes(new_e))
})
}
}

@@ -138,6 +160,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use chrono::{DateTime, Utc};

use crate::optimizer::Optimizer;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
@@ -165,12 +188,12 @@ mod tests {
.expect("building plan")
}

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
let rule = SimplifyExpressions::new();
let optimized_plan = rule
.try_optimize(plan, &OptimizerContext::new())
.unwrap()
.expect("failed to optimize plan");
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
// Use Optimizer to do plan traversal
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]);
let optimized_plan =
optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
Ok(())
@@ -198,7 +221,7 @@ mod tests {

let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]";

assert_optimized_plan_eq(&table_scan, expected)
assert_optimized_plan_eq(table_scan, expected)
}

#[test]
@@ -210,7 +233,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.b > Int32(1)\
\n Projection: test.a\
@@ -227,7 +250,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.b > Int32(1)\
\n Projection: test.a\
@@ -244,7 +267,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.b > Int32(1)\
\n Projection: test.a\
@@ -265,7 +288,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.a > Int32(5) AND test.b < Int32(6)\
\n Projection: test.a, test.b\
@@ -288,7 +311,7 @@ mod tests {
\n Filter: test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -308,7 +331,7 @@ mod tests {
\n Filter: NOT test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -324,7 +347,7 @@ mod tests {
\n Filter: NOT test.b AND test.c\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -340,7 +363,7 @@ mod tests {
\n Filter: NOT test.b OR NOT test.c\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -356,7 +379,7 @@ mod tests {
\n Filter: test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -370,7 +393,7 @@ mod tests {
Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -392,7 +415,7 @@ mod tests {
\n Projection: test.a, test.c, test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -413,20 +436,17 @@ mod tests {
let expected = "\
Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

fn get_optimized_plan_formatted(
plan: &LogicalPlan,
plan: LogicalPlan,
date_time: &DateTime<Utc>,
) -> String {
let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
let rule = SimplifyExpressions::new();

let optimized_plan = rule
.try_optimize(plan, &config)
.unwrap()
.expect("failed to optimize plan");
let optimized_plan = rule.rewrite(plan, &config).unwrap().data;
format!("{optimized_plan:?}")
}

@@ -440,7 +460,7 @@ mod tests {

let expected = "Projection: Int32(0) AS Utf8(\"0\")\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
let actual = get_optimized_plan_formatted(plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}
@@ -457,7 +477,7 @@ mod tests {
.project(proj)?
.build()?;

let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);
let expected =
"Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\
\n TableScan: test";
@@ -476,7 +496,7 @@ mod tests {
let expected = "Filter: test.d <= Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -489,7 +509,7 @@ mod tests {
let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -502,7 +522,7 @@ mod tests {
let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -515,7 +535,7 @@ mod tests {
let expected = "Filter: test.d > Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -528,7 +548,7 @@ mod tests {
let expected = "Filter: test.e IS NOT NULL\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -541,7 +561,7 @@ mod tests {
let expected = "Filter: test.e IS NULL\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -555,7 +575,7 @@ mod tests {
"Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -569,7 +589,7 @@ mod tests {
"Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -583,7 +603,7 @@ mod tests {
let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -597,7 +617,7 @@ mod tests {
let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -617,7 +637,7 @@ mod tests {
let expected = "Filter: test.a NOT LIKE test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -637,7 +657,7 @@ mod tests {
let expected = "Filter: test.a LIKE test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -657,7 +677,7 @@ mod tests {
let expected = "Filter: test.a NOT ILIKE test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -670,7 +690,7 @@ mod tests {
let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -683,7 +703,7 @@ mod tests {
let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -709,7 +729,7 @@ mod tests {
\n TableScan: t1\
\n TableScan: t2";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -722,7 +742,7 @@ mod tests {
let expected = "Filter: Boolean(true)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -735,6 +755,6 @@ mod tests {
let expected = "Filter: Boolean(false)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}
}

0 comments on commit e31a358

Please sign in to comment.