Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Optimized sort of utf8
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Jul 23, 2021
1 parent 128be64 commit ae41e1f
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 238 deletions.
117 changes: 110 additions & 7 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType};

use super::SortOptions;

/// # Safety
/// `indices[i] < values.len()` for all i
/// `limit < values.len()`
/// This function guarantees that:
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
unsafe fn k_element_sort_inner<T, G, F>(
fn k_element_sort_inner<T, G, F>(
indices: &mut [i32],
get: G,
descending: bool,
Expand Down Expand Up @@ -42,11 +47,11 @@ unsafe fn k_element_sort_inner<T, G, F>(
}

/// # Safety
/// Safe iff
/// * `indices[i] < values.len()` for all i
/// * `limit < values.len()`
/// This function guarantees that:
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
pub(super) unsafe fn sort_unstable_by<T, G, F>(
fn sort_unstable_by<T, G, F>(
indices: &mut [i32],
get: G,
mut cmp: F,
Expand Down Expand Up @@ -74,3 +79,101 @@ pub(super) unsafe fn sort_unstable_by<T, G, F>(
})
}
}

/// # Safety
/// This function guarantees that:
/// * `get` is only called for `0 <= i < length`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
pub(super) fn indices_sorted_unstable_by<T, G, F>(
validity: &Option<Bitmap>,
get: G,
cmp: F,
length: usize,
options: &SortOptions,
limit: Option<usize>,
) -> PrimitiveArray<i32>
where
G: Fn(usize) -> T,
F: Fn(&T, &T) -> std::cmp::Ordering,
{
let descending = options.descending;

let limit = limit.unwrap_or(length);
// Safety: without this, we go out of bounds when limit >= length.
let limit = limit.min(length);

let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<i32>::from_len_zeroed(length);

if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(is_valid, index)| {
if is_valid {
indices[validity.null_count() + valids] = index;
valids += 1;
} else {
indices[nulls] = index;
nulls += 1;
}
});

if limit > validity.null_count() {
// when limit is larger, we must sort values:

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction < indices.len()
let limit = limit - validity.null_count();
let indices = &mut indices.as_mut_slice()[validity.null_count()..];
sort_unstable_by(indices, get, cmp, options.descending, limit)
}
} else {
let last_valid_index = length - validity.null_count();
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(x, index)| {
if x {
indices[valids] = index;
valids += 1;
} else {
indices[last_valid_index + nulls] = index;
nulls += 1;
}
});

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction <= values.len()
let limit = limit.min(last_valid_index);
let indices = &mut indices.as_mut_slice()[..last_valid_index];
sort_unstable_by(indices, get, cmp, options.descending, limit);
}

indices.truncate(limit);
indices.shrink_to_fit();

indices
} else {
let mut indices =
unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) };

// Soundness:
// indices are by construction `< values.len()`
// limit is by construction `< values.len()`
sort_unstable_by(&mut indices, get, cmp, descending, limit);

indices.truncate(limit);
indices.shrink_to_fit();

indices
};
PrimitiveArray::<i32>::from_data(DataType::Int32, indices.into(), None)
}
Loading

0 comments on commit ae41e1f

Please sign in to comment.