Skip to content

Commit

Permalink
Simplify small InListExpr (#4090)
Browse files Browse the repository at this point in the history
* Simplify small InListExpr

Simplify small InListExpr

* Tweak

Tweak

* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Co-authored-by: Andrew Lamb <[email protected]>

* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Co-authored-by: Andrew Lamb <[email protected]>

* Feedback

* Feedback

* Tweak

* Tweak

Tweak

* Fmt

* clippy

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
Dandandan and alamb authored Nov 4, 2022
1 parent 7e944ed commit 60f3ef6
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 5 deletions.
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Sort: lineitem.l_shipmode ASC NULLS LAST
Projection: lineitem.l_shipmode, SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count
Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]]
Inner Join: lineitem.l_orderkey = orders.o_orderkey
Filter: lineitem.l_shipmode IN ([Utf8("MAIL"), Utf8("SHIP")]) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
Filter: (lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode]
TableScan: orders projection=[o_orderkey, o_orderpriority]
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q19.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
Projection: lineitem.l_extendedprice, lineitem.l_discount
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
Inner Join: lineitem.l_partkey = part.p_partkey
Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
4 changes: 3 additions & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2040,7 +2040,9 @@ mod tests {
.build()?;
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false }";

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } } }";

let actual = format!("{:?}", execution_plan);
assert!(actual.contains(expected), "{}", actual);

Expand Down
74 changes: 74 additions & 0 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub struct ExprSimplifier<S> {
info: S,
}

const THRESHOLD_INLINE_INLIST: usize = 3;

impl<S: SimplifyInfo> ExprSimplifier<S> {
/// Create a new `ExprSimplifier` with the given `info` such as an
/// instance of [`SimplifyContext`]. See
Expand Down Expand Up @@ -365,7 +367,48 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
None => lit_bool_null(),
}
}
// expr IN () --> false
// expr NOT IN () --> true
Expr::InList {
expr,
list,
negated,
} if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => {
lit(negated)
}

// if expr is a single column reference:
// expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C)
Expr::InList {
expr,
list,
negated,
} if !list.is_empty()
&& (
// For lists with only 1 value we allow more complex expressions to be simplified
// e.g SUBSTR(c1, 2, 3) IN ('1') -> SUBSTR(c1, 2, 3) = '1'
// for more than one we avoid repeating this potentially expensive
// expressions
list.len() == 1
|| list.len() <= THRESHOLD_INLINE_INLIST
&& expr.try_into_col().is_ok()
) =>
{
let first_val = list[0].clone();
if negated {
list.into_iter()
.skip(1)
.fold((*expr.clone()).not_eq(first_val), |acc, y| {
(*expr.clone()).not_eq(y).and(acc)
})
} else {
list.into_iter()
.skip(1)
.fold((*expr.clone()).eq(first_val), |acc, y| {
(*expr.clone()).eq(y).or(acc)
})
}
}
//
// Rules for NotEq
//
Expand Down Expand Up @@ -1749,6 +1792,37 @@ mod tests {
assert_eq!(expected_expr, result);
}

#[test]
fn simplify_inlist() {
assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));

assert_eq!(
simplify(in_list(col("c1"), vec![lit(1)], false)),
col("c1").eq(lit(1))
);
assert_eq!(
simplify(in_list(col("c1"), vec![lit(1)], true)),
col("c1").not_eq(lit(1))
);

// more complex expressions can be simplified if list contains
// one element only
assert_eq!(
simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
(col("c1") * lit(10)).eq(lit(2))
);

assert_eq!(
simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
col("c1").eq(lit(2)).or(col("c1").eq(lit(1)))
);
assert_eq!(
simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
col("c1").not_eq(lit(2)).and(col("c1").not_eq(lit(1)))
);
}

#[test]
fn simplify_expr_bool_and() {
// col & true is always col
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ mod tests {
.unwrap()
.build()
.unwrap();
let expected = "Filter: test.d NOT IN ([Int32(1), Int32(2), Int32(3)])\
let expected =
"Filter: test.d != Int32(3) AND test.d != Int32(2) AND test.d != Int32(1)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -721,7 +722,8 @@ mod tests {
.unwrap()
.build()
.unwrap();
let expected = "Filter: test.d IN ([Int32(1), Int32(2), Int32(3)])\
let expected =
"Filter: test.d = Int32(3) OR test.d = Int32(2) OR test.d = Int32(1)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
Expand Down

0 comments on commit 60f3ef6

Please sign in to comment.