Skip to content

Commit

Permalink
chore: add an unsafe take_unchecked to TakeFn for bounds check ellisi…
Browse files Browse the repository at this point in the history
…on (#1611)
  • Loading branch information
a10y authored Dec 9, 2024
1 parent 69ef100 commit 173d499
Show file tree
Hide file tree
Showing 39 changed files with 298 additions and 519 deletions.
11 changes: 3 additions & 8 deletions encodings/alp/src/alp/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use vortex_array::compute::{
filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn,
TakeFn, TakeOptions,
TakeFn,
};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
Expand Down Expand Up @@ -48,15 +48,10 @@ impl ScalarAtFn<ALPArray> for ALPEncoding {
}

impl TakeFn<ALPArray> for ALPEncoding {
fn take(
&self,
array: &ALPArray,
indices: &ArrayData,
options: TakeOptions,
) -> VortexResult<ArrayData> {
fn take(&self, array: &ALPArray, indices: &ArrayData) -> VortexResult<ArrayData> {
// TODO(ngates): wrap up indices in an array that caches decompression?
Ok(ALPArray::try_new(
take(array.encoded(), indices, options)?,
take(array.encoded(), indices)?,
array.exponents(),
array
.patches()
Expand Down
29 changes: 10 additions & 19 deletions encodings/alp/src/alp_rd/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
use vortex_array::compute::{take, TakeFn, TakeOptions};
use vortex_array::compute::{take, TakeFn};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
use vortex_error::VortexResult;

use crate::{ALPRDArray, ALPRDEncoding};

impl TakeFn<ALPRDArray> for ALPRDEncoding {
fn take(
&self,
array: &ALPRDArray,
indices: &ArrayData,
options: TakeOptions,
) -> VortexResult<ArrayData> {
fn take(&self, array: &ALPRDArray, indices: &ArrayData) -> VortexResult<ArrayData> {
let left_parts_exceptions = array
.left_parts_exceptions()
.map(|array| take(&array, indices, options))
.map(|array| take(&array, indices))
.transpose()?;

Ok(ALPRDArray::try_new(
array.dtype().clone(),
take(array.left_parts(), indices, options)?,
take(array.left_parts(), indices)?,
array.left_parts_dict(),
take(array.right_parts(), indices, options)?,
take(array.right_parts(), indices)?,
array.right_bit_width(),
left_parts_exceptions,
)?
Expand All @@ -32,7 +27,7 @@ impl TakeFn<ALPRDArray> for ALPRDEncoding {
mod test {
use rstest::rstest;
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::{take, TakeOptions};
use vortex_array::compute::take;
use vortex_array::IntoArrayVariant;

use crate::{ALPRDFloat, RDEncoder};
Expand All @@ -46,14 +41,10 @@ mod test {

assert!(encoded.left_parts_exceptions().is_some());

let taken = take(
encoded.as_ref(),
PrimitiveArray::from(vec![0, 2]).as_ref(),
TakeOptions::default(),
)
.unwrap()
.into_primitive()
.unwrap();
let taken = take(encoded.as_ref(), PrimitiveArray::from(vec![0, 2]).as_ref())
.unwrap()
.into_primitive()
.unwrap();

assert_eq!(taken.maybe_null_slice::<T>(), &[a, outlier]);
}
Expand Down
11 changes: 2 additions & 9 deletions encodings/bytebool/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use num_traits::AsPrimitive;
use vortex_array::compute::{
ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn, TakeOptions,
};
use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::validity::{ArrayValidity, Validity};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
Expand Down Expand Up @@ -49,12 +47,7 @@ impl SliceFn<ByteBoolArray> for ByteBoolEncoding {
}

impl TakeFn<ByteBoolArray> for ByteBoolEncoding {
fn take(
&self,
array: &ByteBoolArray,
indices: &ArrayData,
_options: TakeOptions,
) -> VortexResult<ArrayData> {
fn take(&self, array: &ByteBoolArray, indices: &ArrayData) -> VortexResult<ArrayData> {
let validity = array.validity();
let indices = indices.clone().into_primitive()?;
let bools = array.maybe_null_slice();
Expand Down
15 changes: 5 additions & 10 deletions encodings/datetime-parts/src/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
use vortex_array::compute::{take, TakeFn, TakeOptions};
use vortex_array::compute::{take, TakeFn};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
use vortex_error::VortexResult;

use crate::{DateTimePartsArray, DateTimePartsEncoding};

impl TakeFn<DateTimePartsArray> for DateTimePartsEncoding {
fn take(
&self,
array: &DateTimePartsArray,
indices: &ArrayData,
options: TakeOptions,
) -> VortexResult<ArrayData> {
fn take(&self, array: &DateTimePartsArray, indices: &ArrayData) -> VortexResult<ArrayData> {
Ok(DateTimePartsArray::try_new(
array.dtype().clone(),
take(array.days(), indices, options)?,
take(array.seconds(), indices, options)?,
take(array.subsecond(), indices, options)?,
take(array.days(), indices)?,
take(array.seconds(), indices)?,
take(array.subsecond(), indices)?,
)?
.into_array())
}
Expand Down
6 changes: 3 additions & 3 deletions encodings/dict/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::{Debug, Display};
use arrow_buffer::BooleanBuffer;
use serde::{Deserialize, Serialize};
use vortex_array::array::BoolArray;
use vortex_array::compute::{scalar_at, take, TakeOptions};
use vortex_array::compute::{scalar_at, take};
use vortex_array::encoding::ids;
use vortex_array::stats::StatsSet;
use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
Expand Down Expand Up @@ -74,10 +74,10 @@ impl IntoCanonical for DictArray {
// copies of the view pointers.
DType::Utf8(_) | DType::Binary(_) => {
let canonical_values: ArrayData = self.values().into_canonical()?.into();
take(canonical_values, self.codes(), TakeOptions::default())?.into_canonical()
take(canonical_values, self.codes())?.into_canonical()
}
// Non-string case: take and then canonicalize
_ => take(self.values(), self.codes(), TakeOptions::default())?.into_canonical(),
_ => take(self.values(), self.codes())?.into_canonical(),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions encodings/dict/src/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use vortex_array::array::ConstantArray;
use vortex_array::compute::{compare, take, CompareFn, Operator, TakeOptions};
use vortex_array::compute::{compare, take, CompareFn, Operator};
use vortex_array::ArrayData;
use vortex_error::VortexResult;

Expand All @@ -20,7 +20,7 @@ impl CompareFn<DictArray> for DictEncoding {
ConstantArray::new(const_scalar, lhs.values().len()),
operator,
)?;
return take(compare_result, lhs.codes(), TakeOptions::default()).map(Some);
return take(compare_result, lhs.codes()).map(Some);
}

// It's a little more complex, but we could perform a comparison against the dictionary
Expand Down
11 changes: 3 additions & 8 deletions encodings/dict/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod like;

use vortex_array::compute::{
filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, LikeFn,
ScalarAtFn, SliceFn, TakeFn, TakeOptions,
ScalarAtFn, SliceFn, TakeFn,
};
use vortex_array::{ArrayData, IntoArrayData};
use vortex_error::VortexResult;
Expand Down Expand Up @@ -45,16 +45,11 @@ impl ScalarAtFn<DictArray> for DictEncoding {
}

impl TakeFn<DictArray> for DictEncoding {
fn take(
&self,
array: &DictArray,
indices: &ArrayData,
options: TakeOptions,
) -> VortexResult<ArrayData> {
fn take(&self, array: &DictArray, indices: &ArrayData) -> VortexResult<ArrayData> {
// Dict
// codes: 0 0 1
// dict: a b c d e f g h
let codes = take(array.codes(), indices, options)?;
let codes = take(array.codes(), indices)?;
DictArray::try_new(codes, array.values()).map(|a| a.into_array())
}
}
Expand Down
Loading

0 comments on commit 173d499

Please sign in to comment.