Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rewrite SimplifyExprs to avoid a Plan copy
Browse files Browse the repository at this point in the history
fixup
alamb committed Apr 7, 2024

Verified

This commit was signed with the committer’s verified signature.
TimoGlastra Timo Glastra
1 parent 3784f76 commit dacbfa6
Showing 2 changed files with 99 additions and 78 deletions.
21 changes: 11 additions & 10 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;
@@ -107,14 +108,14 @@ fn test_table_scan() -> LogicalPlan {
.expect("building plan")
}

fn get_optimized_plan_formatted(plan: &LogicalPlan, date_time: &DateTime<Utc>) -> String {
fn get_optimized_plan_formatted(plan: LogicalPlan, date_time: &DateTime<Utc>) -> String {
let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
let rule = SimplifyExpressions::new();

let optimized_plan = rule
.try_optimize(plan, &config)
.unwrap()
.expect("failed to optimize plan");
// Use Optimizer to do plan traversal
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]);
let optimized_plan = optimizer.optimize(plan, &config, observe).unwrap();

format!("{optimized_plan:?}")
}

@@ -236,7 +237,7 @@ fn to_timestamp_expr_folded() -> Result<()> {
let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\
\n TableScan: test"
.to_string();
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
let actual = get_optimized_plan_formatted(plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}
@@ -260,7 +261,7 @@ fn now_less_than_timestamp() -> Result<()> {
// expression down to a single constant (true)
let expected = "Filter: Boolean(true)\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);

assert_eq!(expected, actual);
Ok(())
@@ -288,7 +289,7 @@ fn select_date_plus_interval() -> Result<()> {
// expression down to a single constant (true)
let expected = r#"Projection: Date32("18636") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408")
TableScan: test"#;
let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);

assert_eq!(expected, actual);
Ok(())
@@ -420,7 +421,7 @@ fn multiple_now() -> Result<()> {
.build()?;

// expect the same timestamp appears in both exprs
let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);
let expected = format!(
"Projection: TimestampNanosecond({}, Some(\"+00:00\")) AS now(), TimestampNanosecond({}, Some(\"+00:00\")) AS t2\
\n TableScan: test",
156 changes: 88 additions & 68 deletions datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
Original file line number Diff line number Diff line change
@@ -19,12 +19,14 @@
use std::sync::Arc;

use datafusion_common::{DFSchema, DFSchemaRef, Result};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::logical_plan::LogicalPlan;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::utils::merge_schema;

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

use super::ExprSimplifier;
@@ -46,29 +48,47 @@ use super::ExprSimplifier;
pub struct SimplifyExpressions {}

impl OptimizerRule for SimplifyExpressions {
fn try_optimize(
&self,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
internal_err!("Should have called SimplifyExpressions::try_optimize_owned")
}

fn name(&self) -> &str {
"simplify_expressions"
}

fn try_optimize(
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}

fn supports_owned(&self) -> bool {
true
}

/// if supports_owned returns true, the Optimizer calls
/// [`Self::try_optimize_owned`] instead of [`Self::try_optimize`]
fn try_optimize_owned(
&self,
plan: &LogicalPlan,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
let mut execution_props = ExecutionProps::new();
execution_props.query_execution_start_time = config.query_execution_start_time();
Ok(Some(Self::optimize_internal(plan, &execution_props)?))
Self::optimize_internal(plan, &execution_props)
}
}

impl SimplifyExpressions {
fn optimize_internal(
plan: &LogicalPlan,
plan: LogicalPlan,
execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let schema = if !plan.inputs().is_empty() {
DFSchemaRef::new(merge_schema(plan.inputs()))
} else if let LogicalPlan::TableScan(scan) = plan {
} else if let LogicalPlan::TableScan(scan) = &plan {
// When predicates are pushed into a table scan, there is no input
// schema to resolve predicates against, so it must be handled specially
//
@@ -86,13 +106,11 @@ impl SimplifyExpressions {
} else {
Arc::new(DFSchema::empty())
};

let info = SimplifyContext::new(execution_props).with_schema(schema);

let new_inputs = plan
.inputs()
.iter()
.map(|input| Self::optimize_internal(input, execution_props))
.collect::<Result<Vec<_>>>()?;
// Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
// Just need to rewrite our own expressions

let simplifier = ExprSimplifier::new(info);

@@ -109,18 +127,22 @@ impl SimplifyExpressions {
simplifier
};

let exprs = plan
.expressions()
.into_iter()
.map(|e| {
// the output schema of a filter or join is the input schema. Thus they
// can't handle aliased expressions
let use_alias = !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_));
plan.map_expressions(|e| {
let new_e = if use_alias {
// TODO: unify with `rewrite_preserving_name`
let original_name = e.name_for_alias()?;
let new_e = simplifier.simplify(e)?;
new_e.alias_if_changed(original_name)
})
.collect::<Result<Vec<_>>>()?;
simplifier.simplify(e)?.alias_if_changed(original_name)
} else {
simplifier.simplify(e)
}?;

plan.with_new_exprs(exprs, new_inputs)
// TODO it would be nice to have a way to know if the expression was simplified
// or not. For now conservatively return Transformed::yes
Ok(Transformed::yes(new_e))
})
}
}

@@ -138,6 +160,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use chrono::{DateTime, Utc};

use crate::optimizer::Optimizer;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
@@ -165,12 +188,12 @@ mod tests {
.expect("building plan")
}

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
let rule = SimplifyExpressions::new();
let optimized_plan = rule
.try_optimize(plan, &OptimizerContext::new())
.unwrap()
.expect("failed to optimize plan");
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
// Use Optimizer to do plan traversal
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]);
let optimized_plan =
optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
Ok(())
@@ -198,7 +221,7 @@ mod tests {

let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]";

assert_optimized_plan_eq(&table_scan, expected)
assert_optimized_plan_eq(table_scan, expected)
}

#[test]
@@ -210,7 +233,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.b > Int32(1)\
\n Projection: test.a\
@@ -227,7 +250,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.b > Int32(1)\
\n Projection: test.a\
@@ -244,7 +267,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.b > Int32(1)\
\n Projection: test.a\
@@ -265,7 +288,7 @@ mod tests {
.build()?;

assert_optimized_plan_eq(
&plan,
plan,
"\
Filter: test.a > Int32(5) AND test.b < Int32(6)\
\n Projection: test.a, test.b\
@@ -288,7 +311,7 @@ mod tests {
\n Filter: test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -308,7 +331,7 @@ mod tests {
\n Filter: NOT test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -324,7 +347,7 @@ mod tests {
\n Filter: NOT test.b AND test.c\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -340,7 +363,7 @@ mod tests {
\n Filter: NOT test.b OR NOT test.c\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -356,7 +379,7 @@ mod tests {
\n Filter: test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -370,7 +393,7 @@ mod tests {
Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -392,7 +415,7 @@ mod tests {
\n Projection: test.a, test.c, test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -413,20 +436,17 @@ mod tests {
let expected = "\
Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

fn get_optimized_plan_formatted(
plan: &LogicalPlan,
plan: LogicalPlan,
date_time: &DateTime<Utc>,
) -> String {
let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
let rule = SimplifyExpressions::new();

let optimized_plan = rule
.try_optimize(plan, &config)
.unwrap()
.expect("failed to optimize plan");
let optimized_plan = rule.try_optimize_owned(plan, &config).unwrap().data;
format!("{optimized_plan:?}")
}

@@ -440,7 +460,7 @@ mod tests {

let expected = "Projection: Int32(0) AS Utf8(\"0\")\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
let actual = get_optimized_plan_formatted(plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}
@@ -457,7 +477,7 @@ mod tests {
.project(proj)?
.build()?;

let actual = get_optimized_plan_formatted(&plan, &time);
let actual = get_optimized_plan_formatted(plan, &time);
let expected =
"Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\
\n TableScan: test";
@@ -476,7 +496,7 @@ mod tests {
let expected = "Filter: test.d <= Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -489,7 +509,7 @@ mod tests {
let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -502,7 +522,7 @@ mod tests {
let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -515,7 +535,7 @@ mod tests {
let expected = "Filter: test.d > Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -528,7 +548,7 @@ mod tests {
let expected = "Filter: test.e IS NOT NULL\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -541,7 +561,7 @@ mod tests {
let expected = "Filter: test.e IS NULL\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -555,7 +575,7 @@ mod tests {
"Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -569,7 +589,7 @@ mod tests {
"Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -583,7 +603,7 @@ mod tests {
let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -597,7 +617,7 @@ mod tests {
let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -617,7 +637,7 @@ mod tests {
let expected = "Filter: test.a NOT LIKE test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -637,7 +657,7 @@ mod tests {
let expected = "Filter: test.a LIKE test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -657,7 +677,7 @@ mod tests {
let expected = "Filter: test.a NOT ILIKE test.b\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -670,7 +690,7 @@ mod tests {
let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -683,7 +703,7 @@ mod tests {
let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -709,7 +729,7 @@ mod tests {
\n TableScan: t1\
\n TableScan: t2";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -725,7 +745,7 @@ mod tests {
let expected = "Projection: test.f AS power(test.f,Float64(1))\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -745,7 +765,7 @@ mod tests {
// before simplify: t.g = power(t.f, 1.0)
// after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)"
let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]";
assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -758,7 +778,7 @@ mod tests {
let expected = "Filter: Boolean(true)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}

#[test]
@@ -771,6 +791,6 @@ mod tests {
let expected = "Filter: Boolean(false)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected)
assert_optimized_plan_eq(plan, expected)
}
}

0 comments on commit dacbfa6

Please sign in to comment.