Skip to content

Commit

Permalink
[refactor]: Convert Vec<PhysicalExpr> to HashSet<PhysicalExpr> (apach…
Browse files Browse the repository at this point in the history
…e#13612)

* Initial commit

* Change implementation to take iterator

* Minor changes

* Update datafusion/physical-expr/src/equivalence/class.rs

Co-authored-by: Alex Huang <[email protected]>

* Remove leftover comment

---------

Co-authored-by: Alex Huang <[email protected]>
  • Loading branch information
akurmustafa and Weijun-H authored Dec 3, 2024
1 parent 86740bf commit fdb221f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 82 deletions.
24 changes: 19 additions & 5 deletions datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,24 @@ pub fn with_new_children_if_necessary(
/// Returns [`Display`] able a list of [`PhysicalExpr`]
///
/// Example output: `[a + 1, b]`
pub fn format_physical_expr_list(exprs: &[Arc<dyn PhysicalExpr>]) -> impl Display + '_ {
struct DisplayWrapper<'a>(&'a [Arc<dyn PhysicalExpr>]);
impl Display for DisplayWrapper<'_> {
pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
where
T: IntoIterator,
T::Item: Display,
T::IntoIter: Clone,
{
struct DisplayWrapper<I>(I)
where
I: Iterator + Clone,
I::Item: Display;

impl<I> Display for DisplayWrapper<I>
where
I: Iterator + Clone,
I::Item: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut iter = self.0.iter();
let mut iter = self.0.clone();
write!(f, "[")?;
if let Some(expr) = iter.next() {
write!(f, "{}", expr)?;
Expand All @@ -233,5 +246,6 @@ pub fn format_physical_expr_list(exprs: &[Arc<dyn PhysicalExpr>]) -> impl Displa
Ok(())
}
}
DisplayWrapper(exprs)

DisplayWrapper(exprs.into_iter())
}
35 changes: 17 additions & 18 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::fmt::Display;
use std::sync::Arc;

use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping};
use crate::{
expressions::Column, physical_expr::deduplicate_physical_exprs,
physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexRequirement,
expressions::Column, physical_exprs_contains, LexOrdering, LexRequirement,
PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement,
};
use indexmap::IndexSet;
use std::fmt::Display;
use std::sync::Arc;

use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::JoinType;
Expand Down Expand Up @@ -190,47 +189,47 @@ pub struct EquivalenceClass {
/// The expressions in this equivalence class. The order doesn't
/// matter for equivalence purposes
///
/// TODO: use a HashSet for this instead of a Vec
exprs: Vec<Arc<dyn PhysicalExpr>>,
exprs: IndexSet<Arc<dyn PhysicalExpr>>,
}

impl PartialEq for EquivalenceClass {
/// Returns true if other is equal in the sense
/// of bags (multi-sets), disregarding their orderings.
fn eq(&self, other: &Self) -> bool {
physical_exprs_bag_equal(&self.exprs, &other.exprs)
self.exprs.eq(&other.exprs)
}
}

impl EquivalenceClass {
/// Create a new empty equivalence class
pub fn new_empty() -> Self {
Self { exprs: vec![] }
Self {
exprs: IndexSet::new(),
}
}

// Create a new equivalence class from a pre-existing `Vec`
pub fn new(mut exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
deduplicate_physical_exprs(&mut exprs);
Self { exprs }
pub fn new(exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
Self {
exprs: exprs.into_iter().collect(),
}
}

/// Return the inner vector of expressions
pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
self.exprs
self.exprs.into_iter().collect()
}

/// Return the "canonical" expression for this class (the first element)
/// if any
fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
self.exprs.first().cloned()
self.exprs.iter().next().cloned()
}

/// Insert the expression into this class, meaning it is known to be equal to
/// all other expressions in this class
pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
if !self.contains(&expr) {
self.exprs.push(expr);
}
self.exprs.insert(expr);
}

/// Inserts all the expressions from other into this class
Expand All @@ -243,7 +242,7 @@ impl EquivalenceClass {

/// Returns true if this equivalence class contains t expression
pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
physical_exprs_contains(&self.exprs, expr)
self.exprs.contains(expr)
}

/// Returns true if this equivalence class has any entries in common with `other`
Expand Down
61 changes: 2 additions & 59 deletions datafusion/physical-expr/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,34 +65,14 @@ pub fn physical_exprs_bag_equal(
}
}

/// This utility function removes duplicates from the given `exprs` vector.
/// Note that this function does not necessarily preserve its input ordering.
pub fn deduplicate_physical_exprs(exprs: &mut Vec<Arc<dyn PhysicalExpr>>) {
// TODO: Once we can use `HashSet`s with `Arc<dyn PhysicalExpr>`, this
// function should use a `HashSet` to reduce computational complexity.
// See issue: https://github.com/apache/datafusion/issues/8027
let mut idx = 0;
while idx < exprs.len() {
let mut rest_idx = idx + 1;
while rest_idx < exprs.len() {
if exprs[idx].eq(&exprs[rest_idx]) {
exprs.swap_remove(rest_idx);
} else {
rest_idx += 1;
}
}
idx += 1;
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::expressions::{Column, Literal};
use crate::physical_expr::{
deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains,
physical_exprs_equal, PhysicalExpr,
physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal,
PhysicalExpr,
};

use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -208,41 +188,4 @@ mod tests {
assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice()));
assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice()));
}

#[test]
fn test_deduplicate_physical_exprs() {
let lit_true = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>);
let lit_false = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>);
let lit4 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(4))))
as Arc<dyn PhysicalExpr>);
let lit2 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(2))))
as Arc<dyn PhysicalExpr>);
let col_a_expr = &(Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>);
let col_b_expr = &(Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>);

// First vector in the tuple is arguments, second one is the expected value.
let test_cases = vec![
// ---------- TEST CASE 1----------//
(
vec![
lit_true, lit_false, lit4, lit2, col_a_expr, col_a_expr, col_b_expr,
lit_true, lit2,
],
vec![lit_true, lit_false, lit4, lit2, col_a_expr, col_b_expr],
),
// ---------- TEST CASE 2----------//
(
vec![lit_true, lit_true, lit_false, lit4],
vec![lit_true, lit4, lit_false],
),
];
for (exprs, expected) in test_cases {
let mut exprs = exprs.into_iter().cloned().collect::<Vec<_>>();
let expected = expected.into_iter().cloned().collect::<Vec<_>>();
deduplicate_physical_exprs(&mut exprs);
assert!(physical_exprs_equal(&exprs, &expected));
}
}
}

0 comments on commit fdb221f

Please sign in to comment.