Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support for limited sort (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Jul 24, 2021
1 parent 65486f8 commit a6e8b69
Show file tree
Hide file tree
Showing 11 changed files with 737 additions and 452 deletions.
9 changes: 7 additions & 2 deletions benches/sort_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ fn bench_lexsort(arr_a: &dyn Array, array_b: &dyn Array) {
},
];

criterion::black_box(lexsort(&columns).unwrap());
criterion::black_box(lexsort(&columns, None).unwrap());
}

fn bench_sort(arr_a: &dyn Array) {
sort(criterion::black_box(arr_a), &SortOptions::default()).unwrap();
sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap();
}

fn add_benchmark(c: &mut Criterion) {
Expand All @@ -66,6 +66,11 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function(&format!("lexsort null 2^{} f32", log2_size), |b| {
b.iter(|| bench_lexsort(&arr_a, &arr_b))
});

let arr_a = create_string_array::<i32>(size, 0.1);
c.bench_function(&format!("sort utf8 null 2^{}", log2_size), |b| {
b.iter(|| bench_sort(&arr_a))
});
});
}

Expand Down
9 changes: 9 additions & 0 deletions src/buffer/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ impl<T: NativeType> MutableBuffer<T> {
self.len = 0
}

/// Shortens the buffer.
/// If `len` is greater or equal to the buffers' current length, this has no effect.
#[inline]
pub fn truncate(&mut self, len: usize) {
if len < self.len {
self.len = len;
}
}

/// Returns the data stored in this buffer as a slice.
#[inline]
pub fn as_slice(&self) -> &[T] {
Expand Down
4 changes: 2 additions & 2 deletions src/compute/merge_sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,8 @@ mod tests {
let options = SortOptions::default();

// sort individually, potentially in parallel.
let a0 = sort(a0, &options)?;
let a1 = sort(a1, &options)?;
let a0 = sort(a0, &options, None)?;
let a1 = sort(a1, &options, None)?;

// merge then. If multiple arrays, this can be applied in parallel.
let result = merge_sort(a0.as_ref(), a1.as_ref(), &options)?;
Expand Down
52 changes: 52 additions & 0 deletions src/compute/sort/boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use crate::{
array::{Array, BooleanArray, Int32Array},
buffer::MutableBuffer,
datatypes::DataType,
};

use super::SortOptions;

/// Returns the indices that would sort a [`BooleanArray`].
pub fn sort_boolean(
values: &BooleanArray,
value_indices: Vec<i32>,
null_indices: Vec<i32>,
options: &SortOptions,
limit: Option<usize>,
) -> Int32Array {
let descending = options.descending;

// create tuples that are used for sorting
let mut valids = value_indices
.into_iter()
.map(|index| (index, values.value(index as usize)))
.collect::<Vec<(i32, bool)>>();

let mut nulls = null_indices;

if !descending {
valids.sort_by(|a, b| a.1.cmp(&b.1));
} else {
valids.sort_by(|a, b| a.1.cmp(&b.1).reverse());
// reverse to keep a stable ordering
nulls.reverse();
}

let mut values = MutableBuffer::<i32>::with_capacity(values.len());

if options.nulls_first {
values.extend_from_slice(nulls.as_slice());
valids.iter().for_each(|x| values.push(x.0));
} else {
// nulls last
valids.iter().for_each(|x| values.push(x.0));
values.extend_from_slice(nulls.as_slice());
}

// un-efficient; there are much more performant ways of sorting nulls above, anyways.
if let Some(limit) = limit {
values.truncate(limit);
}

Int32Array::from_data(DataType::Int32, values.into(), None)
}
179 changes: 179 additions & 0 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType};

use super::SortOptions;

/// # Safety
/// This function guarantees that:
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn k_element_sort_inner<T, G, F>(
indices: &mut [i32],
get: G,
descending: bool,
limit: usize,
mut cmp: F,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
};
before.sort_unstable_by(compare);
} else {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
};
before.sort_unstable_by(compare);
}
}

/// # Safety
/// This function guarantees that:
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn sort_unstable_by<T, G, F>(
indices: &mut [i32],
get: G,
mut cmp: F,
descending: bool,
limit: usize,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if limit != indices.len() {
return k_element_sort_inner(indices, get, descending, limit, cmp);
}

if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
})
}
}

/// # Safety
/// This function guarantees that:
/// * `get` is only called for `0 <= i < length`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
pub(super) fn indices_sorted_unstable_by<T, G, F>(
validity: &Option<Bitmap>,
get: G,
cmp: F,
length: usize,
options: &SortOptions,
limit: Option<usize>,
) -> PrimitiveArray<i32>
where
G: Fn(usize) -> T,
F: Fn(&T, &T) -> std::cmp::Ordering,
{
let descending = options.descending;

let limit = limit.unwrap_or(length);
// Safety: without this, we go out of bounds when limit >= length.
let limit = limit.min(length);

let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<i32>::from_len_zeroed(length);

if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(is_valid, index)| {
if is_valid {
indices[validity.null_count() + valids] = index;
valids += 1;
} else {
indices[nulls] = index;
nulls += 1;
}
});

if limit > validity.null_count() {
// when limit is larger, we must sort values:

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction < indices.len()
let limit = limit - validity.null_count();
let indices = &mut indices.as_mut_slice()[validity.null_count()..];
sort_unstable_by(indices, get, cmp, options.descending, limit)
}
} else {
let last_valid_index = length - validity.null_count();
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(x, index)| {
if x {
indices[valids] = index;
valids += 1;
} else {
indices[last_valid_index + nulls] = index;
nulls += 1;
}
});

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction <= values.len()
let limit = limit.min(last_valid_index);
let indices = &mut indices.as_mut_slice()[..last_valid_index];
sort_unstable_by(indices, get, cmp, options.descending, limit);
}

indices.truncate(limit);
indices.shrink_to_fit();

indices
} else {
let mut indices =
unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) };

// Soundness:
// indices are by construction `< values.len()`
// limit is by construction `< values.len()`
sort_unstable_by(&mut indices, get, cmp, descending, limit);

indices.truncate(limit);
indices.shrink_to_fit();

indices
};
PrimitiveArray::<i32>::from_data(DataType::Int32, indices.into(), None)
}
53 changes: 27 additions & 26 deletions src/compute/sort/lex_sort.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;

use crate::compute::take;
Expand Down Expand Up @@ -66,14 +49,14 @@ pub struct SortColumn<'a> {
/// nulls_first: false,
/// }),
/// },
/// ]).unwrap();
/// ], None).unwrap();
///
/// let sorted = sorted_columns[0].as_any().downcast_ref::<Int64Array>().unwrap();
/// assert_eq!(sorted.value(1), -64);
/// assert!(sorted.is_null(0));
/// ```
pub fn lexsort(columns: &[SortColumn]) -> Result<Vec<Box<dyn Array>>> {
let indices = lexsort_to_indices(columns)?;
pub fn lexsort(columns: &[SortColumn], limit: Option<usize>) -> Result<Vec<Box<dyn Array>>> {
let indices = lexsort_to_indices(columns, limit)?;
columns
.iter()
.map(|c| take::take(c.values, &indices))
Expand Down Expand Up @@ -135,9 +118,12 @@ pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Resu
})
}

/// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer
/// [`Int32Array`] of indices.
pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<PrimitiveArray<i32>> {
/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray<i32>`]
/// representing the indices that would sort the columns.
pub fn lexsort_to_indices(
columns: &[SortColumn],
limit: Option<usize>,
) -> Result<PrimitiveArray<i32>> {
if columns.is_empty() {
return Err(ArrowError::InvalidArgumentError(
"Sort requires at least one column".to_string(),
Expand All @@ -146,7 +132,7 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<PrimitiveArray<i32>>
if columns.len() == 1 {
// fallback to non-lexical sort
let column = &columns[0];
return sort_to_indices(column.values, &column.options.unwrap_or_default());
return sort_to_indices(column.values, &column.options.unwrap_or_default(), limit);
}

let row_count = columns[0].values.len();
Expand Down Expand Up @@ -180,7 +166,15 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<PrimitiveArray<i32>>
// Safety: `0..row_count` is TrustedLen
let mut values =
unsafe { MutableBuffer::<i32>::from_trusted_len_iter_unchecked(0..row_count as i32) };
values.sort_unstable_by(lex_comparator);

if let Some(limit) = limit {
let limit = limit.min(row_count);
let (before, _, _) = values.select_nth_unstable_by(limit, lex_comparator);
before.sort_unstable_by(lex_comparator);
values.truncate(limit);
} else {
values.sort_unstable_by(lex_comparator);
}

Ok(PrimitiveArray::<i32>::from_data(
DataType::Int32,
Expand All @@ -196,7 +190,14 @@ mod tests {
use super::*;

fn test_lex_sort_arrays(input: Vec<SortColumn>, expected: Vec<Box<dyn Array>>) {
let sorted = lexsort(&input).unwrap();
let sorted = lexsort(&input, None).unwrap();
assert_eq!(sorted, expected);

let sorted = lexsort(&input, Some(2)).unwrap();
let expected = expected
.into_iter()
.map(|x| x.slice(0, 2))
.collect::<Vec<_>>();
assert_eq!(sorted, expected);
}

Expand Down
Loading

0 comments on commit a6e8b69

Please sign in to comment.