Skip to content

Commit

Permalink
Improve formatting of binary expressions (#3884)
Browse files Browse the repository at this point in the history
* Respect operator precedence when writing binary expressions

* update tests

* update tests

* refactor to reduce duplicate code
  • Loading branch information
andygrove authored Oct 25, 2022
1 parent 7cba758 commit 7559c44
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 19 deletions.
6 changes: 3 additions & 3 deletions benchmarks/expected-plans/q19.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
Projection: lineitem.l_extendedprice, lineitem.l_discount
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
Inner Join: lineitem.l_partkey = part.p_partkey
Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15) AND part.p_size >= Int32(1)
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
11 changes: 8 additions & 3 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,7 @@ mod tests {
let expr = col("c1")
.lt(lit(1))
.and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3))));
let expected_expr = "c1_min < Int32(1) AND c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max";
let expected_expr = "c1_min < Int32(1) AND (c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max)";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut required_columns)?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
Expand Down Expand Up @@ -1560,7 +1560,9 @@ mod tests {
list: vec![lit(1), lit(2), lit(3)],
negated: true,
};
let expected_expr = "c1_min != Int32(1) OR Int32(1) != c1_max AND c1_min != Int32(2) OR Int32(2) != c1_max AND c1_min != Int32(3) OR Int32(3) != c1_max";
let expected_expr = "(c1_min != Int32(1) OR Int32(1) != c1_max) \
AND (c1_min != Int32(2) OR Int32(2) != c1_max) \
AND (c1_min != Int32(3) OR Int32(3) != c1_max)";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
Expand Down Expand Up @@ -1632,7 +1634,10 @@ mod tests {
],
negated: true,
};
let expected_expr = "CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64)";
let expected_expr =
"(CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64)) \
AND (CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64)) \
AND (CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64))";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
Expand Down
58 changes: 54 additions & 4 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_common::Result;
use datafusion_common::{plan_err, Column};
use datafusion_common::{DataFusionError, ScalarValue};
use std::fmt;
use std::fmt::Write;
use std::fmt::{Display, Formatter, Write};
use std::hash::{BuildHasher, Hash, Hasher};
use std::ops::Not;
use std::sync::Arc;
Expand Down Expand Up @@ -260,6 +260,58 @@ impl BinaryExpr {
pub fn new(left: Box<Expr>, op: Operator, right: Box<Expr>) -> Self {
Self { left, op, right }
}

/// Get the operator precedence
/// use https://www.postgresql.org/docs/7.0/operators.htm#AEN2026 as a reference
pub fn precedence(&self) -> u8 {
match self.op {
Operator::Or => 5,
Operator::And => 10,
Operator::Like | Operator::NotLike => 19,
Operator::NotEq
| Operator::Eq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq => 20,
Operator::Plus | Operator::Minus => 30,
Operator::Multiply | Operator::Divide | Operator::Modulo => 40,
_ => 0,
}
}
}

impl Display for BinaryExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
// Put parentheses around child binary expressions so that we can see the difference
// between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
// based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
// equivalent and the parentheses are not necessary.

fn write_child(
f: &mut Formatter<'_>,
expr: &Expr,
precedence: u8,
) -> fmt::Result {
match expr {
Expr::BinaryExpr(child) => {
let p = child.precedence();
if p == 0 || p < precedence {
write!(f, "({})", child)?;
} else {
write!(f, "{}", child)?;
}
}
_ => write!(f, "{}", expr)?,
}
Ok(())
}

let precedence = self.precedence();
write_child(f, self.left.as_ref(), precedence)?;
write!(f, " {} ", self.op)?;
write_child(f, self.right.as_ref(), precedence)
}
}

/// CASE expression
Expand Down Expand Up @@ -728,9 +780,7 @@ impl fmt::Debug for Expr {
negated: false,
} => write!(f, "{:?} IN ({:?})", expr, subquery),
Expr::ScalarSubquery(subquery) => write!(f, "({:?})", subquery),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
write!(f, "{:?} {} {:?}", left, op, right)
}
Expr::BinaryExpr(expr) => write!(f, "{}", expr),
Expr::Sort {
expr,
asc,
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ mod test {
)?;

let expected = vec![
(9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(3, "a + Int32(1)Int32(1)a"),
Expand Down Expand Up @@ -671,8 +671,8 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * Int32(1) + test.c)]]\
\n Projection: test.a * Int32(1) - test.b AS test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
\n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ mod tests {
let expected = "\
Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: test.a * Int32(2) + test.c * Int32(3) = Int64(1)\
\n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/reduce_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ mod tests {
.build()?;

let expected = vec![
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
Expand Down Expand Up @@ -936,7 +936,7 @@ mod tests {
.build()?;

let expected = vec![
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/subquery_filter_to_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ mod tests {
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n Filter: test.a = UInt32(1) OR test.b IN (<subquery>) AND test.c IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Filter: (test.a = UInt32(1) OR test.b IN (<subquery>)) AND test.c IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [c:UInt32]\
\n Projection: sq1.c [c:UInt32]\
\n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3414,7 +3414,7 @@ mod tests {
#[test]
fn select_binary_expr_nested() {
let sql = "SELECT (age + salary)/2 from person";
let expected = "Projection: person.age + person.salary / Int64(2)\
let expected = "Projection: (person.age + person.salary) / Int64(2)\
\n TableScan: person";
quick_test(sql, expected);
}
Expand Down Expand Up @@ -3849,7 +3849,7 @@ mod tests {
fn select_where_nullif_division() {
let sql = "SELECT c3/(c4+c5) \
FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1";
let expected = "Projection: aggregate_test_100.c3 / aggregate_test_100.c4 + aggregate_test_100.c5\
let expected = "Projection: aggregate_test_100.c3 / (aggregate_test_100.c4 + aggregate_test_100.c5)\
\n Filter: aggregate_test_100.c3 / nullif(aggregate_test_100.c4 + aggregate_test_100.c5, Int64(0)) > Float64(0.1)\
\n TableScan: aggregate_test_100";
quick_test(sql, expected);
Expand Down

0 comments on commit 7559c44

Please sign in to comment.