diff --git a/encodings/alp/src/alp/compute/mod.rs b/encodings/alp/src/alp/compute/mod.rs index a6c0bde235..74c0546ff6 100644 --- a/encodings/alp/src/alp/compute/mod.rs +++ b/encodings/alp/src/alp/compute/mod.rs @@ -1,6 +1,6 @@ use vortex_array::compute::{ filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, - TakeFn, TakeOptions, + TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -48,15 +48,10 @@ impl ScalarAtFn for ALPEncoding { } impl TakeFn for ALPEncoding { - fn take( - &self, - array: &ALPArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &ALPArray, indices: &ArrayData) -> VortexResult { // TODO(ngates): wrap up indices in an array that caches decompression? Ok(ALPArray::try_new( - take(array.encoded(), indices, options)?, + take(array.encoded(), indices)?, array.exponents(), array .patches() diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 6c4d6871e8..321935b7f0 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -1,26 +1,21 @@ -use vortex_array::compute::{take, TakeFn, TakeOptions}; +use vortex_array::compute::{take, TakeFn}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_error::VortexResult; use crate::{ALPRDArray, ALPRDEncoding}; impl TakeFn for ALPRDEncoding { - fn take( - &self, - array: &ALPRDArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &ALPRDArray, indices: &ArrayData) -> VortexResult { let left_parts_exceptions = array .left_parts_exceptions() - .map(|array| take(&array, indices, options)) + .map(|array| take(&array, indices)) .transpose()?; Ok(ALPRDArray::try_new( array.dtype().clone(), - take(array.left_parts(), indices, options)?, + take(array.left_parts(), indices)?, array.left_parts_dict(), - take(array.right_parts(), indices, options)?, + take(array.right_parts(), indices)?, array.right_bit_width(), left_parts_exceptions, )? @@ -32,7 +27,7 @@ impl TakeFn for ALPRDEncoding { mod test { use rstest::rstest; use vortex_array::array::PrimitiveArray; - use vortex_array::compute::{take, TakeOptions}; + use vortex_array::compute::take; use vortex_array::IntoArrayVariant; use crate::{ALPRDFloat, RDEncoder}; @@ -46,14 +41,10 @@ mod test { assert!(encoded.left_parts_exceptions().is_some()); - let taken = take( - encoded.as_ref(), - PrimitiveArray::from(vec![0, 2]).as_ref(), - TakeOptions::default(), - ) - .unwrap() - .into_primitive() - .unwrap(); + let taken = take(encoded.as_ref(), PrimitiveArray::from(vec![0, 2]).as_ref()) + .unwrap() + .into_primitive() + .unwrap(); assert_eq!(taken.maybe_null_slice::(), &[a, outlier]); } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 5f8d63b59d..6c9a31d0cf 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -1,7 +1,5 @@ use num_traits::AsPrimitive; -use vortex_array::compute::{ - ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn, TakeOptions, -}; +use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn}; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; @@ -49,12 +47,7 @@ impl SliceFn for ByteBoolEncoding { } impl TakeFn for ByteBoolEncoding { - fn take( - &self, - array: &ByteBoolArray, - indices: &ArrayData, - _options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &ByteBoolArray, indices: &ArrayData) -> VortexResult { let validity = array.validity(); let indices = indices.clone().into_primitive()?; let bools = array.maybe_null_slice(); diff --git a/encodings/datetime-parts/src/compute/take.rs b/encodings/datetime-parts/src/compute/take.rs index dcd697e326..c9d0c68366 100644 --- a/encodings/datetime-parts/src/compute/take.rs +++ b/encodings/datetime-parts/src/compute/take.rs @@ -1,21 +1,16 @@ -use vortex_array::compute::{take, TakeFn, TakeOptions}; +use vortex_array::compute::{take, TakeFn}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_error::VortexResult; use crate::{DateTimePartsArray, DateTimePartsEncoding}; impl TakeFn for DateTimePartsEncoding { - fn take( - &self, - array: &DateTimePartsArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &DateTimePartsArray, indices: &ArrayData) -> VortexResult { Ok(DateTimePartsArray::try_new( array.dtype().clone(), - take(array.days(), indices, options)?, - take(array.seconds(), indices, options)?, - take(array.subsecond(), indices, options)?, + take(array.days(), indices)?, + take(array.seconds(), indices)?, + take(array.subsecond(), indices)?, )? .into_array()) } diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index c8f3c6d024..e7c17e5e5c 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -3,7 +3,7 @@ use std::fmt::{Debug, Display}; use arrow_buffer::BooleanBuffer; use serde::{Deserialize, Serialize}; use vortex_array::array::BoolArray; -use vortex_array::compute::{scalar_at, take, TakeOptions}; +use vortex_array::compute::{scalar_at, take}; use vortex_array::encoding::ids; use vortex_array::stats::StatsSet; use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; @@ -74,10 +74,10 @@ impl IntoCanonical for DictArray { // copies of the view pointers. DType::Utf8(_) | DType::Binary(_) => { let canonical_values: ArrayData = self.values().into_canonical()?.into(); - take(canonical_values, self.codes(), TakeOptions::default())?.into_canonical() + take(canonical_values, self.codes())?.into_canonical() } // Non-string case: take and then canonicalize - _ => take(self.values(), self.codes(), TakeOptions::default())?.into_canonical(), + _ => take(self.values(), self.codes())?.into_canonical(), } } } diff --git a/encodings/dict/src/compute/compare.rs b/encodings/dict/src/compute/compare.rs index 04b485da33..db95fd9540 100644 --- a/encodings/dict/src/compute/compare.rs +++ b/encodings/dict/src/compute/compare.rs @@ -1,5 +1,5 @@ use vortex_array::array::ConstantArray; -use vortex_array::compute::{compare, take, CompareFn, Operator, TakeOptions}; +use vortex_array::compute::{compare, take, CompareFn, Operator}; use vortex_array::ArrayData; use vortex_error::VortexResult; @@ -20,7 +20,7 @@ impl CompareFn for DictEncoding { ConstantArray::new(const_scalar, lhs.values().len()), operator, )?; - return take(compare_result, lhs.codes(), TakeOptions::default()).map(Some); + return take(compare_result, lhs.codes()).map(Some); } // It's a little more complex, but we could perform a comparison against the dictionary diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 7454b5e368..51490e9ce2 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -3,7 +3,7 @@ mod like; use vortex_array::compute::{ filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, LikeFn, - ScalarAtFn, SliceFn, TakeFn, TakeOptions, + ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; @@ -45,16 +45,11 @@ impl ScalarAtFn for DictEncoding { } impl TakeFn for DictEncoding { - fn take( - &self, - array: &DictArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &DictArray, indices: &ArrayData) -> VortexResult { // Dict // codes: 0 0 1 // dict: a b c d e f g h - let codes = take(array.codes(), indices, options)?; + let codes = take(array.codes(), indices)?; DictArray::try_new(codes, array.values()).map(|a| a.into_array()) } } diff --git a/encodings/fastlanes/benches/bitpacking_take.rs b/encodings/fastlanes/benches/bitpacking_take.rs index 6eb215042d..6706c95640 100644 --- a/encodings/fastlanes/benches/bitpacking_take.rs +++ b/encodings/fastlanes/benches/bitpacking_take.rs @@ -6,7 +6,7 @@ use itertools::Itertools; use rand::distributions::Uniform; use rand::{thread_rng, Rng}; use vortex_array::array::PrimitiveArray; -use vortex_array::compute::{take, TakeOptions}; +use vortex_array::compute::take; use vortex_fastlanes::{find_best_bit_width, BitPackedArray}; fn values(len: usize, bits: usize) -> Vec { @@ -27,30 +27,12 @@ fn bench_take(c: &mut Criterion) { let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::>().into(); c.bench_function("take_10_stratified", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - stratified_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), stratified_indices.as_ref()).unwrap())); }); let contiguous_indices: PrimitiveArray = (0..10).collect::>().into(); c.bench_function("take_10_contiguous", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - contiguous_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap())); }); let rng = thread_rng(); @@ -62,30 +44,12 @@ fn bench_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("take_10K_random", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - random_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), random_indices.as_ref()).unwrap())); }); let contiguous_indices: PrimitiveArray = (0..10_000).collect::>().into(); c.bench_function("take_10K_contiguous", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - contiguous_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap())); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -93,16 +57,7 @@ fn bench_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("take_200K_dispersed", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - lots_of_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -110,16 +65,7 @@ fn bench_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("take_200K_first_chunk_only", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - lots_of_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); }); } @@ -142,30 +88,12 @@ fn bench_patched_take(c: &mut Criterion) { let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::>().into(); c.bench_function("patched_take_10_stratified", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - stratified_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), stratified_indices.as_ref()).unwrap())); }); let contiguous_indices: PrimitiveArray = (0..10).collect::>().into(); c.bench_function("patched_take_10_contiguous", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - contiguous_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap())); }); let rng = thread_rng(); @@ -177,16 +105,7 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_random", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - random_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), random_indices.as_ref()).unwrap())); }); let not_patch_indices: PrimitiveArray = (0u32..num_exceptions) @@ -195,16 +114,7 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_contiguous_not_patches", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - not_patch_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), not_patch_indices.as_ref()).unwrap())); }); let patch_indices: PrimitiveArray = (big_base2..big_base2 + num_exceptions) @@ -213,16 +123,7 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_contiguous_patches", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - patch_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), patch_indices.as_ref()).unwrap())); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -230,16 +131,7 @@ fn bench_patched_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("patched_take_200K_dispersed", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - lots_of_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -247,16 +139,7 @@ fn bench_patched_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("patched_take_200K_first_chunk_only", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - lots_of_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); }); // There are currently 2 magic parameters of note: @@ -280,16 +163,7 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_adversarial", |b| { - b.iter(|| { - black_box( - take( - packed.as_ref(), - adversarial_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(), - ) - }); + b.iter(|| black_box(take(packed.as_ref(), adversarial_indices.as_ref()).unwrap())); }); } diff --git a/encodings/fastlanes/src/bitpacking/compute/slice.rs b/encodings/fastlanes/src/bitpacking/compute/slice.rs index d7b3728530..be83727c23 100644 --- a/encodings/fastlanes/src/bitpacking/compute/slice.rs +++ b/encodings/fastlanes/src/bitpacking/compute/slice.rs @@ -39,7 +39,7 @@ impl SliceFn for BitPackedEncoding { mod test { use itertools::Itertools; use vortex_array::array::PrimitiveArray; - use vortex_array::compute::{scalar_at, slice, take, TakeOptions}; + use vortex_array::compute::{scalar_at, slice, take}; use vortex_array::{ArrayLen, IntoArrayData}; use crate::BitPackedArray; @@ -191,7 +191,6 @@ mod test { let taken = take( &sliced, PrimitiveArray::from(vec![101i64, 1125i64, 1138i64]).as_ref(), - TakeOptions::default(), ) .unwrap(); assert_eq!(taken.len(), 3); diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index b1af0599df..c433fd3e52 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -1,7 +1,7 @@ use fastlanes::BitPacking; use itertools::Itertools; use vortex_array::array::PrimitiveArray; -use vortex_array::compute::{take, try_cast, TakeFn, TakeOptions}; +use vortex_array::compute::{take, try_cast, TakeFn}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, IntoCanonical, ToArrayData, @@ -20,29 +20,20 @@ use crate::{unpack_single_primitive, BitPackedArray, BitPackedEncoding}; pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8; impl TakeFn for BitPackedEncoding { - fn take( - &self, - array: &BitPackedArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &BitPackedArray, indices: &ArrayData) -> VortexResult { // If the indices are large enough, it's faster to flatten and take the primitive array. if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() { - return take( - array.clone().into_canonical()?.into_primitive()?, - indices, - options, - ); + return take(array.clone().into_canonical()?.into_primitive()?, indices); } let ptype: PType = array.dtype().try_into()?; let validity = array.validity(); - let taken_validity = validity.take(indices, options)?; + let taken_validity = validity.take(indices)?; let indices = indices.clone().into_primitive()?; let taken = match_each_unsigned_integer_ptype!(ptype, |$T| { match_each_integer_ptype!(indices.ptype(), |$I| { - PrimitiveArray::from_vec(take_primitive::<$T, $I>(array, &indices, options)?, taken_validity) + PrimitiveArray::from_vec(take_primitive::<$T, $I>(array, &indices)?, taken_validity) }) }); Ok(taken.reinterpret_cast(ptype).into_array()) @@ -52,7 +43,6 @@ impl TakeFn for BitPackedEncoding { fn take_primitive( array: &BitPackedArray, indices: &PrimitiveArray, - _options: TakeOptions, ) -> VortexResult> { if indices.is_empty() { return Ok(vec![]); @@ -153,7 +143,7 @@ mod test { use rand::distributions::Uniform; use rand::{thread_rng, Rng}; use vortex_array::array::PrimitiveArray; - use vortex_array::compute::{scalar_at, slice, take, TakeOptions}; + use vortex_array::compute::{scalar_at, slice, take}; use vortex_array::{IntoArrayData, IntoArrayVariant}; use crate::BitPackedArray; @@ -166,7 +156,7 @@ mod test { let unpacked = PrimitiveArray::from((0..4096).map(|i| (i % 63) as u8).collect::>()); let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap(); - let primitive_result = take(bitpacked.as_ref(), &indices, TakeOptions::default()) + let primitive_result = take(bitpacked.as_ref(), &indices) .unwrap() .into_primitive() .unwrap(); @@ -181,7 +171,7 @@ mod test { let indices = PrimitiveArray::from(vec![0, 2, 4, 6]); - let primitive_result = take(bitpacked.as_ref(), &indices, TakeOptions::default()) + let primitive_result = take(bitpacked.as_ref(), &indices) .unwrap() .into_primitive() .unwrap(); @@ -198,10 +188,7 @@ mod test { let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap(); let sliced = slice(bitpacked.as_ref(), 128, 2050).unwrap(); - let primitive_result = take(&sliced, &indices, TakeOptions::default()) - .unwrap() - .into_primitive() - .unwrap(); + let primitive_result = take(&sliced, &indices).unwrap().into_primitive().unwrap(); let res_bytes = primitive_result.maybe_null_slice::(); assert_eq!(res_bytes, &[31, 33]); } @@ -223,12 +210,7 @@ mod test { .map(|i| i as u32) .collect_vec() .into(); - let taken = take( - packed.as_ref(), - random_indices.as_ref(), - TakeOptions::default(), - ) - .unwrap(); + let taken = take(packed.as_ref(), random_indices.as_ref()).unwrap(); // sanity check random_indices diff --git a/encodings/fastlanes/src/for/compute.rs b/encodings/fastlanes/src/for/compute.rs index 5324c58224..87e269353c 100644 --- a/encodings/fastlanes/src/for/compute.rs +++ b/encodings/fastlanes/src/for/compute.rs @@ -3,7 +3,7 @@ use std::ops::AddAssign; use num_traits::{CheckedShl, CheckedShr, WrappingAdd, WrappingSub}; use vortex_array::compute::{ filter, scalar_at, search_sorted, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, - SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions, + SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -36,14 +36,9 @@ impl ComputeVTable for FoREncoding { } impl TakeFn for FoREncoding { - fn take( - &self, - array: &FoRArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &FoRArray, indices: &ArrayData) -> VortexResult { FoRArray::try_new( - take(array.encoded(), indices, options)?, + take(array.encoded(), indices)?, array.reference_scalar(), array.shift(), ) diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 065d42b8a1..923d69e7e6 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -3,7 +3,7 @@ mod compare; use vortex_array::array::varbin_scalar; use vortex_array::compute::{ filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, - SliceFn, TakeFn, TakeOptions, + SliceFn, TakeFn, }; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_buffer::Buffer; @@ -51,18 +51,13 @@ impl SliceFn for FSSTEncoding { impl TakeFn for FSSTEncoding { // Take on an FSSTArray is a simple take on the codes array. - fn take( - &self, - array: &FSSTArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &FSSTArray, indices: &ArrayData) -> VortexResult { Ok(FSSTArray::try_new( array.dtype().clone(), array.symbols(), array.symbol_lengths(), - take(array.codes(), indices, options)?, - take(array.uncompressed_lengths(), indices, options)?, + take(array.codes(), indices)?, + take(array.uncompressed_lengths(), indices)?, )? .into_array()) } diff --git a/encodings/fsst/tests/fsst_tests.rs b/encodings/fsst/tests/fsst_tests.rs index 98641285ab..69a4801c37 100644 --- a/encodings/fsst/tests/fsst_tests.rs +++ b/encodings/fsst/tests/fsst_tests.rs @@ -2,7 +2,7 @@ use vortex_array::array::builder::VarBinBuilder; use vortex_array::array::PrimitiveArray; -use vortex_array::compute::{filter, scalar_at, slice, take, FilterMask, TakeOptions}; +use vortex_array::compute::{filter, scalar_at, slice, take, FilterMask}; use vortex_array::encoding::Encoding; use vortex_array::validity::Validity; use vortex_array::{ArrayData, IntoArrayData, IntoCanonical}; @@ -71,7 +71,7 @@ fn test_fsst_array_ops() { // test take let indices = PrimitiveArray::from_vec(vec![0, 2], Validity::NonNullable).into_array(); - let fsst_taken = take(&fsst_array, &indices, TakeOptions::default()).unwrap(); + let fsst_taken = take(&fsst_array, &indices).unwrap(); assert_eq!(fsst_taken.len(), 2); assert_nth_scalar!( fsst_taken, diff --git a/encodings/runend-bool/src/array.rs b/encodings/runend-bool/src/array.rs index d9303dd483..7189b1bc4f 100644 --- a/encodings/runend-bool/src/array.rs +++ b/encodings/runend-bool/src/array.rs @@ -255,7 +255,7 @@ mod test { use itertools::Itertools as _; use rstest::rstest; use vortex_array::array::{BoolArray, PrimitiveArray}; - use vortex_array::compute::{scalar_at, slice, take, TakeOptions}; + use vortex_array::compute::{scalar_at, slice, take}; use vortex_array::stats::ArrayStatistics; use vortex_array::validity::Validity; use vortex_array::{ @@ -345,7 +345,6 @@ mod test { ) .unwrap(), vec![0, 0, 6, 4].into_array(), - TakeOptions::default(), ) .unwrap(); diff --git a/encodings/runend-bool/src/compute/mod.rs b/encodings/runend-bool/src/compute/mod.rs index 1e533963d4..0bb53ea1ea 100644 --- a/encodings/runend-bool/src/compute/mod.rs +++ b/encodings/runend-bool/src/compute/mod.rs @@ -1,9 +1,7 @@ mod invert; use vortex_array::array::BoolArray; -use vortex_array::compute::{ - slice, ComputeVTable, InvertFn, ScalarAtFn, SliceFn, TakeFn, TakeOptions, -}; +use vortex_array::compute::{slice, ComputeVTable, InvertFn, ScalarAtFn, SliceFn, TakeFn}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::match_each_integer_ptype; @@ -39,12 +37,7 @@ impl ScalarAtFn for RunEndBoolEncoding { } impl TakeFn for RunEndBoolEncoding { - fn take( - &self, - array: &RunEndBoolArray, - indices: &ArrayData, - _options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &RunEndBoolArray, indices: &ArrayData) -> VortexResult { let primitive_indices = indices.clone().into_primitive()?; let physical_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| { primitive_indices diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index 4a21ad3b6e..6786045f7d 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -8,7 +8,7 @@ use num_traits::AsPrimitive; use vortex_array::array::{BooleanBuffer, ConstantArray, PrimitiveArray, SparseArray}; use vortex_array::compute::{ filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn, - ScalarAtFn, SliceFn, TakeFn, TakeOptions, + ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::validity::Validity; use vortex_array::variants::PrimitiveArrayTrait; @@ -53,12 +53,7 @@ impl ScalarAtFn for RunEndEncoding { impl TakeFn for RunEndEncoding { #[allow(deprecated)] - fn take( - &self, - array: &RunEndArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &RunEndArray, indices: &ArrayData) -> VortexResult { let primitive_indices = indices.clone().into_primitive()?; let usize_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| { primitive_indices @@ -80,7 +75,7 @@ impl TakeFn for RunEndEncoding { .map(|idx| idx as u64) .collect::>(); let physical_indices_array = PrimitiveArray::from(physical_indices).into_array(); - let dense_values = take(array.values(), &physical_indices_array, options)?; + let dense_values = take(array.values(), &physical_indices_array)?; Ok(match array.validity() { Validity::NonNullable => dense_values, @@ -89,8 +84,7 @@ impl TakeFn for RunEndEncoding { ConstantArray::new(Scalar::null(array.dtype().clone()), indices.len()).into_array() } Validity::Array(original_validity) => { - let dense_validity = - FilterMask::try_from(take(&original_validity, indices, options)?)?; + let dense_validity = FilterMask::try_from(take(&original_validity, indices)?)?; let length = dense_validity.len(); let dense_nonnull_indices = PrimitiveArray::from( dense_validity @@ -197,9 +191,7 @@ fn filter_run_ends + AsPrimitive>( #[cfg(test)] mod test { use vortex_array::array::{BoolArray, PrimitiveArray}; - use vortex_array::compute::{ - filter, scalar_at, slice, take, try_cast, FilterMask, TakeOptions, - }; + use vortex_array::compute::{filter, scalar_at, slice, take, try_cast, FilterMask}; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::{ArrayDType, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability, PType}; @@ -219,7 +211,6 @@ mod test { let taken = take( ree_array().as_ref(), PrimitiveArray::from(vec![9, 8, 1, 3]).as_ref(), - TakeOptions::default(), ) .unwrap(); assert_eq!( @@ -233,7 +224,6 @@ mod test { let taken = take( ree_array().as_ref(), PrimitiveArray::from(vec![11]).as_ref(), - TakeOptions::default(), ) .unwrap(); assert_eq!( @@ -248,7 +238,6 @@ mod test { take( ree_array().as_ref(), PrimitiveArray::from(vec![12]).as_ref(), - TakeOptions::default(), ) .unwrap(); } @@ -409,7 +398,7 @@ mod test { .unwrap(); let test_indices = PrimitiveArray::from_vec(vec![0, 2, 4, 6], Validity::NonNullable); - let taken = take(arr.as_ref(), test_indices.as_ref(), TakeOptions::default()).unwrap(); + let taken = take(arr.as_ref(), test_indices.as_ref()).unwrap(); assert_eq!(taken.len(), test_indices.len()); @@ -447,7 +436,6 @@ mod test { let taken = take( sliced.as_ref(), PrimitiveArray::from(vec![1, 3, 4]).as_ref(), - TakeOptions::default(), ) .unwrap(); diff --git a/fuzz/fuzz_targets/array_ops.rs b/fuzz/fuzz_targets/array_ops.rs index a998f2e222..7fa289f677 100644 --- a/fuzz/fuzz_targets/array_ops.rs +++ b/fuzz/fuzz_targets/array_ops.rs @@ -6,7 +6,7 @@ use vortex_array::array::{ BoolEncoding, PrimitiveEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding, }; use vortex_array::compute::{ - filter, scalar_at, search_sorted, slice, take, SearchResult, SearchSortedSide, TakeOptions, + filter, scalar_at, search_sorted, slice, take, SearchResult, SearchSortedSide, }; use vortex_array::encoding::EncodingRef; use vortex_array::{ArrayData, IntoCanonical}; @@ -36,7 +36,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus { if indices.is_empty() { return Corpus::Reject; } - current_array = take(¤t_array, &indices, TakeOptions::default()).unwrap(); + current_array = take(¤t_array, &indices).unwrap(); assert_array_eq(&expected.array(), ¤t_array, i); } Action::SearchSorted(s, side) => { diff --git a/pyvortex/src/array.rs b/pyvortex/src/array.rs index d69a6df8d6..bb239a129d 100644 --- a/pyvortex/src/array.rs +++ b/pyvortex/src/array.rs @@ -4,9 +4,7 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyInt, PyList}; use vortex::array::ChunkedArray; -use vortex::compute::{ - compare, fill_forward, scalar_at, slice, take, FilterMask, Operator, TakeOptions, -}; +use vortex::compute::{compare, fill_forward, scalar_at, slice, take, FilterMask, Operator}; use vortex::{ArrayDType, ArrayData, IntoCanonical}; use crate::dtype::PyDType; @@ -439,7 +437,7 @@ impl PyArray { ))); } - let inner = take(&self.inner, indices, TakeOptions::default())?; + let inner = take(&self.inner, indices)?; Ok(PyArray { inner }) } diff --git a/vortex-array/benches/take_strings.rs b/vortex-array/benches/take_strings.rs index 23accc3cec..69f651970e 100644 --- a/vortex-array/benches/take_strings.rs +++ b/vortex-array/benches/take_strings.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use vortex_array::array::{PrimitiveArray, VarBinArray}; -use vortex_array::compute::{take, TakeOptions}; +use vortex_array::compute::take; use vortex_array::validity::Validity; use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_dtype::{DType, Nullability}; @@ -33,9 +33,7 @@ fn bench_varbin(c: &mut Criterion) { let array = fixture(65_535); let indices = indices(1024); - c.bench_function("varbin", |b| { - b.iter(|| take(&array, &indices, TakeOptions::default()).unwrap()) - }); + c.bench_function("varbin", |b| b.iter(|| take(&array, &indices).unwrap())); } fn bench_varbinview(c: &mut Criterion) { @@ -43,7 +41,7 @@ fn bench_varbinview(c: &mut Criterion) { let indices = indices(1024); c.bench_function("varbinview", |b| { - b.iter(|| take(array.as_ref(), &indices, TakeOptions::default()).unwrap()) + b.iter(|| take(array.as_ref(), &indices).unwrap()) }); } diff --git a/vortex-array/src/array/bool/compute/take.rs b/vortex-array/src/array/bool/compute/take.rs index a6263ce9d3..74d59583ee 100644 --- a/vortex-array/src/array/bool/compute/take.rs +++ b/vortex-array/src/array/bool/compute/take.rs @@ -5,16 +5,35 @@ use vortex_dtype::match_each_integer_ptype; use vortex_error::VortexResult; use crate::array::{BoolArray, BoolEncoding}; -use crate::compute::{TakeFn, TakeOptions}; +use crate::compute::TakeFn; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; impl TakeFn for BoolEncoding { - fn take( + fn take(&self, array: &BoolArray, indices: &ArrayData) -> VortexResult { + let validity = array.validity(); + let indices = indices.clone().into_primitive()?; + + // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth + // the overhead to convert to a Vec. + let buffer = if array.len() <= 4096 { + let bools = array.boolean_buffer().into_iter().collect_vec(); + match_each_integer_ptype!(indices.ptype(), |$I| { + take_byte_bool(bools, indices.maybe_null_slice::<$I>()) + }) + } else { + match_each_integer_ptype!(indices.ptype(), |$I| { + take_bool(&array.boolean_buffer(), indices.maybe_null_slice::<$I>()) + }) + }; + + Ok(BoolArray::try_new(buffer, validity.take(indices.as_ref())?)?.into_array()) + } + + unsafe fn take_unchecked( &self, array: &BoolArray, indices: &ArrayData, - options: TakeOptions, ) -> VortexResult { let validity = array.validity(); let indices = indices.clone().into_primitive()?; @@ -24,23 +43,17 @@ impl TakeFn for BoolEncoding { let buffer = if array.len() <= 4096 { let bools = array.boolean_buffer().into_iter().collect_vec(); match_each_integer_ptype!(indices.ptype(), |$I| { - if options.skip_bounds_check { - take_byte_bool_unchecked(bools, indices.maybe_null_slice::<$I>()) - } else { - take_byte_bool(bools, indices.maybe_null_slice::<$I>()) - } + take_byte_bool_unchecked(bools, indices.maybe_null_slice::<$I>()) }) } else { match_each_integer_ptype!(indices.ptype(), |$I| { - if options.skip_bounds_check { - take_bool_unchecked(&array.boolean_buffer(), indices.maybe_null_slice::<$I>()) - } else { - take_bool(&array.boolean_buffer(), indices.maybe_null_slice::<$I>()) - } + take_bool_unchecked(&array.boolean_buffer(), indices.maybe_null_slice::<$I>()) }) }; - Ok(BoolArray::try_new(buffer, validity.take(indices.as_ref(), options)?)?.into_array()) + // SAFETY: caller enforces indices are valid for array, and array has same len as validity. + let validity = unsafe { validity.take_unchecked(indices.as_ref())? }; + Ok(BoolArray::try_new(buffer, validity)?.into_array()) } } @@ -80,7 +93,7 @@ fn take_bool_unchecked>( mod test { use crate::array::primitive::PrimitiveArray; use crate::array::BoolArray; - use crate::compute::{take, TakeOptions}; + use crate::compute::take; #[test] fn take_nullable() { @@ -92,15 +105,8 @@ mod test { Some(false), ]); - let b = BoolArray::try_from( - take( - &reference, - PrimitiveArray::from(vec![0, 3, 4]), - TakeOptions::default(), - ) - .unwrap(), - ) - .unwrap(); + let b = BoolArray::try_from(take(&reference, PrimitiveArray::from(vec![0, 3, 4])).unwrap()) + .unwrap(); assert_eq!( b.boolean_buffer(), BoolArray::from_iter(vec![Some(false), None, Some(false)]).boolean_buffer() diff --git a/vortex-array/src/array/chunked/compute/filter.rs b/vortex-array/src/array/chunked/compute/filter.rs index ad547d7e6a..c8f70ff2b2 100644 --- a/vortex-array/src/array/chunked/compute/filter.rs +++ b/vortex-array/src/array/chunked/compute/filter.rs @@ -2,9 +2,7 @@ use arrow_buffer::BooleanBufferBuilder; use vortex_error::{VortexExpect, VortexResult, VortexUnwrap}; use crate::array::{ChunkedArray, ChunkedEncoding, PrimitiveArray}; -use crate::compute::{ - filter, take, FilterFn, FilterMask, SearchSorted, SearchSortedSide, TakeOptions, -}; +use crate::compute::{filter, take, FilterFn, FilterMask, SearchSorted, SearchSortedSide}; use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoCanonical}; // This is modeled after the constant with the equivalent name in arrow-rs. @@ -159,7 +157,6 @@ fn filter_indices(array: &ChunkedArray, mask: FilterMask) -> VortexResult VortexResult for ChunkedEncoding { - fn take( - &self, - array: &ChunkedArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &ChunkedArray, indices: &ArrayData) -> VortexResult { // Fast path for strict sorted indices. if indices .statistics() @@ -29,7 +24,7 @@ impl TakeFn for ChunkedEncoding { return Ok(array.to_array()); } - return take_strict_sorted(array, indices, options); + return take_strict_sorted(array, indices); } let indices = try_cast(indices, PType::U64.into())?.into_primitive()?; @@ -50,7 +45,6 @@ impl TakeFn for ChunkedEncoding { chunks.push(take( &array.chunk(prev_chunk_idx)?, &indices_in_chunk_array, - options, )?); indices_in_chunk = Vec::new(); } @@ -64,7 +58,6 @@ impl TakeFn for ChunkedEncoding { chunks.push(take( &array.chunk(prev_chunk_idx)?, &indices_in_chunk_array, - options, )?); } @@ -73,11 +66,7 @@ impl TakeFn for ChunkedEncoding { } /// When the indices are non-null and strict-sorted, we can do better -fn take_strict_sorted( - chunked: &ChunkedArray, - indices: &ArrayData, - options: TakeOptions, -) -> VortexResult { +fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResult { let mut indices_by_chunk = vec![None; chunked.nchunks()]; // Track our position in the indices array @@ -125,7 +114,7 @@ fn take_strict_sorted( .into_iter() .enumerate() .filter_map(|(chunk_idx, indices)| indices.map(|i| (chunk_idx, i))) - .map(|(chunk_idx, chunk_indices)| take(&chunked.chunk(chunk_idx)?, &chunk_indices, options)) + .map(|(chunk_idx, chunk_indices)| take(&chunked.chunk(chunk_idx)?, &chunk_indices)) .try_collect()?; Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array()) @@ -134,7 +123,7 @@ fn take_strict_sorted( #[cfg(test)] mod test { use crate::array::chunked::ChunkedArray; - use crate::compute::{take, TakeOptions}; + use crate::compute::take; use crate::{ArrayDType, ArrayLen, IntoArrayData, IntoArrayVariant}; #[test] @@ -146,12 +135,11 @@ mod test { assert_eq!(arr.len(), 9); let indices = vec![0u64, 0, 6, 4].into_array(); - let result = - &ChunkedArray::try_from(take(arr.as_ref(), &indices, TakeOptions::default()).unwrap()) - .unwrap() - .into_array() - .into_primitive() - .unwrap(); + let result = &ChunkedArray::try_from(take(arr.as_ref(), &indices).unwrap()) + .unwrap() + .into_array() + .into_primitive() + .unwrap(); assert_eq!(result.maybe_null_slice::(), &[1, 1, 1, 2]); } } diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index c3bd9994fb..56cdaf0ce3 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -10,7 +10,7 @@ use crate::array::constant::ConstantArray; use crate::array::ConstantEncoding; use crate::compute::{ BinaryBooleanFn, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn, - SearchSortedFn, SliceFn, TakeFn, TakeOptions, + SearchSortedFn, SliceFn, TakeFn, }; use crate::{ArrayData, IntoArrayData}; @@ -55,12 +55,7 @@ impl ScalarAtFn for ConstantEncoding { } impl TakeFn for ConstantEncoding { - fn take( - &self, - array: &ConstantArray, - indices: &ArrayData, - _options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &ConstantArray, indices: &ArrayData) -> VortexResult { Ok(ConstantArray::new(array.scalar(), indices.len()).into_array()) } } diff --git a/vortex-array/src/array/extension/compute/mod.rs b/vortex-array/src/array/extension/compute/mod.rs index 8f3eb8ed5b..46c0ec9f83 100644 --- a/vortex-array/src/array/extension/compute/mod.rs +++ b/vortex-array/src/array/extension/compute/mod.rs @@ -7,7 +7,6 @@ use crate::array::extension::ExtensionArray; use crate::array::ExtensionEncoding; use crate::compute::{ scalar_at, slice, take, CastFn, CompareFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn, - TakeOptions, }; use crate::variants::ExtensionArrayTrait; use crate::{ArrayData, IntoArrayData}; @@ -57,16 +56,10 @@ impl SliceFn for ExtensionEncoding { } impl TakeFn for ExtensionEncoding { - fn take( - &self, - array: &ExtensionArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { - Ok(ExtensionArray::new( - array.ext_dtype().clone(), - take(array.storage(), indices, options)?, + fn take(&self, array: &ExtensionArray, indices: &ArrayData) -> VortexResult { + Ok( + ExtensionArray::new(array.ext_dtype().clone(), take(array.storage(), indices)?) + .into_array(), ) - .into_array()) } } diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index eaa15f9943..40938d8560 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -4,7 +4,7 @@ use vortex_scalar::Scalar; use crate::array::null::NullArray; use crate::array::NullEncoding; -use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn, TakeOptions}; +use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; @@ -35,27 +35,28 @@ impl ScalarAtFn for NullEncoding { } impl TakeFn for NullEncoding { - fn take( - &self, - array: &NullArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &NullArray, indices: &ArrayData) -> VortexResult { let indices = indices.clone().into_primitive()?; // Enforce all indices are valid - if !options.skip_bounds_check { - match_each_integer_ptype!(indices.ptype(), |$T| { - for index in indices.maybe_null_slice::<$T>() { - if !((*index as usize) < array.len()) { - vortex_bail!(OutOfBounds: *index as usize, 0, array.len()); - } + match_each_integer_ptype!(indices.ptype(), |$T| { + for index in indices.maybe_null_slice::<$T>() { + if !((*index as usize) < array.len()) { + vortex_bail!(OutOfBounds: *index as usize, 0, array.len()); } - }); - } + } + }); Ok(NullArray::new(indices.len()).into_array()) } + + unsafe fn take_unchecked( + &self, + _array: &NullArray, + indices: &ArrayData, + ) -> VortexResult { + Ok(NullArray::new(indices.len()).into_array()) + } } #[cfg(test)] @@ -63,7 +64,7 @@ mod test { use vortex_dtype::DType; use crate::array::null::NullArray; - use crate::compute::{scalar_at, slice, take, TakeOptions}; + use crate::compute::{scalar_at, slice, take}; use crate::validity::{ArrayValidity, LogicalValidity}; use crate::{ArrayLen, IntoArrayData}; @@ -83,15 +84,8 @@ mod test { #[test] fn test_take_nulls() { let nulls = NullArray::new(10); - let taken = NullArray::try_from( - take( - nulls, - vec![0u64, 2, 4, 6, 8].into_array(), - TakeOptions::default(), - ) - .unwrap(), - ) - .unwrap(); + let taken = + NullArray::try_from(take(nulls, vec![0u64, 2, 4, 6, 8].into_array()).unwrap()).unwrap(); assert_eq!(taken.len(), 5); assert!(matches!( diff --git a/vortex-array/src/array/primitive/compute/take.rs b/vortex-array/src/array/primitive/compute/take.rs index 1003d1f041..d2570b40dc 100644 --- a/vortex-array/src/array/primitive/compute/take.rs +++ b/vortex-array/src/array/primitive/compute/take.rs @@ -4,29 +4,36 @@ use vortex_error::VortexResult; use crate::array::primitive::PrimitiveArray; use crate::array::PrimitiveEncoding; -use crate::compute::{TakeFn, TakeOptions}; +use crate::compute::TakeFn; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; impl TakeFn for PrimitiveEncoding { #[allow(clippy::cognitive_complexity)] - fn take( + fn take(&self, array: &PrimitiveArray, indices: &ArrayData) -> VortexResult { + let indices = indices.clone().into_primitive()?; + let validity = array.validity().take(indices.as_ref())?; + + match_each_native_ptype!(array.ptype(), |$T| { + match_each_integer_ptype!(indices.ptype(), |$I| { + let values = take_primitive(array.maybe_null_slice::<$T>(), indices.into_maybe_null_slice::<$I>()); + Ok(PrimitiveArray::from_vec(values, validity).into_array()) + }) + }) + } + + unsafe fn take_unchecked( &self, array: &PrimitiveArray, indices: &ArrayData, - options: TakeOptions, ) -> VortexResult { let indices = indices.clone().into_primitive()?; - let validity = array.validity().take(indices.as_ref(), options)?; + let validity = unsafe { array.validity().take_unchecked(indices.as_ref())? }; match_each_native_ptype!(array.ptype(), |$T| { match_each_integer_ptype!(indices.ptype(), |$I| { - let values = if options.skip_bounds_check { - take_primitive_unchecked(array.maybe_null_slice::<$T>(), indices.into_maybe_null_slice::<$I>()) - } else { - take_primitive(array.maybe_null_slice::<$T>(), indices.into_maybe_null_slice::<$I>()) - }; - Ok(PrimitiveArray::from_vec(values,validity).into_array()) + let values = take_primitive_unchecked(array.maybe_null_slice::<$T>(), indices.into_maybe_null_slice::<$I>()); + Ok(PrimitiveArray::from_vec(values, validity).into_array()) }) }) } @@ -43,7 +50,7 @@ fn take_primitive>( // We pass a Vec in case we're T == u64. // In which case, Rust should reuse the same Vec the result. -fn take_primitive_unchecked>( +unsafe fn take_primitive_unchecked>( array: &[T], indices: Vec, ) -> Vec { diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 6c3c224f19..aaf0fd5ce4 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -7,7 +7,6 @@ use crate::array::{PrimitiveArray, SparseEncoding}; use crate::compute::{ scalar_at, search_sorted, take, ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn, SearchResult, SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn, SliceFn, TakeFn, - TakeOptions, }; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; @@ -127,11 +126,7 @@ impl FilterFn for SparseEncoding { Ok(SparseArray::try_new( PrimitiveArray::from(coordinate_indices).into_array(), - take( - array.values(), - PrimitiveArray::from(value_indices), - TakeOptions::default(), - )?, + take(array.values(), PrimitiveArray::from(value_indices))?, buffer.count_set_bits(), array.fill_scalar(), )? diff --git a/vortex-array/src/array/sparse/compute/take.rs b/vortex-array/src/array/sparse/compute/take.rs index 4af03fa723..d305dea441 100644 --- a/vortex-array/src/array/sparse/compute/take.rs +++ b/vortex-array/src/array/sparse/compute/take.rs @@ -8,17 +8,12 @@ use crate::aliases::hash_map::HashMap; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; use crate::array::SparseEncoding; -use crate::compute::{take, try_cast, TakeFn, TakeOptions}; +use crate::compute::{take, try_cast, TakeFn}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; impl TakeFn for SparseEncoding { - fn take( - &self, - array: &SparseArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &SparseArray, indices: &ArrayData) -> VortexResult { let flat_indices = indices.clone().into_primitive()?; // if we are taking a lot of values we should build a hashmap let (positions, physical_take_indices) = if indices.len() > 128 { @@ -27,7 +22,7 @@ impl TakeFn for SparseEncoding { take_search_sorted(array, &flat_indices)? }; - let taken_values = take(array.values(), physical_take_indices, options)?; + let taken_values = take(array.values(), physical_take_indices)?; Ok(SparseArray::try_new( positions.into_array(), @@ -107,7 +102,7 @@ mod test { use crate::array::primitive::PrimitiveArray; use crate::array::sparse::compute::take::take_map; use crate::array::sparse::SparseArray; - use crate::compute::{take, TakeOptions}; + use crate::compute::take; use crate::validity::Validity; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; @@ -126,15 +121,9 @@ mod test { #[test] fn sparse_take() { let sparse = sparse_array(); - let taken = SparseArray::try_from( - take( - sparse, - vec![0, 47, 47, 0, 99].into_array(), - TakeOptions::default(), - ) - .unwrap(), - ) - .unwrap(); + let taken = + SparseArray::try_from(take(sparse, vec![0, 47, 47, 0, 99].into_array()).unwrap()) + .unwrap(); assert_eq!( taken .indices() @@ -156,10 +145,7 @@ mod test { #[test] fn nonexistent_take() { let sparse = sparse_array(); - let taken = SparseArray::try_from( - take(sparse, vec![69].into_array(), TakeOptions::default()).unwrap(), - ) - .unwrap(); + let taken = SparseArray::try_from(take(sparse, vec![69].into_array()).unwrap()).unwrap(); assert!(taken .indices() .into_primitive() @@ -177,10 +163,8 @@ mod test { #[test] fn ordered_take() { let sparse = sparse_array(); - let taken = SparseArray::try_from( - take(&sparse, vec![69, 37].into_array(), TakeOptions::default()).unwrap(), - ) - .unwrap(); + let taken = + SparseArray::try_from(take(&sparse, vec![69, 37].into_array()).unwrap()).unwrap(); assert_eq!( taken .indices() diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index 3d0a08dc4e..ba3b44dce7 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -6,7 +6,7 @@ use crate::array::struct_::StructArray; use crate::array::StructEncoding; use crate::compute::{ filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, - TakeFn, TakeOptions, + TakeFn, }; use crate::variants::StructArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData}; @@ -42,20 +42,15 @@ impl ScalarAtFn for StructEncoding { } impl TakeFn for StructEncoding { - fn take( - &self, - array: &StructArray, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &StructArray, indices: &ArrayData) -> VortexResult { StructArray::try_new( array.names().clone(), array .children() - .map(|field| take(&field, indices, options)) + .map(|field| take(&field, indices)) .try_collect()?, indices.len(), - array.validity().take(indices, options)?, + array.validity().take(indices)?, ) .map(|a| a.into_array()) } diff --git a/vortex-array/src/array/varbin/compute/take.rs b/vortex-array/src/array/varbin/compute/take.rs index 111141ba04..c3788e6c28 100644 --- a/vortex-array/src/array/varbin/compute/take.rs +++ b/vortex-array/src/array/varbin/compute/take.rs @@ -5,18 +5,13 @@ use vortex_error::{vortex_err, vortex_panic, VortexResult}; use crate::array::varbin::builder::VarBinBuilder; use crate::array::varbin::VarBinArray; use crate::array::VarBinEncoding; -use crate::compute::{TakeFn, TakeOptions}; +use crate::compute::TakeFn; use crate::validity::Validity; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; impl TakeFn for VarBinEncoding { - fn take( - &self, - array: &VarBinArray, - indices: &ArrayData, - _options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &VarBinArray, indices: &ArrayData) -> VortexResult { let offsets = array.offsets().into_primitive()?; let data = array.bytes().into_primitive()?; let indices = indices.clone().into_primitive()?; diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index 44ae9bbb6b..53d5a1ca2e 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -11,7 +11,7 @@ use vortex_scalar::Scalar; use crate::array::varbin::varbin_scalar; use crate::array::varbinview::{VarBinViewArray, VIEW_SIZE_BYTES}; use crate::array::{PrimitiveArray, VarBinViewEncoding}; -use crate::compute::{slice, ComputeVTable, ScalarAtFn, SliceFn, TakeFn, TakeOptions}; +use crate::compute::{slice, ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; use crate::validity::Validity; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; @@ -58,14 +58,43 @@ impl SliceFn for VarBinViewEncoding { /// Take involves creating a new array that references the old array, just with the given set of views. impl TakeFn for VarBinViewEncoding { - fn take( + fn take(&self, array: &VarBinViewArray, indices: &ArrayData) -> VortexResult { + // Compute the new validity + let validity = array.validity().take(indices)?; + + // Convert our views array into an Arrow u128 ScalarBuffer (16 bytes per view) + let views_buffer = + ScalarBuffer::::from(array.views().into_primitive()?.into_buffer().into_arrow()); + + let indices = indices.clone().into_primitive()?; + + let views_buffer = match_each_integer_ptype!(indices.ptype(), |$I| { + take_views(views_buffer, indices.maybe_null_slice::<$I>()) + }); + + // Cast views back to u8 + let views_array = PrimitiveArray::new( + views_buffer.into_inner().into(), + PType::U8, + Validity::NonNullable, + ); + + Ok(VarBinViewArray::try_new( + views_array.into_array(), + array.buffers().collect_vec(), + array.dtype().clone(), + validity, + )? + .into_array()) + } + + unsafe fn take_unchecked( &self, array: &VarBinViewArray, indices: &ArrayData, - options: TakeOptions, ) -> VortexResult { // Compute the new validity - let validity = array.validity().take(indices, options)?; + let validity = array.validity().take(indices)?; // Convert our views array into an Arrow u128 ScalarBuffer (16 bytes per view) let views_buffer = @@ -74,11 +103,7 @@ impl TakeFn for VarBinViewEncoding { let indices = indices.clone().into_primitive()?; let views_buffer = match_each_integer_ptype!(indices.ptype(), |$I| { - if options.skip_bounds_check { - take_views_unchecked(views_buffer, indices.maybe_null_slice::<$I>()) - } else { - take_views(views_buffer, indices.maybe_null_slice::<$I>()) - } + take_views_unchecked(views_buffer, indices.maybe_null_slice::<$I>()) }); // Cast views back to u8 @@ -124,7 +149,7 @@ fn take_views_unchecked>( mod tests { use crate::accessor::ArrayAccessor; use crate::array::{PrimitiveArray, VarBinViewArray}; - use crate::compute::{take, TakeOptions}; + use crate::compute::take; use crate::{ArrayDType, IntoArrayData, IntoArrayVariant}; #[test] @@ -138,12 +163,7 @@ mod tests { Some("six"), ]); - let taken = take( - arr, - PrimitiveArray::from(vec![0, 3]).into_array(), - TakeOptions::default(), - ) - .unwrap(); + let taken = take(arr, PrimitiveArray::from(vec![0, 3]).into_array()).unwrap(); assert!(taken.dtype().is_nullable()); assert_eq!( diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 6d4bc94fcf..43c72d53ee 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -21,7 +21,7 @@ pub use scalar_at::{scalar_at, ScalarAtFn}; pub use scalar_subtract::{subtract_scalar, SubtractScalarFn}; pub use search_sorted::*; pub use slice::{slice, SliceFn}; -pub use take::{take, TakeFn, TakeOptions}; +pub use take::{take, TakeFn}; use crate::ArrayData; diff --git a/vortex-array/src/compute/take.rs b/vortex-array/src/compute/take.rs index 9f9376c0ed..12acf43869 100644 --- a/vortex-array/src/compute/take.rs +++ b/vortex-array/src/compute/take.rs @@ -5,18 +5,26 @@ use crate::encoding::Encoding; use crate::stats::{ArrayStatistics, Stat}; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical}; -#[derive(Default, Debug, Clone, Copy)] -pub struct TakeOptions { - pub skip_bounds_check: bool, -} - pub trait TakeFn { - fn take( - &self, - array: &Array, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult; + /// Create a new array by taking the values from the `array` at the + /// given `indices`. + /// + /// # Panics + /// + /// Using `indices` that are invalid for the given `array` will cause a panic. + fn take(&self, array: &Array, indices: &ArrayData) -> VortexResult; + + /// Create a new array by taking the values from the `array` at the + /// given `indices`. + /// + /// # Safety + /// + /// This take variant will not perform bounds checking on indices, so it is the caller's + /// responsibility to ensure that the `indices` are all valid for the provided `array`. + /// Failure to do so could result in out of bounds memory access or UB. + unsafe fn take_unchecked(&self, array: &Array, indices: &ArrayData) -> VortexResult { + self.take(array, indices) + } } impl TakeFn for E @@ -24,26 +32,20 @@ where E: TakeFn, for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, { - fn take( - &self, - array: &ArrayData, - indices: &ArrayData, - options: TakeOptions, - ) -> VortexResult { + fn take(&self, array: &ArrayData, indices: &ArrayData) -> VortexResult { let array_ref = <&E::Array>::try_from(array)?; let encoding = array .encoding() .as_any() .downcast_ref::() .ok_or_else(|| vortex_err!("Mismatched encoding"))?; - TakeFn::take(encoding, array_ref, indices, options) + TakeFn::take(encoding, array_ref, indices) } } pub fn take( array: impl AsRef, indices: impl AsRef, - mut options: TakeOptions, ) -> VortexResult { let array = array.as_ref(); let indices = indices.as_ref(); @@ -59,27 +61,38 @@ pub fn take( // the filter function since they're typically optimised for this case. // If the indices are all within bounds, we can skip bounds checking. - if indices + let checked_indices = indices .statistics() .get_as::(Stat::Max) - .is_some_and(|max| max < array.len()) - { - options.skip_bounds_check = true; - } + .is_some_and(|max| max < array.len()); // TODO(ngates): if indices min is quite high, we could slice self and offset the indices // such that canonicalize does less work. + // If TakeFn defined for the encoding, delegate to TakeFn. + // If we know from stats that indices are all valid, we can avoid all bounds checks. if let Some(take_fn) = array.encoding().take_fn() { - return take_fn.take(array, indices, options); + return if checked_indices { + // SAFETY: indices are all inbounds per stats. + // TODO(aduffy): this means stats must be trusted, can still trigger UB if stats are bad. + unsafe { take_fn.take_unchecked(array, indices) } + } else { + take_fn.take(array, indices) + }; } // Otherwise, flatten and try again. info!("TakeFn not implemented for {}, flattening", array); let canonical = array.clone().into_canonical()?.into_array(); - canonical + let canonical_take_fn = canonical .encoding() .take_fn() - .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding().id()))? - .take(&canonical, indices, options) + .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding().id()))?; + + if checked_indices { + // SAFETY: indices are known to be in-bound from stats + unsafe { canonical_take_fn.take_unchecked(&canonical, indices) } + } else { + canonical_take_fn.take(&canonical, indices) + } } diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 5581d90916..b91b906b5c 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -9,7 +9,7 @@ use vortex_scalar::Scalar; use crate::array::PrimitiveArray; use crate::compute::{ scalar_at, search_sorted, search_sorted_many, search_sorted_usize, slice, subtract_scalar, - take, try_cast, FilterMask, SearchResult, SearchSortedSide, TakeOptions, + take, try_cast, FilterMask, SearchResult, SearchSortedSide, }; use crate::stats::{ArrayStatistics, Stat}; use crate::validity::Validity; @@ -190,11 +190,7 @@ impl Patches { } let indices = PrimitiveArray::from(coordinate_indices).into_array(); - let values = take( - self.values(), - PrimitiveArray::from(value_indices), - TakeOptions::default(), - )?; + let values = take(self.values(), PrimitiveArray::from(value_indices))?; Ok(Some(Self::new(mask.len(), indices, values))) } @@ -252,7 +248,7 @@ impl Patches { let values_indices = PrimitiveArray::from_vec(values_indices, Validity::NonNullable).into_array(); - let new_values = take(self.values(), values_indices, TakeOptions::default())?; + let new_values = take(self.values(), values_indices)?; Ok(Some(Self::new(indices.len(), new_indices, new_values))) } diff --git a/vortex-array/src/stream/take_rows.rs b/vortex-array/src/stream/take_rows.rs index f252a252de..9ddd86e847 100644 --- a/vortex-array/src/stream/take_rows.rs +++ b/vortex-array/src/stream/take_rows.rs @@ -7,9 +7,7 @@ use vortex_dtype::match_each_integer_ptype; use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::Scalar; -use crate::compute::{ - search_sorted_usize, slice, subtract_scalar, take, SearchSortedSide, TakeOptions, -}; +use crate::compute::{search_sorted_usize, slice, subtract_scalar, take, SearchSortedSide}; use crate::stats::{ArrayStatistics, Stat}; use crate::stream::ArrayStream; use crate::variants::PrimitiveArrayTrait; @@ -95,11 +93,7 @@ impl Stream for TakeRows { let shifted_arr = match_each_integer_ptype!(indices_for_batch.ptype(), |$T| { subtract_scalar(&indices_for_batch.into_array(), &Scalar::from(curr_offset as $T))? }); - return Poll::Ready( - take(&batch, &shifted_arr, TakeOptions::default()) - .map(Some) - .transpose(), - ); + return Poll::Ready(take(&batch, &shifted_arr).map(Some).transpose()); } Poll::Ready(None) diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 8491d39ab1..8021eb3a2e 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -12,7 +12,7 @@ use vortex_error::{ }; use crate::array::{BoolArray, ConstantArray}; -use crate::compute::{filter, scalar_at, slice, take, FilterMask, TakeOptions}; +use crate::compute::{filter, scalar_at, slice, take, FilterMask}; use crate::encoding::Encoding; use crate::patches::Patches; use crate::stats::ArrayStatistics; @@ -196,12 +196,37 @@ impl Validity { } } - pub fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { + pub fn take(&self, indices: &ArrayData) -> VortexResult { match self { Self::NonNullable => Ok(Self::NonNullable), Self::AllValid => Ok(Self::AllValid), Self::AllInvalid => Ok(Self::AllInvalid), - Self::Array(a) => Ok(Self::Array(take(a, indices, options)?)), + Self::Array(a) => Ok(Self::Array(take(a, indices)?)), + } + } + + /// Take the validity buffer at the provided indices. + /// + /// # Safety + /// + /// It is assumed the caller has checked that all indices are <= the length of this validity + /// buffer. + /// + /// Failure to do so may result in UB. + pub unsafe fn take_unchecked(&self, indices: &ArrayData) -> VortexResult { + match self { + Self::NonNullable => Ok(Self::NonNullable), + Self::AllValid => Ok(Self::AllValid), + Self::AllInvalid => Ok(Self::AllInvalid), + Self::Array(a) => { + let taken = if let Some(take_fn) = a.encoding().take_fn() { + unsafe { take_fn.take_unchecked(a, indices) } + } else { + take(a, indices) + }; + + taken.map(Self::Array) + } } } diff --git a/vortex-datafusion/src/memory/plans.rs b/vortex-datafusion/src/memory/plans.rs index 1a05c7640e..a416fc6ff1 100644 --- a/vortex-datafusion/src/memory/plans.rs +++ b/vortex-datafusion/src/memory/plans.rs @@ -20,7 +20,7 @@ use futures::{ready, Stream}; use pin_project::pin_project; use vortex_array::array::ChunkedArray; use vortex_array::arrow::FromArrowArray; -use vortex_array::compute::{take, TakeOptions}; +use vortex_array::compute::take; use vortex_array::{ArrayData, IntoArrayVariant, IntoCanonical}; use vortex_dtype::field::Field; use vortex_error::{vortex_err, vortex_panic, VortexError}; @@ -348,8 +348,7 @@ where // We should find a way to avoid decoding the filter columns and only decode the other // columns, then stitch the StructArray back together from those. let projected_for_output = chunk.project(this.output_projection)?; - let decoded = - take(projected_for_output, &row_indices, TakeOptions::default())?.into_arrow()?; + let decoded = take(projected_for_output, &row_indices)?.into_arrow()?; // Send back a single record batch of the decoded data. let output_batch = RecordBatch::from(decoded.as_struct()); diff --git a/vortex-file/src/read/layouts/chunked.rs b/vortex-file/src/read/layouts/chunked.rs index cb085c5b55..39b1be0027 100644 --- a/vortex-file/src/read/layouts/chunked.rs +++ b/vortex-file/src/read/layouts/chunked.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, OnceLock, RwLock}; use itertools::Itertools; use vortex_array::aliases::hash_map::HashMap; use vortex_array::array::ChunkedArray; -use vortex_array::compute::{scalar_at, take, TakeOptions}; +use vortex_array::compute::{scalar_at, take}; use vortex_array::stats::{stats_from_bitset_bytes, ArrayStatistics as _, Stat}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_dtype::{DType, Nullability, StructDType}; @@ -266,13 +266,7 @@ impl ChunkedLayoutReader { .iter() .map(|x| *x as u64) .collect::>(); - let chunks_prunable = take( - chunk_prunability, - ArrayData::from(layouts), - TakeOptions { - skip_bounds_check: false, - }, - )?; + let chunks_prunable = take(chunk_prunability, ArrayData::from(layouts))?; if !chunks_prunable .statistics() diff --git a/vortex-ipc/benches/ipc_take.rs b/vortex-ipc/benches/ipc_take.rs index 1291fd928a..89c7a3cc48 100644 --- a/vortex-ipc/benches/ipc_take.rs +++ b/vortex-ipc/benches/ipc_take.rs @@ -15,7 +15,7 @@ use futures_util::{pin_mut, TryStreamExt}; use itertools::Itertools; use vortex_array::array::PrimitiveArray; use vortex_array::compress::CompressionStrategy; -use vortex_array::compute::{take, TakeOptions}; +use vortex_array::compute::take; use vortex_array::{Context, IntoArrayData}; use vortex_io::VortexBufReader; use vortex_ipc::stream_reader::StreamArrayReader; @@ -83,7 +83,7 @@ fn ipc_take(c: &mut Criterion) { let reader = stream_reader.into_array_stream(); pin_mut!(reader); let array_view = reader.try_next().await?.unwrap(); - black_box(take(&array_view, indices_ref, TakeOptions::default())) + black_box(take(&array_view, indices_ref)) }); }); }