diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 1d96532598ca..147af1e301d6 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -22,6 +22,7 @@ use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; +use arrow_buffer::BooleanBufferBuilder; use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer}; use arrow_data::ArrayData; use arrow_data::ArrayDataBuilder; @@ -57,11 +58,74 @@ pub fn sort( values: &dyn Array, options: Option, ) -> Result { - if let DataType::RunEndEncoded(_, _) = values.data_type() { - return sort_run(values, options, None); + downcast_primitive_array!( + values => sort_native_type(values, options), + DataType::RunEndEncoded(_, _) => sort_run(values, options, None), + _ => { + let indices = sort_to_indices(values, options, None)?; + take(values, &indices, None) + } + ) +} + +fn sort_native_type( + primitive_values: &PrimitiveArray, + options: Option, +) -> Result +where + T: ArrowPrimitiveType, +{ + let sort_options = options.unwrap_or_default(); + + let mut mutable_buffer = vec![T::default_value(); primitive_values.len()]; + let mutable_slice = &mut mutable_buffer; + + let input_values = primitive_values.values().as_ref(); + + let nulls_count = primitive_values.null_count(); + let valid_count = primitive_values.len() - nulls_count; + + let null_bit_buffer = match nulls_count > 0 { + true => { + let mut validity_buffer = BooleanBufferBuilder::new(primitive_values.len()); + if sort_options.nulls_first { + validity_buffer.append_n(nulls_count, false); + validity_buffer.append_n(valid_count, true); + } else { + validity_buffer.append_n(valid_count, true); + validity_buffer.append_n(nulls_count, false); + } + Some(validity_buffer.finish().into()) + } + false => None, + }; + + if let Some(nulls) = primitive_values.nulls().filter(|n| n.null_count() > 0) { + let values_slice = match sort_options.nulls_first { + true => &mut mutable_slice[nulls_count..], + false => &mut mutable_slice[..valid_count], + }; + + for (write_index, index) in nulls.valid_indices().enumerate() { + values_slice[write_index] = primitive_values.value(index); + } + + values_slice.sort_unstable_by(|a, b| a.compare(*b)); + if sort_options.descending { + values_slice.reverse(); + } + } else { + mutable_slice.copy_from_slice(input_values); + mutable_slice.sort_unstable_by(|a, b| a.compare(*b)); + if sort_options.descending { + mutable_slice.reverse(); + } } - let indices = sort_to_indices(values, options, None)?; - take(values, &indices, None) + + Ok(Arc::new( + PrimitiveArray::::new(mutable_buffer.into(), null_bit_buffer) + .with_data_type(primitive_values.data_type().clone()), + )) } /// Sort the `ArrayRef` partially. diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 998d077fa105..cbb33de6d1fa 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -188,6 +188,11 @@ name = "sort_kernel" harness = false required-features = ["test_utils"] +[[bench]] +name = "sort_kernel_primitives" +harness = false +required-features = ["test_utils"] + [[bench]] name = "partition_kernels" harness = false diff --git a/arrow/benches/sort_kernel_primitives.rs b/arrow/benches/sort_kernel_primitives.rs new file mode 100644 index 000000000000..ca9183580bd2 --- /dev/null +++ b/arrow/benches/sort_kernel_primitives.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use arrow_ord::sort::sort; +use criterion::Criterion; + +use std::sync::Arc; + +extern crate arrow; + +use arrow::util::bench_util::*; +use arrow::{array::*, datatypes::Int64Type}; + +fn create_i64_array(size: usize, with_nulls: bool) -> ArrayRef { + let null_density = if with_nulls { 0.5 } else { 0.0 }; + let array = create_primitive_array::(size, null_density); + Arc::new(array) +} + +fn bench_sort(array: &ArrayRef) { + criterion::black_box(sort(criterion::black_box(array), None).unwrap()); +} + +fn add_benchmark(c: &mut Criterion) { + let arr_a = create_i64_array(2u64.pow(10) as usize, false); + + c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a))); + + let arr_a = create_i64_array(2u64.pow(12) as usize, false); + + c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a))); + + let arr_a = create_i64_array(2u64.pow(10) as usize, true); + + c.bench_function("sort nulls 2^10", |b| b.iter(|| bench_sort(&arr_a))); + + let arr_a = create_i64_array(2u64.pow(12) as usize, true); + + c.bench_function("sort nulls 2^12", |b| b.iter(|| bench_sort(&arr_a))); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches);