From c20c4d409ef14cedebc982a01054b6058777d1e5 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 18 Dec 2024 16:48:06 -0500 Subject: [PATCH 1/2] fix: binary_numeric is now correct (+ tests!) Lessons learned: do not merge without at least basic tests. We check for a constant array but do not create a new constant array with the appropriate length for the child on which we want to delegate the binary numeric operation. That is now fixed. I also added a very basic test suite for applying the (now six) numeric operations on pairs of arrays where one of the two are constant. In order to properly support constant arrays on the left- hand-side, I added FlippedSub and FlippedDiv which allow us to commute/flip the operator so that we need not teach ConstantArray to check if its right-hand-side supports an operation on constants. --- Cargo.lock | 1 + encodings/dict/src/compute/binary_numeric.rs | 29 +++++ encodings/dict/src/compute/mod.rs | 44 ++++--- encodings/runend/Cargo.toml | 3 + .../runend/src/compute/binary_numeric.rs | 31 +++++ encodings/runend/src/compute/mod.rs | 36 ++---- vortex-array/Cargo.toml | 2 + .../array/chunked/compute/binary_numeric.rs | 28 +++++ vortex-array/src/array/chunked/compute/mod.rs | 1 + vortex-array/src/array/chunked/mod.rs | 39 ++----- .../array/constant/compute/binary_numeric.rs | 2 +- vortex-array/src/array/null/compute.rs | 20 +--- .../array/sparse/compute/binary_numeric.rs | 12 +- vortex-array/src/array/sparse/compute/mod.rs | 6 + vortex-array/src/compute/binary_numeric.rs | 110 +++++++++++++++++- vortex-array/src/compute/mod.rs | 6 +- vortex-scalar/src/primitive.rs | 42 ++++++- 17 files changed, 306 insertions(+), 106 deletions(-) create mode 100644 encodings/dict/src/compute/binary_numeric.rs create mode 100644 encodings/runend/src/compute/binary_numeric.rs create mode 100644 vortex-array/src/array/chunked/compute/binary_numeric.rs diff --git a/Cargo.lock b/Cargo.lock index 328978b6d0..b12ced2a54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4928,6 +4928,7 @@ dependencies = [ "rstest", "serde", "static_assertions", + "vortex-array", "vortex-buffer", "vortex-datetime-dtype", "vortex-dtype", diff --git a/encodings/dict/src/compute/binary_numeric.rs b/encodings/dict/src/compute/binary_numeric.rs new file mode 100644 index 0000000000..e41f8a40ee --- /dev/null +++ b/encodings/dict/src/compute/binary_numeric.rs @@ -0,0 +1,29 @@ +use vortex_array::array::ConstantArray; +use vortex_array::compute::{binary_numeric, BinaryNumericFn}; +use vortex_array::{ArrayData, IntoArrayData}; +use vortex_error::VortexResult; +use vortex_scalar::BinaryNumericOperator; + +use crate::{DictArray, DictEncoding}; + +impl BinaryNumericFn for DictEncoding { + fn binary_numeric( + &self, + array: &DictArray, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + let Some(rhs_scalar) = rhs.as_constant() else { + return Ok(None); + }; + + let rhs_const_array = ConstantArray::new(rhs_scalar, array.values().len()).into_array(); + + DictArray::try_new( + array.codes(), + binary_numeric(&array.values(), &rhs_const_array, op)?, + ) + .map(IntoArrayData::into_array) + .map(Some) + } +} diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 09ee3e9eea..8b9429123e 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -1,13 +1,14 @@ +mod binary_numeric; mod compare; mod like; use vortex_array::compute::{ - binary_numeric, filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable, - FilterFn, FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn, + filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable, FilterFn, + FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; -use vortex_scalar::{BinaryNumericOperator, Scalar}; +use vortex_scalar::Scalar; use crate::{DictArray, DictEncoding}; @@ -41,23 +42,6 @@ impl ComputeVTable for DictEncoding { } } -impl BinaryNumericFn for DictEncoding { - fn binary_numeric( - &self, - array: &DictArray, - rhs: &ArrayData, - op: BinaryNumericOperator, - ) -> VortexResult> { - if !rhs.is_constant() { - return Ok(None); - } - - DictArray::try_new(array.codes(), binary_numeric(&array.values(), rhs, op)?) - .map(IntoArrayData::into_array) - .map(Some) - } -} - impl ScalarAtFn for DictEncoding { fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult { let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?; @@ -94,8 +78,9 @@ impl SliceFn for DictEncoding { mod test { use vortex_array::accessor::ArrayAccessor; use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray}; + use vortex_array::compute::binary_numeric::test_harness::test_binary_numeric; use vortex_array::compute::{compare, scalar_at, slice, Operator}; - use vortex_array::{ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; + use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability}; use vortex_scalar::Scalar; @@ -143,8 +128,7 @@ mod test { ); } - #[test] - fn compare_sliced_dict() { + fn sliced_dict_array() -> ArrayData { let reference = PrimitiveArray::from_nullable_vec(vec![ Some(42), Some(-9), @@ -155,8 +139,14 @@ mod test { ]); let (codes, values) = dict_encode_primitive(&reference); let dict = DictArray::try_new(codes.into_array(), values.into_array()).unwrap(); - let sliced = slice(dict, 1, 4).unwrap(); + slice(dict, 1, 4).unwrap() + } + + #[test] + fn compare_sliced_dict() { + let sliced = sliced_dict_array(); let compared = compare(sliced, ConstantArray::new(42, 3), Operator::Eq).unwrap(); + assert_eq!( scalar_at(&compared, 0).unwrap(), Scalar::bool(false, Nullability::Nullable) @@ -170,4 +160,10 @@ mod test { Scalar::bool(true, Nullability::Nullable) ); } + + #[test] + fn test_dict_binary_numeric() { + let array = sliced_dict_array(); + test_binary_numeric::(array) + } } diff --git a/encodings/runend/Cargo.toml b/encodings/runend/Cargo.toml index e440a17fa1..1246a20ce9 100644 --- a/encodings/runend/Cargo.toml +++ b/encodings/runend/Cargo.toml @@ -23,5 +23,8 @@ vortex-dtype = { workspace = true } vortex-error = { workspace = true } vortex-scalar = { workspace = true } +[dev-dependencies] +vortex-array = { workspace = true, features = ["test-harness"] } + [lints] workspace = true diff --git a/encodings/runend/src/compute/binary_numeric.rs b/encodings/runend/src/compute/binary_numeric.rs new file mode 100644 index 0000000000..6805423a76 --- /dev/null +++ b/encodings/runend/src/compute/binary_numeric.rs @@ -0,0 +1,31 @@ +use vortex_array::array::ConstantArray; +use vortex_array::compute::{binary_numeric, BinaryNumericFn}; +use vortex_array::{ArrayData, ArrayLen, IntoArrayData}; +use vortex_error::VortexResult; +use vortex_scalar::BinaryNumericOperator; + +use crate::{RunEndArray, RunEndEncoding}; + +impl BinaryNumericFn for RunEndEncoding { + fn binary_numeric( + &self, + array: &RunEndArray, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + let Some(rhs_scalar) = rhs.as_constant() else { + return Ok(None); + }; + + let rhs_const_array = ConstantArray::new(rhs_scalar, array.values().len()).into_array(); + + RunEndArray::with_offset_and_length( + array.ends(), + binary_numeric(&array.values(), &rhs_const_array, op)?, + array.offset(), + array.len(), + ) + .map(IntoArrayData::into_array) + .map(Some) + } +} diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index 974bc82520..ef545c942a 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -1,3 +1,4 @@ +mod binary_numeric; mod compare; mod fill_null; mod invert; @@ -9,14 +10,14 @@ use std::ops::AddAssign; use num_traits::AsPrimitive; use vortex_array::array::{BooleanBuffer, PrimitiveArray}; use vortex_array::compute::{ - binary_numeric, filter, scalar_at, slice, BinaryNumericFn, CompareFn, ComputeVTable, - FillNullFn, FilterFn, FilterMask, InvertFn, ScalarAtFn, SliceFn, TakeFn, + filter, scalar_at, slice, BinaryNumericFn, CompareFn, ComputeVTable, FillNullFn, FilterFn, + FilterMask, InvertFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType}; use vortex_error::{VortexResult, VortexUnwrap}; -use vortex_scalar::{BinaryNumericOperator, Scalar}; +use vortex_scalar::Scalar; use crate::{RunEndArray, RunEndEncoding}; @@ -54,28 +55,6 @@ impl ComputeVTable for RunEndEncoding { } } -impl BinaryNumericFn for RunEndEncoding { - fn binary_numeric( - &self, - array: &RunEndArray, - rhs: &ArrayData, - op: BinaryNumericOperator, - ) -> VortexResult> { - if !rhs.is_constant() { - return Ok(None); - } - - RunEndArray::with_offset_and_length( - array.ends(), - binary_numeric(&array.values(), rhs, op)?, - array.offset(), - array.len(), - ) - .map(IntoArrayData::into_array) - .map(Some) - } -} - impl ScalarAtFn for RunEndEncoding { fn scalar_at(&self, array: &RunEndArray, index: usize) -> VortexResult { scalar_at(array.values(), array.find_physical_index(index)?) @@ -163,6 +142,7 @@ fn filter_run_ends + AsPrimitive>( #[cfg(test)] mod test { use vortex_array::array::PrimitiveArray; + use vortex_array::compute::binary_numeric::test_harness::test_binary_numeric; use vortex_array::compute::{filter, scalar_at, slice, FilterMask}; use vortex_array::{ArrayDType, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability, PType}; @@ -345,4 +325,10 @@ mod test { [1, 4, 2] ); } + + #[test] + fn test_runend_binary_numeric() { + let array = ree_array().into_array(); + test_binary_numeric::(array) + } } diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 6a61606750..8927380b0b 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -58,6 +58,7 @@ vortex-scalar = { workspace = true, features = ["flatbuffers", "serde"] } [features] arbitrary = ["dep:arbitrary", "vortex-dtype/arbitrary"] canonical_counter = [] +test-harness = [] [target.'cfg(target_arch = "wasm32")'.dependencies] # Enable the JS feature of getrandom (via rand) to supprt wasm32 target @@ -66,6 +67,7 @@ getrandom = { workspace = true, features = ["js"] } [dev-dependencies] criterion = { workspace = true } rstest = { workspace = true } +vortex-array = { workspace = true, features = ["test-harness"] } [[bench]] name = "search_sorted" diff --git a/vortex-array/src/array/chunked/compute/binary_numeric.rs b/vortex-array/src/array/chunked/compute/binary_numeric.rs new file mode 100644 index 0000000000..7531092be2 --- /dev/null +++ b/vortex-array/src/array/chunked/compute/binary_numeric.rs @@ -0,0 +1,28 @@ +use vortex_error::VortexResult; +use vortex_scalar::BinaryNumericOperator; + +use crate::array::{ChunkedArray, ChunkedEncoding}; +use crate::compute::{binary_numeric, slice, BinaryNumericFn}; +use crate::{ArrayDType as _, ArrayData, IntoArrayData}; + +impl BinaryNumericFn for ChunkedEncoding { + fn binary_numeric( + &self, + array: &ChunkedArray, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + let mut start = 0; + + let mut new_chunks = Vec::with_capacity(array.nchunks()); + for chunk in array.chunks() { + let end = start + chunk.len(); + new_chunks.push(binary_numeric(&chunk, &slice(rhs, start, end)?, op)?); + start = end; + } + + ChunkedArray::try_new(new_chunks, array.dtype().clone()) + .map(IntoArrayData::into_array) + .map(Some) + } +} diff --git a/vortex-array/src/array/chunked/compute/mod.rs b/vortex-array/src/array/chunked/compute/mod.rs index 242082ae03..675dacfd15 100644 --- a/vortex-array/src/array/chunked/compute/mod.rs +++ b/vortex-array/src/array/chunked/compute/mod.rs @@ -9,6 +9,7 @@ use crate::compute::{ }; use crate::{ArrayData, IntoArrayData}; +mod binary_numeric; mod boolean; mod compare; mod fill_null; diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index 32ed452964..14921f7cf7 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -9,12 +9,9 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use vortex_dtype::{DType, Nullability, PType}; use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult, VortexUnwrap}; -use vortex_scalar::BinaryNumericOperator; use crate::array::primitive::PrimitiveArray; -use crate::compute::{ - binary_numeric, scalar_at, search_sorted_usize, slice, BinaryNumericFn, SearchSortedSide, -}; +use crate::compute::{scalar_at, search_sorted_usize, SearchSortedSide}; use crate::encoding::ids; use crate::iter::{ArrayIterator, ArrayIteratorAdapter}; use crate::stats::StatsSet; @@ -234,35 +231,14 @@ impl ValidityVTable for ChunkedEncoding { } } -impl BinaryNumericFn for ChunkedEncoding { - fn binary_numeric( - &self, - array: &ChunkedArray, - rhs: &ArrayData, - op: BinaryNumericOperator, - ) -> VortexResult> { - let mut start = 0; - - let mut new_chunks = Vec::with_capacity(array.nchunks()); - for chunk in array.chunks() { - let end = start + chunk.len(); - new_chunks.push(binary_numeric(&chunk, &slice(rhs, start, end)?, op)?); - start = end; - } - - ChunkedArray::try_new(new_chunks, array.dtype().clone()) - .map(IntoArrayData::into_array) - .map(Some) - } -} - #[cfg(test)] mod test { use vortex_dtype::{DType, Nullability, PType}; use vortex_error::VortexResult; use crate::array::chunked::ChunkedArray; - use crate::compute::{scalar_at, sub_scalar}; + use crate::compute::binary_numeric::test_harness::test_binary_numeric; + use crate::compute::{scalar_at, sub_scalar, try_cast}; use crate::{assert_arrays_eq, ArrayDType, IntoArrayData, IntoArrayVariant}; fn chunked_array() -> ChunkedArray { @@ -374,4 +350,13 @@ mod test { assert_eq!(rechunked.nchunks(), 4); assert_arrays_eq!(chunked, rechunked); } + + #[test] + fn test_chunked_binary_numeric() { + let array = chunked_array().into_array(); + // The tests test both X - 1 and 1 - X, so we need signed values + let signed_dtype = DType::from(PType::try_from(array.dtype()).unwrap().to_signed()); + let array = try_cast(array, &signed_dtype).unwrap(); + test_binary_numeric::(array) + } } diff --git a/vortex-array/src/array/constant/compute/binary_numeric.rs b/vortex-array/src/array/constant/compute/binary_numeric.rs index b7994b2167..aaf8bc26e4 100644 --- a/vortex-array/src/array/constant/compute/binary_numeric.rs +++ b/vortex-array/src/array/constant/compute/binary_numeric.rs @@ -21,7 +21,7 @@ impl BinaryNumericFn for ConstantEncoding { array .scalar() .as_primitive() - .checked_numeric_operator(rhs.as_primitive(), op)? + .checked_binary_numeric(rhs.as_primitive(), op)? .ok_or_else(|| vortex_err!("numeric overflow"))?, array.len(), ) diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index dbe3ca459f..40938d8560 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -1,10 +1,10 @@ use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, VortexResult}; -use vortex_scalar::{BinaryNumericOperator, Scalar}; +use vortex_scalar::Scalar; use crate::array::null::NullArray; use crate::array::NullEncoding; -use crate::compute::{BinaryNumericFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; @@ -13,10 +13,6 @@ impl ComputeVTable for NullEncoding { Some(self) } - fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { - Some(self) - } - fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } @@ -26,18 +22,6 @@ impl ComputeVTable for NullEncoding { } } -impl BinaryNumericFn for NullEncoding { - fn binary_numeric( - &self, - array: &NullArray, - _rhs: &ArrayData, - _op: BinaryNumericOperator, - ) -> VortexResult> { - // for any arithmetic operation, forall X. NULL op X = NULL - Ok(Some(NullArray::new(array.len()).into_array())) - } -} - impl SliceFn for NullEncoding { fn slice(&self, _array: &NullArray, start: usize, stop: usize) -> VortexResult { Ok(NullArray::new(stop - start).into_array()) diff --git a/vortex-array/src/array/sparse/compute/binary_numeric.rs b/vortex-array/src/array/sparse/compute/binary_numeric.rs index 50d1e686e8..0806d3cf54 100644 --- a/vortex-array/src/array/sparse/compute/binary_numeric.rs +++ b/vortex-array/src/array/sparse/compute/binary_numeric.rs @@ -1,7 +1,7 @@ use vortex_error::{vortex_err, VortexResult}; use vortex_scalar::BinaryNumericOperator; -use crate::array::{SparseArray, SparseEncoding}; +use crate::array::{ConstantArray, SparseArray, SparseEncoding}; use crate::compute::{binary_numeric, BinaryNumericFn}; use crate::{ArrayData, ArrayLen as _, IntoArrayData}; @@ -16,13 +16,15 @@ impl BinaryNumericFn for SparseEncoding { return Ok(None); }; - let new_patches = array - .patches() - .map_values(|values| binary_numeric(&values, rhs, op))?; + let new_patches = array.patches().map_values(|values| { + let rhs_const_array = ConstantArray::new(rhs_scalar.clone(), values.len()).into_array(); + + binary_numeric(&values, &rhs_const_array, op) + })?; let new_fill_value = array .fill_scalar() .as_primitive() - .checked_numeric_operator(rhs_scalar.as_primitive(), op)? + .checked_binary_numeric(rhs_scalar.as_primitive(), op)? .ok_or_else(|| vortex_err!("numeric overflow"))?; SparseArray::try_new_from_patches( new_patches, diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index fa4765632a..3a699386f1 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -108,6 +108,7 @@ mod test { use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; + use crate::compute::binary_numeric::test_harness::test_binary_numeric; use crate::compute::{ filter, search_sorted, slice, FilterMask, SearchResult, SearchSortedSide, }; @@ -218,4 +219,9 @@ mod test { assert_eq!(primitive.maybe_null_slice::(), &[1, 3]); } + + #[rstest] + fn test_sparse_binary_numeric(array: ArrayData) { + test_binary_numeric::(array) + } } diff --git a/vortex-array/src/compute/binary_numeric.rs b/vortex-array/src/compute/binary_numeric.rs index e0a1889d06..e00c22e8d5 100644 --- a/vortex-array/src/compute/binary_numeric.rs +++ b/vortex-array/src/compute/binary_numeric.rs @@ -101,7 +101,11 @@ pub fn binary_numeric( op: BinaryNumericOperator, ) -> VortexResult { if lhs.len() != rhs.len() { - vortex_bail!("Numeric operations aren't supported on arrays of different lengths") + vortex_bail!( + "Numeric operations aren't supported on arrays of different lengths {} {}", + lhs.len(), + rhs.len() + ) } if !matches!(lhs.dtype(), DType::Primitive(_, _)) || !matches!(rhs.dtype(), DType::Primitive(_, _)) @@ -138,7 +142,7 @@ pub fn binary_numeric( // Check if RHS supports the operation directly. if let Some(fun) = rhs.encoding().binary_numeric_fn() { - if let Some(result) = fun.binary_numeric(rhs, lhs, op)? { + if let Some(result) = fun.binary_numeric(rhs, lhs, op.flip_parameters())? { debug_assert_eq!( result.len(), lhs.len(), @@ -186,15 +190,115 @@ fn arrow_numeric( let array = match operator { BinaryNumericOperator::Add => arrow_arith::numeric::add(&lhs, &rhs)?, BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&lhs, &rhs)?, - BinaryNumericOperator::Div => arrow_arith::numeric::div(&lhs, &rhs)?, + BinaryNumericOperator::FlippedSub => arrow_arith::numeric::sub(&rhs, &lhs)?, BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&lhs, &rhs)?, + BinaryNumericOperator::Div => arrow_arith::numeric::div(&lhs, &rhs)?, + BinaryNumericOperator::FlippedDiv => arrow_arith::numeric::div(&rhs, &lhs)?, }; Ok(ArrayData::from_arrow(Arc::new(array) as ArrayRef, nullable)) } +#[cfg(feature = "test-harness")] +pub mod test_harness { + use num_traits::Num; + use vortex_dtype::NativePType; + use vortex_error::{vortex_err, VortexResult}; + use vortex_scalar::{BinaryNumericOperator, Scalar}; + + use crate::array::ConstantArray; + use crate::compute::{binary_numeric, scalar_at}; + use crate::{ArrayDType as _, ArrayData, IntoArrayData as _, IntoCanonical as _}; + + #[allow(clippy::unwrap_used)] + fn to_vec_of_scalar(array: &ArrayData) -> Vec { + // Not fast, but obviously correct + (0..array.len()) + .map(|index| scalar_at(array, index)) + .collect::>>() + .unwrap() + } + + #[allow(clippy::unwrap_used)] + pub fn test_binary_numeric(array: ArrayData) + where + Scalar: From, + { + let canonicalized_array = array + .clone() + .into_canonical() + .unwrap() + .into_primitive() + .unwrap(); + let original_values = to_vec_of_scalar(&canonicalized_array.into_array()); + + let one = T::from(1) + .ok_or_else(|| vortex_err!("could not convert 1 into array native type")) + .unwrap(); + let scalar_one = Scalar::from(one).cast(array.dtype()).unwrap(); + + let operators: [BinaryNumericOperator; 6] = [ + BinaryNumericOperator::Add, + BinaryNumericOperator::Sub, + BinaryNumericOperator::FlippedSub, + BinaryNumericOperator::Mul, + BinaryNumericOperator::Div, + BinaryNumericOperator::FlippedDiv, + ]; + + for operator in operators { + assert_eq!( + to_vec_of_scalar( + &binary_numeric( + &array, + &ConstantArray::new(scalar_one.clone(), array.len()).into_array(), + operator + ) + .unwrap() + ), + original_values + .iter() + .map(|x| x + .as_primitive() + .checked_binary_numeric(scalar_one.as_primitive(), operator) + .unwrap() + .unwrap()) + .collect::>(), + "(Constant array of {}) {} ({}) did not produce expected results", + scalar_one, + operator.math_symbol(), + array, + ); + + assert_eq!( + to_vec_of_scalar( + &binary_numeric( + &ConstantArray::new(scalar_one.clone(), array.len()).into_array(), + &array, + operator + ) + .unwrap() + ), + original_values + .iter() + .map(|x| scalar_one + .as_primitive() + .checked_binary_numeric(x.as_primitive(), operator) + .unwrap() + .unwrap()) + .collect::>(), + "(Constant array of {}) {} ({}) did not produce expected results", + scalar_one, + operator.math_symbol(), + array, + ); + } + } +} + #[cfg(test)] mod test { + use vortex_scalar::Scalar; use crate::array::PrimitiveArray; diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 92f39021bf..b954d2028a 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -7,7 +7,9 @@ //! implementations of these operators, else we will decode, and perform the equivalent operator //! from Arrow. -pub use binary_numeric::*; +pub use binary_numeric::{ + add_scalar, binary_numeric, div_scalar, mul_scalar, sub_scalar, BinaryNumericFn, +}; pub use boolean::{ and, and_kleene, binary_boolean, or, or_kleene, BinaryBooleanFn, BinaryOperator, }; @@ -25,7 +27,7 @@ pub use take::{take, TakeFn}; use crate::ArrayData; -mod binary_numeric; +pub mod binary_numeric; mod boolean; mod cast; mod compare; diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index feebc731a7..930523f44e 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -274,17 +274,51 @@ impl From for Scalar { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Binary element-wise operations on two arrays or two scalars. pub enum BinaryNumericOperator { + /// Binary element-wise addition of two arrays or of two scalars. Add, + /// Binary element-wise subtraction of two arrays or of two scalars. Sub, + /// Same as [BinaryNumericOperator::Sub] but with the parameters flipped: `right - left`. + FlippedSub, + /// Binary element-wise multiplication of two arrays or of two scalars. Mul, + /// Binary element-wise division of two arrays or of two scalars. Div, + /// Same as [BinaryNumericOperator::Div] but with the parameters flipped: `right - left`. + FlippedDiv, // Missing from arrow-rs: // Min, // Max, // Pow, } +impl BinaryNumericOperator { + pub fn flip_parameters(self) -> Self { + match self { + BinaryNumericOperator::Add => BinaryNumericOperator::Add, + BinaryNumericOperator::Sub => BinaryNumericOperator::FlippedSub, + BinaryNumericOperator::FlippedSub => BinaryNumericOperator::Sub, + BinaryNumericOperator::Mul => BinaryNumericOperator::Mul, + BinaryNumericOperator::Div => BinaryNumericOperator::FlippedDiv, + BinaryNumericOperator::FlippedDiv => BinaryNumericOperator::Div, + } + } + + pub fn math_symbol(&self) -> String { + match self { + BinaryNumericOperator::Add => "+", + BinaryNumericOperator::Sub => "-", + BinaryNumericOperator::FlippedSub => "+", + BinaryNumericOperator::Mul => "*", + BinaryNumericOperator::Div => "/", + BinaryNumericOperator::FlippedDiv => "/", + } + .to_string() + } +} + impl PrimitiveScalar<'_> { /// Apply the (checked) operator to self and other using SQL-style null semantics. /// @@ -293,7 +327,7 @@ impl PrimitiveScalar<'_> { /// If the types are incompatible (ignoring nullability), an error is returned. /// /// If either value is null, the result is null. - pub fn checked_numeric_operator( + pub fn checked_binary_numeric( self, other: PrimitiveScalar<'_>, op: BinaryNumericOperator, @@ -317,10 +351,14 @@ impl PrimitiveScalar<'_> { lhs.checked_add(rhs).map(|result| Scalar::primitive(result, nullability)), BinaryNumericOperator::Sub => lhs.checked_sub(rhs).map(|result| Scalar::primitive(result, nullability)), + BinaryNumericOperator::FlippedSub => + rhs.checked_sub(lhs).map(|result| Scalar::primitive(result, nullability)), BinaryNumericOperator::Mul => lhs.checked_mul(rhs).map(|result| Scalar::primitive(result, nullability)), BinaryNumericOperator::Div => lhs.checked_div(rhs).map(|result| Scalar::primitive(result, nullability)), + BinaryNumericOperator::FlippedDiv => + rhs.checked_div(lhs).map(|result| Scalar::primitive(result, nullability)), } } } @@ -332,8 +370,10 @@ impl PrimitiveScalar<'_> { (Some(lhs), Some(rhs)) => match op { BinaryNumericOperator::Add => Scalar::primitive(lhs + rhs, nullability), BinaryNumericOperator::Sub => Scalar::primitive(lhs - rhs, nullability), + BinaryNumericOperator::FlippedSub => Scalar::primitive(rhs - lhs, nullability), BinaryNumericOperator::Mul => Scalar::primitive(lhs - rhs, nullability), BinaryNumericOperator::Div => Scalar::primitive(lhs - rhs, nullability), + BinaryNumericOperator::FlippedDiv => Scalar::primitive(rhs - lhs, nullability), } }) } From 7b84989ccd6a37be0eeacf01c54f10f1be48b8ea Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 18 Dec 2024 17:56:49 -0500 Subject: [PATCH 2/2] need test harness in dict too --- encodings/dict/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/encodings/dict/Cargo.toml b/encodings/dict/Cargo.toml index f45593a9a7..a1d5d2143c 100644 --- a/encodings/dict/Cargo.toml +++ b/encodings/dict/Cargo.toml @@ -29,6 +29,7 @@ workspace = true [dev-dependencies] criterion = { workspace = true } rand = { workspace = true } +vortex-array = { workspace = true, features = ["test-harness"] } [[bench]] name = "dict_compress"