Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve InListSImplifier -- add test, commend and avoid clones #8971

Merged
merged 10 commits into from
Feb 2, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -3364,6 +3364,12 @@ mod tests {
);
assert_eq!(simplify(expr.clone()), lit(true));

// 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 This case shocks me (again) on how powerful this rule is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jayzhan211 gets all the credit for implementing it in #8949 (🏆)

let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
);
assert_eq!(simplify(expr.clone()), col("c1").not_eq(lit(4)));

// 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
Expand Down Expand Up @@ -3457,6 +3463,7 @@ mod tests {
true,
)));
// TODO: Further simplify this expression
// https://github.com/apache/arrow-datafusion/issues/8970
// assert_eq!(simplify(expr.clone()), lit(true));
assert_eq!(simplify(expr.clone()), expr);
}
Expand Down
133 changes: 70 additions & 63 deletions datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,85 +52,92 @@ impl TreeNodeRewriter for InListSimplifier {
type N = Expr;

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr {
if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) =
(left.as_ref(), op, right.as_ref())
{
if l1.expr == l2.expr && !l1.negated && !l2.negated {
return inlist_intersection(l1, l2, false);
} else if l1.expr == l2.expr && l1.negated && l2.negated {
return inlist_union(l1, l2, true);
} else if l1.expr == l2.expr && !l1.negated && l2.negated {
return inlist_except(l1, l2);
} else if l1.expr == l2.expr && l1.negated && !l2.negated {
return inlist_except(l2, l1);
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern allows getting an owned left and right and avoiding the clone

match (*left, op, *right) {
(Expr::InList(l1), Operator::And, Expr::InList(l2))
if l1.expr == l2.expr && !l1.negated && !l2.negated =>
{
inlist_intersection(l1, l2, false)
}
} else if let (Expr::InList(l1), Operator::Or, Expr::InList(l2)) =
(left.as_ref(), op, right.as_ref())
{
if l1.expr == l2.expr && l1.negated && l2.negated {
return inlist_intersection(l1, l2, true);
(Expr::InList(l1), Operator::And, Expr::InList(l2))
if l1.expr == l2.expr && l1.negated && l2.negated =>
{
inlist_union(l1, l2, true)
}
(Expr::InList(l1), Operator::And, Expr::InList(l2))
if l1.expr == l2.expr && !l1.negated && l2.negated =>
{
inlist_except(l1, l2)
}
(Expr::InList(l1), Operator::And, Expr::InList(l2))
if l1.expr == l2.expr && l1.negated && !l2.negated =>
{
inlist_except(l2, l1)
}
(Expr::InList(l1), Operator::Or, Expr::InList(l2))
if l1.expr == l2.expr && l1.negated && l2.negated =>
{
inlist_intersection(l1, l2, true)
}
(left, op, right) => {
// put the expression back together
Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(left),
op,
right: Box::new(right),
}))
}
}
} else {
Ok(expr)
}

Ok(expr)
}
}

fn inlist_union(l1: &InList, l2: &InList, negated: bool) -> Result<Expr> {
let mut seen: HashSet<Expr> = HashSet::new();
let list = l1
.list
.iter()
.chain(l2.list.iter())
.filter(|&e| seen.insert(e.to_owned()))
.cloned()
.collect::<Vec<_>>();
let merged_inlist = InList {
expr: l1.expr.clone(),
list,
negated,
};
Ok(Expr::InList(merged_inlist))
}
/// Return the union of two inlist expressions
/// maintaining the order of the elements in the two lists
fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions now take owned InList and modify l1 in place

// extend the list in l1 with the elements in l2 that are not already in l1
let l1_items: HashSet<_> = l1.list.iter().collect();

fn inlist_intersection(l1: &InList, l2: &InList, negated: bool) -> Result<Expr> {
let l1_set: HashSet<Expr> = l1.list.iter().cloned().collect();
let intersect_list: Vec<Expr> = l2
// keep all l2 items that do not also appear in l1
let keep_l2: Vec<_> = l2
.list
.iter()
.filter(|x| l1_set.contains(x))
.cloned()
.into_iter()
.filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
.collect();

l1.list.extend(keep_l2);
l1.negated = negated;
Ok(Expr::InList(l1))
}

/// Return the intersection of two inlist expressions
/// maintaining the order of the elements in the two lists
fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
let l2_items = l2.list.iter().collect::<HashSet<_>>();

// remove all items from l1 that are not in l2
l1.list.retain(|e| l2_items.contains(e));

// e in () is always false
// e not in () is always true
if intersect_list.is_empty() {
if l1.list.is_empty() {
return Ok(lit(negated));
}
let merged_inlist = InList {
expr: l1.expr.clone(),
list: intersect_list,
negated,
};
Ok(Expr::InList(merged_inlist))
Ok(Expr::InList(l1))
}

fn inlist_except(l1: &InList, l2: &InList) -> Result<Expr> {
let l2_set: HashSet<Expr> = l2.list.iter().cloned().collect();
let except_list: Vec<Expr> = l1
.list
.iter()
.filter(|x| !l2_set.contains(x))
.cloned()
.collect();
if except_list.is_empty() {
/// Return the all items in l1 that are not in l2
/// maintaining the order of the elements in the two lists
fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
let l2_items = l2.list.iter().collect::<HashSet<_>>();

// keep only items from l1 that are not in l2
l1.list.retain(|e| !l2_items.contains(e));

if l1.list.is_empty() {
return Ok(lit(false));
}
let merged_inlist = InList {
expr: l1.expr.clone(),
list: except_list,
negated: false,
};
Ok(Expr::InList(merged_inlist))
Ok(Expr::InList(l1))
}
Loading