Skip to content

Commit

Permalink
perf: take_run improvements (#3705)
Browse files Browse the repository at this point in the history
* take_run improvements

* doc fix

* test case update per pr comment

---------

Co-authored-by: ask <ask@local>
  • Loading branch information
askoa and ask authored Feb 13, 2023
1 parent e37e379 commit d011e6a
Showing 1 changed file with 66 additions and 58 deletions.
124 changes: 66 additions & 58 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
use std::sync::Arc;

use arrow_array::builder::BufferBuilder;
use arrow_array::types::*;
use arrow_array::*;
use arrow_array::{builder::PrimitiveRunBuilder, types::*};
use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Field};

use arrow_array::cast::{
as_generic_binary_array, as_largestring_array, as_primitive_array, as_string_array,
};
use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array};
use num::{ToPrimitive, Zero};

/// Take elements by index from [Array], creating a new [Array] from those indexes.
Expand Down Expand Up @@ -816,22 +815,14 @@ where
Ok(DictionaryArray::<T>::from(data))
}

macro_rules! primitive_run_take {
($t:ty, $o:ty, $indices:ident, $value:ident) => {
take_primitive_run_values::<$o, $t>(
$indices,
as_primitive_array::<$t>($value.values()),
)
};
}

/// `take` implementation for run arrays
///
/// Finds physical indices for the given logical indices and builds output run array
/// by taking values in the input run array at the physical indices.
/// for e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `indices=[2,7]`
/// would be converted to `physical_indices=[1,3]` which will be used to build
/// output `RunArray{ run_ends=[2], values=[2] }`
/// by taking values in the input run_array.values at the physical indices.
/// The output run array will be run encoded on the physical indices and not on output values.
/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
fn take_run<T, I>(
run_array: &RunArray<T>,
logical_indices: &PrimitiveArray<I>,
Expand All @@ -842,43 +833,60 @@ where
I: ArrowPrimitiveType,
I::Native: ToPrimitive,
{
match run_array.data_type() {
DataType::RunEndEncoded(_, fl) => {
let physical_indices =
run_array.get_physical_indices(logical_indices.values())?;

downcast_primitive! {
fl.data_type() => (primitive_run_take, T, physical_indices, run_array),
dt => Err(ArrowError::NotYetImplemented(format!("take_run is not implemented for {dt:?}")))
}
// get physical indices for the input logical indices
let physical_indices = run_array.get_physical_indices(logical_indices.values())?;

// Run encode the physical indices into new_run_ends_builder
// Keep track of the physical indices to take in take_value_indices
// `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
let mut new_physical_len = 1;
for ix in 1..physical_indices.len() {
if physical_indices[ix] != physical_indices[ix - 1] {
take_value_indices
.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
new_physical_len += 1;
}
dt => Err(ArrowError::InvalidArgumentError(format!(
"Expected DataType::RunEndEncoded found {dt:?}"
))),
}
}
take_value_indices.append(
I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap(),
);
new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
let new_run_ends = unsafe {
// Safety:
// The function builds a valid run_ends array and hence need not be validated.
ArrayDataBuilder::new(T::DATA_TYPE)
.len(new_physical_len)
.null_count(0)
.add_buffer(new_run_ends_builder.finish())
.build_unchecked()
};

// Builds a `RunArray` by taking values from given array for the given indices.
fn take_primitive_run_values<R, V>(
physical_indices: Vec<usize>,
values: &PrimitiveArray<V>,
) -> Result<RunArray<R>, ArrowError>
where
R: RunEndIndexType,
V: ArrowPrimitiveType,
{
let mut builder = PrimitiveRunBuilder::<R, V>::new();
let values_len = values.len();
for ix in physical_indices {
if ix >= values_len {
return Err(ArrowError::InvalidArgumentError("The requested index {ix} is out of bounds for values array with length {values_len}".to_string()));
} else if values.is_null(ix) {
builder.append_null()
} else {
builder.append_value(values.value(ix))
}
}
Ok(builder.finish())
let take_value_indices: PrimitiveArray<I> = unsafe {
// Safety:
// The function builds a valid take_value_indices array and hence need not be validated.
ArrayDataBuilder::new(I::DATA_TYPE)
.len(new_physical_len)
.null_count(0)
.add_buffer(take_value_indices.finish())
.build_unchecked()
.into()
};

let new_values = take(run_array.values(), &take_value_indices, None)?;

let builder = ArrayDataBuilder::new(run_array.data_type().clone())
.len(physical_indices.len())
.add_child_data(new_run_ends)
.add_child_data(new_values.into_data());
let array_data = unsafe {
// Safety:
// This function builds a valid run array and hence can skip validation.
builder.build_unchecked()
};
Ok(array_data.into())
}

/// Takes/filters a list array's inner data using the offsets of the list array.
Expand Down Expand Up @@ -983,7 +991,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::*;
use arrow_array::{builder::*, cast::as_primitive_array};
use arrow_schema::TimeUnit;

fn test_take_decimal_arrays(
Expand Down Expand Up @@ -2159,24 +2167,24 @@ mod tests {

#[test]
fn test_take_runs() {
let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2];
let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];

let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
builder.extend(logical_array.into_iter().map(Some));
let run_array = builder.finish();

let take_indices: PrimitiveArray<Int32Type> =
vec![2, 7, 10].into_iter().collect();
vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();

let take_out = take_run(&run_array, &take_indices).unwrap();

assert_eq!(take_out.len(), 3);
assert_eq!(take_out.len(), 7);

assert_eq!(take_out.run_ends().len(), 1);
assert_eq!(take_out.run_ends().value(0), 3);
assert_eq!(take_out.run_ends().len(), 5);
assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);

let take_out_values = as_primitive_array::<Int32Type>(take_out.values());
assert_eq!(take_out_values.value(0), 2);
assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
}

#[test]
Expand Down

0 comments on commit d011e6a

Please sign in to comment.