Skip to content

Commit

Permalink
WIP: Use IndexSet for OrderingEquivalenceSet
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 15, 2025
1 parent 0c229d7 commit 544e058
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 33 deletions.
8 changes: 8 additions & 0 deletions datafusion/physical-expr-common/src/sort_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,14 @@ impl LexOrdering {
}
output
}

/// applies the method to each expr in this ordering
pub fn map(mut self, mut f: impl FnMut(&mut PhysicalSortExpr)) -> Self {
for sort_expr in self.inner.iter_mut() {
f(sort_expr)
}
self
}
}

impl From<Vec<PhysicalSortExpr>> for LexOrdering {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/equivalence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement {

/// Adds the `offset` value to `Column` indices inside `expr`. This function is
/// generally used during the update of the right table schema in join operations.
pub fn add_offset_to_expr(
fn add_offset_to_expr(
expr: Arc<dyn PhysicalExpr>,
offset: usize,
) -> Arc<dyn PhysicalExpr> {
Expand Down
58 changes: 40 additions & 18 deletions datafusion/physical-expr/src/equivalence/ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
// under the License.

use std::fmt::Display;
use std::hash::Hash;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::vec::IntoIter;

use crate::equivalence::add_offset_to_expr;
use crate::{LexOrdering, PhysicalExpr};
use arrow_schema::SortOptions;
use indexmap::IndexSet;

/// An `OrderingEquivalenceClass` object keeps track of different alternative
/// orderings than can describe a schema. For example, consider the following table:
Expand All @@ -37,9 +37,18 @@ use arrow_schema::SortOptions;
///
/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table
/// ordering. In this case, we say that these orderings are equivalent.
#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct OrderingEquivalenceClass {
orderings: Vec<LexOrdering>,
/// Use index set to maintain order but avoid duplicates.
orderings: IndexSet<LexOrdering>,
}

impl Hash for OrderingEquivalenceClass {
fn hash<H: Hasher>(&self, state: &mut H) {
for ordering in &self.orderings {
ordering.hash(state);
}
}
}

impl OrderingEquivalenceClass {
Expand All @@ -56,15 +65,16 @@ impl OrderingEquivalenceClass {
/// Creates new ordering equivalence class from the given orderings
///
/// Any redundant entries are removed
pub fn new(orderings: Vec<LexOrdering>) -> Self {
let mut result = Self { orderings };
result.remove_redundant_entries();
result
pub fn new(orderings: impl IntoIterator<Item = LexOrdering>) -> Self {
let orderings = orderings.into_iter().collect();
Self { orderings }
}

/// Converts this OrderingEquivalenceClass to a vector of orderings.
///
// TODO remove / rename into_vec if it is needed
pub fn into_inner(self) -> Vec<LexOrdering> {
self.orderings
self.orderings.into_iter().collect()
}

/// Checks whether `ordering` is a member of this equivalence class.
Expand Down Expand Up @@ -175,32 +185,38 @@ impl OrderingEquivalenceClass {
let n_ordering = self.orderings.len();
// Replicate entries before cross product
let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering);
self.orderings = self
let mut new_orderings: Vec<_> = self
.orderings
.iter()
.cloned()
.cycle()
.take(n_cross)
.collect();

// Suffix orderings of other to the current orderings.
for (outer_idx, ordering) in other.iter().enumerate() {
for idx in 0..n_ordering {
// Calculate cross product index
let idx = outer_idx * n_ordering + idx;
self.orderings[idx].inner.extend(ordering.iter().cloned());
new_orderings[idx].inner.extend(ordering.iter().cloned());
}
}
// turn back to indexset
self.orderings = new_orderings.into_iter().collect();
self
}

/// Adds `offset` value to the index of each expression inside this
/// ordering equivalence class.
pub fn add_offset(&mut self, offset: usize) {
for ordering in self.orderings.iter_mut() {
for sort_expr in ordering.inner.iter_mut() {
sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset);
}
}
pub fn add_offset(mut self, offset: usize) -> Self {
self.orderings =
// update each offset and then recollect the set
self.orderings.into_iter()
.map(|ordering| {
ordering.map(|sort_expr| { sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); })
})
.collect();
self
}

/// Gets sort options associated with this expression if it is a leading
Expand All @@ -219,7 +235,7 @@ impl OrderingEquivalenceClass {
/// Convert the `OrderingEquivalenceClass` into an iterator of LexOrderings
impl IntoIterator for OrderingEquivalenceClass {
type Item = LexOrdering;
type IntoIter = IntoIter<LexOrdering>;
type IntoIter = indexmap::set::IntoIter<LexOrdering>;

fn into_iter(self) -> Self::IntoIter {
self.orderings.into_iter()
Expand Down Expand Up @@ -255,6 +271,12 @@ impl Display for OrderingEquivalenceClass {
}
}

impl From<IndexSet<LexOrdering>> for OrderingEquivalenceClass {
fn from(orderings: IndexSet<LexOrdering>) -> Self {
Self { orderings }
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down
29 changes: 15 additions & 14 deletions datafusion/physical-expr/src/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ impl EquivalenceProperties {
OrderingEquivalenceClass::new(
self.oeq_class
.iter()
.map(|ordering| self.normalize_sort_exprs(ordering))
.collect(),
.map(|ordering| self.normalize_sort_exprs(ordering)),
)
}

Expand Down Expand Up @@ -715,7 +714,7 @@ impl EquivalenceProperties {
.iter()
.map(|order| self.substitute_ordering_component(mapping, order))
.collect::<Result<Vec<_>>>()?;
let new_order = new_order.into_iter().flatten().collect();
let new_order = new_order.into_iter().flatten();
self.oeq_class = OrderingEquivalenceClass::new(new_order);
Ok(())
}
Expand Down Expand Up @@ -1852,16 +1851,16 @@ pub fn join_equivalence_properties(
} = left;
let EquivalenceProperties {
constants: right_constants,
oeq_class: mut right_oeq_class,
oeq_class: right_oeq_class,
..
} = right;
match maintains_input_order {
[true, false] => {
// In this special case, right side ordering can be prefixed with
// the left side ordering.
if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) {
updated_right_ordering_equivalence_class(
&mut right_oeq_class,
let right_oeq_class = updated_right_ordering_equivalence_class(
right_oeq_class,
join_type,
left_size,
);
Expand All @@ -1881,8 +1880,8 @@ pub fn join_equivalence_properties(
}
}
[false, true] => {
updated_right_ordering_equivalence_class(
&mut right_oeq_class,
let right_oeq_class = updated_right_ordering_equivalence_class(
right_oeq_class,
join_type,
left_size,
);
Expand Down Expand Up @@ -1927,15 +1926,17 @@ pub fn join_equivalence_properties(
/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases,
/// indices do not change.
fn updated_right_ordering_equivalence_class(
right_oeq_class: &mut OrderingEquivalenceClass,
right_oeq_class: OrderingEquivalenceClass,
join_type: &JoinType,
left_size: usize,
) {
) -> OrderingEquivalenceClass {
if matches!(
join_type,
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right
) {
right_oeq_class.add_offset(left_size);
right_oeq_class.add_offset(left_size)
} else {
right_oeq_class
}
}

Expand Down Expand Up @@ -2496,7 +2497,7 @@ mod tests {
];
let orderings = convert_to_orderings(&orderings);
// Right child ordering equivalences
let mut right_oeq_class = OrderingEquivalenceClass::new(orderings);
let right_oeq_class = OrderingEquivalenceClass::new(orderings);

let left_columns_len = 4;

Expand All @@ -2519,8 +2520,8 @@ mod tests {
join_eq_properties.add_equal_conditions(col_a, col_x)?;
join_eq_properties.add_equal_conditions(col_d, col_w)?;

updated_right_ordering_equivalence_class(
&mut right_oeq_class,
let right_oeq_class = updated_right_ordering_equivalence_class(
right_oeq_class,
&join_type,
left_columns_len,
);
Expand Down

0 comments on commit 544e058

Please sign in to comment.