Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support for custom sort build_compare_fn (#1016)
Browse files Browse the repository at this point in the history
  • Loading branch information
b41sh authored May 28, 2022
1 parent 7cc874f commit 93bdde8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
11 changes: 10 additions & 1 deletion src/compute/merge_sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ type IsValid<'a> = Box<dyn Fn(usize) -> bool + 'a>;
/// returns a comparison function between any two arrays of each pair of arrays, according to `SortOptions`.
pub fn build_comparator<'a>(
pairs: &'a [(&'a [&'a dyn Array], &SortOptions)],
) -> Result<Comparator<'a>> {
build_comparator_impl(pairs, &build_compare)
}

/// returns a comparison function between any two arrays of each pair of arrays, according to `SortOptions`.
/// Implementing custom `build_compare_fn` for unsupportd data types.
pub fn build_comparator_impl<'a>(
pairs: &'a [(&'a [&'a dyn Array], &SortOptions)],
build_compare_fn: &dyn Fn(&dyn Array, &dyn Array) -> Result<DynComparator>,
) -> Result<Comparator<'a>> {
// prepare the comparison function of _values_ between all pairs of arrays
let indices_pairs = (0..pairs[0].0.len())
Expand All @@ -483,7 +492,7 @@ pub fn build_comparator<'a>(
Ok((
Box::new(move |row| arrays[lhs_index].is_valid(row)) as IsValid<'a>,
Box::new(move |row| arrays[rhs_index].is_valid(row)) as IsValid<'a>,
build_compare(arrays[lhs_index], arrays[rhs_index])?,
build_compare_fn(arrays[lhs_index], arrays[rhs_index])?,
))
})
.collect::<Result<Vec<_>>>()?;
Expand Down
33 changes: 30 additions & 3 deletions src/compute/sort/lex_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,16 @@ fn build_is_valid(array: &dyn Array) -> IsValid {
}

pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Result<DynComparator> {
build_compare_impl(array, sort_option, &ord::build_compare)
}

pub(crate) fn build_compare_impl(
array: &dyn Array,
sort_option: SortOptions,
build_compare_fn: &dyn Fn(&dyn Array, &dyn Array) -> Result<DynComparator>,
) -> Result<DynComparator> {
let is_valid = build_is_valid(array);
let comparator = ord::build_compare(array, array)?;
let comparator = build_compare_fn(array, array)?;

Ok(match (sort_option.descending, sort_option.nulls_first) {
(true, true) => Box::new(move |i: usize, j: usize| match (is_valid(i), is_valid(j)) {
Expand Down Expand Up @@ -127,6 +135,17 @@ pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Resu
pub fn lexsort_to_indices<I: Index>(
columns: &[SortColumn],
limit: Option<usize>,
) -> Result<PrimitiveArray<I>> {
lexsort_to_indices_impl(columns, limit, &ord::build_compare)
}

/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray`]
/// representing the indices that would sort the columns.
/// Implementing custom `build_compare_fn` for unsupportd data types.
pub fn lexsort_to_indices_impl<I: Index>(
columns: &[SortColumn],
limit: Option<usize>,
build_compare_fn: &dyn Fn(&dyn Array, &dyn Array) -> Result<DynComparator>,
) -> Result<PrimitiveArray<I>> {
if columns.is_empty() {
return Err(Error::InvalidArgumentError(
Expand All @@ -136,7 +155,11 @@ pub fn lexsort_to_indices<I: Index>(
if columns.len() == 1 {
// fallback to non-lexical sort
let column = &columns[0];
return sort_to_indices(column.values, &column.options.unwrap_or_default(), limit);
if let Ok(indices) =
sort_to_indices(column.values, &column.options.unwrap_or_default(), limit)
{
return Ok(indices);
}
}

let row_count = columns[0].values.len();
Expand All @@ -150,7 +173,11 @@ pub fn lexsort_to_indices<I: Index>(
let comparators = columns
.iter()
.map(|column| -> Result<DynComparator> {
build_compare(column.values, column.options.unwrap_or_default())
build_compare_impl(
column.values,
column.options.unwrap_or_default(),
build_compare_fn,
)
})
.collect::<Result<Vec<DynComparator>>>()?;

Expand Down
2 changes: 1 addition & 1 deletion src/compute/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mod primitive;
mod utf8;

pub(crate) use lex_sort::build_compare;
pub use lex_sort::{lexsort, lexsort_to_indices, SortColumn};
pub use lex_sort::{lexsort, lexsort_to_indices, lexsort_to_indices_impl, SortColumn};

macro_rules! dyn_sort {
($ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{
Expand Down

0 comments on commit 93bdde8

Please sign in to comment.