From 4fde8f7ba43ef5dd74323bf40572c80447333db4 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 12 Dec 2024 11:13:54 +0000 Subject: [PATCH] FoR compare --- .../fastlanes/src/for/compute/compare.rs | 73 +++++++++++++++++++ .../src/for/{compute.rs => compute/mod.rs} | 10 ++- 2 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 encodings/fastlanes/src/for/compute/compare.rs rename encodings/fastlanes/src/for/{compute.rs => compute/mod.rs} (97%) diff --git a/encodings/fastlanes/src/for/compute/compare.rs b/encodings/fastlanes/src/for/compute/compare.rs new file mode 100644 index 0000000000..f9c5137d80 --- /dev/null +++ b/encodings/fastlanes/src/for/compute/compare.rs @@ -0,0 +1,73 @@ +use num_traits::{CheckedShl, WrappingSub}; +use vortex_array::array::ConstantArray; +use vortex_array::compute::{compare, CompareFn, Operator}; +use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData}; +use vortex_dtype::{match_each_integer_ptype, NativePType}; +use vortex_error::{vortex_err, VortexError, VortexResult}; +use vortex_scalar::{PValue, PrimitiveScalar, Scalar}; + +use crate::{FoRArray, FoREncoding}; + +impl CompareFn for FoREncoding { + fn compare( + &self, + lhs: &FoRArray, + rhs: &ArrayData, + operator: Operator, + ) -> VortexResult> { + if let Some(constant) = rhs.as_constant() { + if let Ok(constant) = PrimitiveScalar::try_from(&constant) { + match_each_integer_ptype!(constant.ptype(), |$T| { + return compare_constant(lhs, constant.typed_value::<$T>(), operator); + }) + } + } + + Ok(None) + } +} + +fn compare_constant( + lhs: &FoRArray, + rhs: Option, + operator: Operator, +) -> VortexResult> +where + T: NativePType + WrappingSub + CheckedShl, + T: TryFrom, + Scalar: From>, +{ + // For now, we only support equals and not equals. Comparisons are a little more fiddly to + // get right regarding how to handle overflow and the wrapping subtraction. + if !matches!(operator, Operator::Eq | Operator::NotEq) { + return Ok(None); + } + + let reference = lhs.reference_scalar(); + let reference = reference.as_primitive().typed_value::(); + + // We encode the RHS into the FoR domain. + let rhs = rhs + .map(|mut rhs| { + if let Some(reference) = reference { + rhs = rhs.wrapping_sub(&reference); + } + if lhs.shift() > 0 { + rhs.checked_shl(lhs.shift() as u32) + .ok_or_else(|| vortex_err!("Shift overflow"))?; + } + Ok::<_, VortexError>(rhs) + }) + .transpose()?; + + // Wrap up the RHS into a scalar and cast to the encoded DType (this will be the equivalent + // unsigned integer type). + let rhs = Scalar::from(rhs).cast(lhs.encoded().dtype())?; + + compare( + lhs.encoded(), + ConstantArray::new(rhs, lhs.len()).into_array(), + operator, + ) + .map(Some) +} diff --git a/encodings/fastlanes/src/for/compute.rs b/encodings/fastlanes/src/for/compute/mod.rs similarity index 97% rename from encodings/fastlanes/src/for/compute.rs rename to encodings/fastlanes/src/for/compute/mod.rs index 87e269353c..30ef14a860 100644 --- a/encodings/fastlanes/src/for/compute.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -1,9 +1,11 @@ +mod compare; + use std::ops::AddAssign; use num_traits::{CheckedShl, CheckedShr, WrappingAdd, WrappingSub}; use vortex_array::compute::{ - filter, scalar_at, search_sorted, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, - SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, + filter, scalar_at, search_sorted, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, + ScalarAtFn, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -14,6 +16,10 @@ use vortex_scalar::{PValue, Scalar}; use crate::{FoRArray, FoREncoding}; impl ComputeVTable for FoREncoding { + fn compare_fn(&self) -> Option<&dyn CompareFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) }