Skip to content

Commit

Permalink
have filter_run_end_array use filter array with run_ends max value size
Browse files Browse the repository at this point in the history
  • Loading branch information
delamarch3 committed Nov 3, 2024
1 parent 9c661bf commit 1f5b61c
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use std::ops::AddAssign;
use std::sync::Arc;

use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::builder::{BooleanBufferBuilder, BooleanBuilder};
use arrow_array::cast::AsArray;
use arrow_array::types::{
ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
Expand Down Expand Up @@ -378,17 +378,8 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array
Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
}
DataType::RunEndEncoded(_, _) => {
if predicate.filter.len() != values.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Filter predicate of length {} is not equal to the target array of length {}",
predicate.filter.len(),
values.len()
)));
}

// Safety: We have checked that the predicate and values have the same length
downcast_run_array!{
values => Ok(Arc::new(unsafe { filter_run_end_array(values, predicate)? })),
values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
}
}
Expand Down Expand Up @@ -431,31 +422,45 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array
}

/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
///
/// # Safety
/// The caller must ensure that the `pred` and `re_arr` are the same length.
unsafe fn filter_run_end_array<R: RunEndIndexType>(
fn filter_run_end_array<R: RunEndIndexType>(
re_arr: &RunArray<R>,
pred: &FilterPredicate,
) -> Result<RunArray<R>, ArrowError>
where
R::Native: Into<i64> + From<bool>,
R::Native: AddAssign,
{
let mut resized_filter = None;
let required_filter_len = re_arr.run_ends().max_value();
let diff = required_filter_len - pred.filter.len();
if diff > 0 {
let mut builder = BooleanBuilder::with_capacity(required_filter_len);
pred.filter
.values()
.iter()
.for_each(|v| builder.append_value(v));
builder.append_n(diff, false);
resized_filter = Some(builder.finish());
}

let run_ends: &RunEndBuffer<R::Native> = re_arr.run_ends();
let mut values_filter = BooleanBufferBuilder::new(run_ends.len());
let mut new_run_ends = vec![R::default_value(); run_ends.len()];

let mut start = 0i64;
let mut i = 0;
let filter_values = pred.filter.values();
let mut count = R::default_value();
let filter_values = if let Some(ref resized_filter) = resized_filter {
resized_filter.values()
} else {
pred.filter.values()
};

for end in run_ends.inner().into_iter().map(|i| (*i).into()) {
let mut keep = false;
// in filter_array the predicate array is checked to have the same len as the run end array
// this means the largest value in the run_ends is == to pred.len()
// so we're always within bounds when calling value_unchecked

// Safety: we create new filter values above if the given values are less than the
// run_ends max value so we're always within bounds when calling value_unchecked
for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
count += R::Native::from(pred);
keep |= pred
Expand Down Expand Up @@ -1293,15 +1298,14 @@ mod tests {
}

#[test]
fn test_filter_run_end_encoding_array_safety() {
fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
let values = Int64Array::from(vec![7, -2, 9, -8]);
let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
let b = BooleanArray::from(vec![
true, false, false, false, false, false, false, false, false,
]);
let c = filter(&a, &b);
assert!(c.is_err());
let b = BooleanArray::from(vec![false, true, false]);
let c = filter(&a, &b).unwrap();
let actual: &RunArray<Int64Type> = as_run_array(&c);
assert_eq!(1, actual.len());
}

#[test]
Expand Down

0 comments on commit 1f5b61c

Please sign in to comment.