This repository has been archived by the owner on Feb 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 224
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for limited sort (#218)
- Loading branch information
1 parent
65486f8
commit a6e8b69
Showing
11 changed files
with
737 additions
and
452 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
use crate::{ | ||
array::{Array, BooleanArray, Int32Array}, | ||
buffer::MutableBuffer, | ||
datatypes::DataType, | ||
}; | ||
|
||
use super::SortOptions; | ||
|
||
/// Returns the indices that would sort a [`BooleanArray`]. | ||
pub fn sort_boolean( | ||
values: &BooleanArray, | ||
value_indices: Vec<i32>, | ||
null_indices: Vec<i32>, | ||
options: &SortOptions, | ||
limit: Option<usize>, | ||
) -> Int32Array { | ||
let descending = options.descending; | ||
|
||
// create tuples that are used for sorting | ||
let mut valids = value_indices | ||
.into_iter() | ||
.map(|index| (index, values.value(index as usize))) | ||
.collect::<Vec<(i32, bool)>>(); | ||
|
||
let mut nulls = null_indices; | ||
|
||
if !descending { | ||
valids.sort_by(|a, b| a.1.cmp(&b.1)); | ||
} else { | ||
valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); | ||
// reverse to keep a stable ordering | ||
nulls.reverse(); | ||
} | ||
|
||
let mut values = MutableBuffer::<i32>::with_capacity(values.len()); | ||
|
||
if options.nulls_first { | ||
values.extend_from_slice(nulls.as_slice()); | ||
valids.iter().for_each(|x| values.push(x.0)); | ||
} else { | ||
// nulls last | ||
valids.iter().for_each(|x| values.push(x.0)); | ||
values.extend_from_slice(nulls.as_slice()); | ||
} | ||
|
||
// un-efficient; there are much more performant ways of sorting nulls above, anyways. | ||
if let Some(limit) = limit { | ||
values.truncate(limit); | ||
} | ||
|
||
Int32Array::from_data(DataType::Int32, values.into(), None) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType}; | ||
|
||
use super::SortOptions; | ||
|
||
/// # Safety | ||
/// This function guarantees that: | ||
/// * `get` is only called for `0 <= i < limit` | ||
/// * `cmp` is only called from the co-domain of `get`. | ||
#[inline] | ||
fn k_element_sort_inner<T, G, F>( | ||
indices: &mut [i32], | ||
get: G, | ||
descending: bool, | ||
limit: usize, | ||
mut cmp: F, | ||
) where | ||
G: Fn(usize) -> T, | ||
F: FnMut(&T, &T) -> std::cmp::Ordering, | ||
{ | ||
if descending { | ||
let compare = |lhs: &i32, rhs: &i32| { | ||
let lhs = get(*lhs as usize); | ||
let rhs = get(*rhs as usize); | ||
cmp(&lhs, &rhs).reverse() | ||
}; | ||
let (before, _, _) = indices.select_nth_unstable_by(limit, compare); | ||
let compare = |lhs: &i32, rhs: &i32| { | ||
let lhs = get(*lhs as usize); | ||
let rhs = get(*rhs as usize); | ||
cmp(&lhs, &rhs).reverse() | ||
}; | ||
before.sort_unstable_by(compare); | ||
} else { | ||
let compare = |lhs: &i32, rhs: &i32| { | ||
let lhs = get(*lhs as usize); | ||
let rhs = get(*rhs as usize); | ||
cmp(&lhs, &rhs) | ||
}; | ||
let (before, _, _) = indices.select_nth_unstable_by(limit, compare); | ||
let compare = |lhs: &i32, rhs: &i32| { | ||
let lhs = get(*lhs as usize); | ||
let rhs = get(*rhs as usize); | ||
cmp(&lhs, &rhs) | ||
}; | ||
before.sort_unstable_by(compare); | ||
} | ||
} | ||
|
||
/// # Safety | ||
/// This function guarantees that: | ||
/// * `get` is only called for `0 <= i < limit` | ||
/// * `cmp` is only called from the co-domain of `get`. | ||
#[inline] | ||
fn sort_unstable_by<T, G, F>( | ||
indices: &mut [i32], | ||
get: G, | ||
mut cmp: F, | ||
descending: bool, | ||
limit: usize, | ||
) where | ||
G: Fn(usize) -> T, | ||
F: FnMut(&T, &T) -> std::cmp::Ordering, | ||
{ | ||
if limit != indices.len() { | ||
return k_element_sort_inner(indices, get, descending, limit, cmp); | ||
} | ||
|
||
if descending { | ||
indices.sort_unstable_by(|lhs, rhs| { | ||
let lhs = get(*lhs as usize); | ||
let rhs = get(*rhs as usize); | ||
cmp(&lhs, &rhs).reverse() | ||
}) | ||
} else { | ||
indices.sort_unstable_by(|lhs, rhs| { | ||
let lhs = get(*lhs as usize); | ||
let rhs = get(*rhs as usize); | ||
cmp(&lhs, &rhs) | ||
}) | ||
} | ||
} | ||
|
||
/// # Safety | ||
/// This function guarantees that: | ||
/// * `get` is only called for `0 <= i < length` | ||
/// * `cmp` is only called from the co-domain of `get`. | ||
#[inline] | ||
pub(super) fn indices_sorted_unstable_by<T, G, F>( | ||
validity: &Option<Bitmap>, | ||
get: G, | ||
cmp: F, | ||
length: usize, | ||
options: &SortOptions, | ||
limit: Option<usize>, | ||
) -> PrimitiveArray<i32> | ||
where | ||
G: Fn(usize) -> T, | ||
F: Fn(&T, &T) -> std::cmp::Ordering, | ||
{ | ||
let descending = options.descending; | ||
|
||
let limit = limit.unwrap_or(length); | ||
// Safety: without this, we go out of bounds when limit >= length. | ||
let limit = limit.min(length); | ||
|
||
let indices = if let Some(validity) = validity { | ||
let mut indices = MutableBuffer::<i32>::from_len_zeroed(length); | ||
|
||
if options.nulls_first { | ||
let mut nulls = 0; | ||
let mut valids = 0; | ||
validity | ||
.iter() | ||
.zip(0..length as i32) | ||
.for_each(|(is_valid, index)| { | ||
if is_valid { | ||
indices[validity.null_count() + valids] = index; | ||
valids += 1; | ||
} else { | ||
indices[nulls] = index; | ||
nulls += 1; | ||
} | ||
}); | ||
|
||
if limit > validity.null_count() { | ||
// when limit is larger, we must sort values: | ||
|
||
// Soundness: | ||
// all indices in `indices` are by construction `< array.len() == values.len()` | ||
// limit is by construction < indices.len() | ||
let limit = limit - validity.null_count(); | ||
let indices = &mut indices.as_mut_slice()[validity.null_count()..]; | ||
sort_unstable_by(indices, get, cmp, options.descending, limit) | ||
} | ||
} else { | ||
let last_valid_index = length - validity.null_count(); | ||
let mut nulls = 0; | ||
let mut valids = 0; | ||
validity | ||
.iter() | ||
.zip(0..length as i32) | ||
.for_each(|(x, index)| { | ||
if x { | ||
indices[valids] = index; | ||
valids += 1; | ||
} else { | ||
indices[last_valid_index + nulls] = index; | ||
nulls += 1; | ||
} | ||
}); | ||
|
||
// Soundness: | ||
// all indices in `indices` are by construction `< array.len() == values.len()` | ||
// limit is by construction <= values.len() | ||
let limit = limit.min(last_valid_index); | ||
let indices = &mut indices.as_mut_slice()[..last_valid_index]; | ||
sort_unstable_by(indices, get, cmp, options.descending, limit); | ||
} | ||
|
||
indices.truncate(limit); | ||
indices.shrink_to_fit(); | ||
|
||
indices | ||
} else { | ||
let mut indices = | ||
unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) }; | ||
|
||
// Soundness: | ||
// indices are by construction `< values.len()` | ||
// limit is by construction `< values.len()` | ||
sort_unstable_by(&mut indices, get, cmp, descending, limit); | ||
|
||
indices.truncate(limit); | ||
indices.shrink_to_fit(); | ||
|
||
indices | ||
}; | ||
PrimitiveArray::<i32>::from_data(DataType::Int32, indices.into(), None) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.