Skip to content

Commit

Permalink
Bug fix: Fix lexicographical column search among provided ordering (#156
Browse files Browse the repository at this point in the history
)
  • Loading branch information
mustafasrepo authored and metesynnada committed Dec 15, 2023
1 parent 4abaa79 commit f9c53f5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1313,15 +1313,15 @@ mod order_preserving_join_swap_tests {
let expected_input = vec![
"BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]",
" HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3",
" MemoryExec: partitions=0, partition_sizes=[]",
" MemoryExec: partitions=0, partition_sizes=[]",
" MemoryExec: partitions=0, partition_sizes=[], output_ordering=a@0 ASC",
" MemoryExec: partitions=0, partition_sizes=[], output_ordering=d@0 ASC",
];
let expected_optimized = vec![
"BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]",
" SortExec: expr=[e@4 ASC]",
" HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3",
" MemoryExec: partitions=0, partition_sizes=[]",
" MemoryExec: partitions=0, partition_sizes=[]",
" MemoryExec: partitions=0, partition_sizes=[], output_ordering=a@0 ASC",
" MemoryExec: partitions=0, partition_sizes=[], output_ordering=d@0 ASC",
];
assert_optimized_orthogonal!(expected_input, expected_optimized, physical_plan);
Ok(())
Expand Down
37 changes: 29 additions & 8 deletions datafusion/physical-expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,35 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result<ArrayRef> {
let data = mutable.freeze();
Ok(make_array(data))
}

/// Return indices of each item in `required_exprs` inside `provided_exprs`.
/// All of the indices should be found inside `provided_exprs`.
/// Also found indices should be a permutation of range consecutive range from 0 to n.
/// Such as \[2,1,0\] is valid (\[0,1,2\] is consecutive). However, \[3,1,0\] is not
/// valid (\[0,1,3\] is not consecutive).
fn get_lexicographical_match_indices(
required_exprs: &[Arc<dyn PhysicalExpr>],
provided_exprs: &[Arc<dyn PhysicalExpr>],
) -> Option<Vec<usize>> {
let indices_of_equality = get_indices_of_exprs_strict(required_exprs, provided_exprs);
let mut ordered_indices = indices_of_equality.clone();
ordered_indices.sort();
let n_match = indices_of_equality.len();
let first_n = longest_consecutive_prefix(ordered_indices);
// If we found all the expressions, return early:
if n_match == required_exprs.len() && first_n == n_match && n_match > 0 {
return Some(indices_of_equality);
}
None
}

/// This function attempts to find a full match between required and provided
/// sorts, returning the indices and sort options of the matches found.
///
/// First, it normalizes the sort requirements and then checks for matches.
/// If no full match is found, it then checks against ordering equivalence properties.
/// If still no full match is found, it returns `None`.
/// required_columns columns of lexicographical ordering.
pub fn get_indices_of_matching_sort_exprs_with_order_eq<
F: Fn() -> EquivalenceProperties,
F2: Fn() -> OrderingEquivalenceProperties,
Expand Down Expand Up @@ -397,10 +420,9 @@ pub fn get_indices_of_matching_sort_exprs_with_order_eq<
.map(|req| req.expr.clone())
.collect::<Vec<_>>();

let indices_of_equality =
get_indices_of_exprs_strict(&normalized_required_expr, &provided_sorts);
// If we found all the expressions, return early:
if indices_of_equality.len() == normalized_required_expr.len() {
if let Some(indices_of_equality) =
get_lexicographical_match_indices(&normalized_required_expr, &provided_sorts)
{
return Some((
indices_of_equality
.iter()
Expand All @@ -415,16 +437,15 @@ pub fn get_indices_of_matching_sort_exprs_with_order_eq<
let head = class.head();
for ordering in class.others().iter().chain(std::iter::once(head)) {
let order_eq_class_exprs = convert_to_expr(ordering);
let indices_of_equality = get_indices_of_exprs_strict(
if let Some(indices_of_equality) = get_lexicographical_match_indices(
&normalized_required_expr,
&order_eq_class_exprs,
);
if indices_of_equality.len() == normalized_required_expr.len() {
) {
return Some((
indices_of_equality
.iter()
.map(|index| ordering[*index].options)
.collect::<Vec<_>>(),
.collect(),
indices_of_equality,
));
}
Expand Down

0 comments on commit f9c53f5

Please sign in to comment.