From 7594db6367515473efdb130e7de91060079a4d88 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 14 Sep 2022 16:54:23 -0700 Subject: [PATCH] Add overflow-checking variants of arithmetic scalar dyn kernels (#2713) * Add overflow-checking variants of arithmetic scalar dyn kernels * Update doc * For review --- arrow/src/compute/kernels/arithmetic.rs | 199 +++++++++++++++++++++--- arrow/src/compute/kernels/arity.rs | 50 +++++- 2 files changed, 226 insertions(+), 23 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index a344407e426d..04fe2393ec4d 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -22,7 +22,7 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. -use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; +use std::ops::{Div, Neg, Rem}; use num::{One, Zero}; @@ -32,7 +32,9 @@ use crate::buffer::Buffer; use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; use crate::compute::util::combine_option_bitmap; -use crate::compute::{binary, binary_opt, try_binary, try_unary, unary_dyn}; +use crate::compute::{ + binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, +}; use crate::datatypes::{ native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, @@ -834,12 +836,39 @@ where /// Add every value in an array by a scalar. If any value in the array is null then the /// result is also null. The given array must be a `PrimitiveArray` of the type same as /// the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead. +/// +/// This returns an `Err` when the input array is not supported for adding operation. pub fn add_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result where T: ArrowNumericType, - T::Native: Add, + T::Native: ArrowNativeTypeOp, +{ + unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar)) +} + +/// Add every value in an array by a scalar. If any value in the array is null then the +/// result is also null. The given array must be a `PrimitiveArray` of the type same as +/// the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `add_scalar_dyn` instead. +/// +/// As this kernel has the branching costs and also prevents LLVM from vectorising it correctly, +/// it is usually much slower than non-checking variant. +pub fn add_scalar_checked_dyn(array: &dyn Array, scalar: T::Native) -> Result +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, { - unary_dyn::<_, T>(array, |value| value + scalar) + try_unary_dyn::<_, T>(array, |value| { + value.add_checked(scalar).ok_or_else(|| { + ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value)) + }) + }) + .map(|a| Arc::new(a) as ArrayRef) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -937,16 +966,40 @@ where /// Subtract every value in an array by a scalar. If any value in the array is null then the /// result is also null. The given array must be a `PrimitiveArray` of the type same as /// the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead. pub fn subtract_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result where - T: datatypes::ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Zero, + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar)) +} + +/// Subtract every value in an array by a scalar. If any value in the array is null then the +/// result is also null. The given array must be a `PrimitiveArray` of the type same as +/// the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `subtract_scalar_dyn` instead. +pub fn subtract_scalar_checked_dyn( + array: &dyn Array, + scalar: T::Native, +) -> Result +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, { - unary_dyn::<_, T>(array, |value| value - scalar) + try_unary_dyn::<_, T>(array, |value| { + value.sub_checked(scalar).ok_or_else(|| { + ArrowError::CastError(format!( + "Overflow: subtracting {:?} from {:?}", + scalar, value + )) + }) + }) + .map(|a| Arc::new(a) as ArrayRef) } /// Perform `-` operation on an array. If value is null then the result is also null. @@ -1065,18 +1118,40 @@ where /// Multiply every value in an array by a scalar. If any value in the array is null then the /// result is also null. The given array must be a `PrimitiveArray` of the type same as /// the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead. pub fn multiply_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result where T: ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Rem - + Zero - + One, + T::Native: ArrowNativeTypeOp, +{ + unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar)) +} + +/// Subtract every value in an array by a scalar. If any value in the array is null then the +/// result is also null. The given array must be a `PrimitiveArray` of the type same as +/// the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `multiply_scalar_dyn` instead. +pub fn multiply_scalar_checked_dyn( + array: &dyn Array, + scalar: T::Native, +) -> Result +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, { - unary_dyn::<_, T>(array, |value| value * scalar) + try_unary_dyn::<_, T>(array, |value| { + value.mul_checked(scalar).ok_or_else(|| { + ArrowError::CastError(format!( + "Overflow: multiplying {:?} by {:?}", + value, scalar + )) + }) + }) + .map(|a| Arc::new(a) as ArrayRef) } /// Perform `left % right` operation on two arrays. If either left or right value is null @@ -1223,15 +1298,48 @@ where /// result is also null. If the scalar is zero then the result of this operation will be /// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type /// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead. pub fn divide_scalar_dyn(array: &dyn Array, divisor: T::Native) -> Result where T: ArrowNumericType, - T::Native: Div + Zero, + T::Native: ArrowNativeTypeOp + Zero, +{ + if divisor.is_zero() { + return Err(ArrowError::DivideByZero); + } + unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor)) +} + +/// Divide every value in an array by a scalar. If any value in the array is null then the +/// result is also null. If the scalar is zero then the result of this operation will be +/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type +/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `divide_scalar_dyn` instead. +pub fn divide_scalar_checked_dyn( + array: &dyn Array, + divisor: T::Native, +) -> Result +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp + Zero, { if divisor.is_zero() { return Err(ArrowError::DivideByZero); } - unary_dyn::<_, T>(array, |value| value / divisor) + + try_unary_dyn::<_, T>(array, |value| { + value.div_checked(divisor).ok_or_else(|| { + ArrowError::CastError(format!( + "Overflow: dividing {:?} by {:?}", + value, divisor + )) + }) + }) + .map(|a| Arc::new(a) as ArrayRef) } #[cfg(test)] @@ -2222,6 +2330,55 @@ mod tests { overflow.expect_err("overflow should be detected"); } + #[test] + fn test_primitive_add_scalar_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + + let wrapped = add_scalar_dyn::(&a, 1).unwrap(); + let expected = + Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = add_scalar_checked_dyn::(&a, 1); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_subtract_scalar_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![-2]); + + let wrapped = subtract_scalar_dyn::(&a, i32::MAX).unwrap(); + let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = subtract_scalar_checked_dyn::(&a, i32::MAX); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_mul_scalar_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![10]); + + let wrapped = multiply_scalar_dyn::(&a, i32::MAX).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = multiply_scalar_checked_dyn::(&a, i32::MAX); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_div_scalar_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MIN]); + + let wrapped = divide_scalar_dyn::(&a, -1).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = divide_scalar_checked_dyn::(&a, -1); + overflow.expect_err("overflow should be detected"); + } + #[test] fn test_primitive_div_opt_overflow_division_by_zero() { let a = Int32Array::from(vec![i32::MIN]); diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index fffa81af8190..21c633116ee0 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -123,7 +123,7 @@ where Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) } -/// A helper function that applies an unary function to a dictionary array with primitive value type. +/// A helper function that applies an infallible unary function to a dictionary array with primitive value type. fn unary_dict(array: &DictionaryArray, op: F) -> Result where K: ArrowNumericType, @@ -138,7 +138,22 @@ where Ok(Arc::new(new_dict)) } -/// Applies an unary function to an array with primitive values. +/// A helper function that applies a fallible unary function to a dictionary array with primitive value type. +fn try_unary_dict(array: &DictionaryArray, op: F) -> Result +where + K: ArrowNumericType, + T: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, +{ + let dict_values = array.values().as_any().downcast_ref().unwrap(); + let values = try_unary::(dict_values, op)?.into_data(); + let data = array.data().clone().into_builder().child_data(vec![values]); + + let new_dict: DictionaryArray = unsafe { data.build_unchecked() }.into(); + Ok(Arc::new(new_dict)) +} + +/// Applies an infallible unary function to an array with primitive values. pub fn unary_dyn(array: &dyn Array, op: F) -> Result where T: ArrowPrimitiveType, @@ -162,6 +177,37 @@ where } } +/// Applies a fallible unary function to an array with primitive values. +pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, +{ + downcast_dictionary_array! { + array => if array.values().data_type() == &T::DATA_TYPE { + try_unary_dict::<_, F, T>(array, op) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation on dictionary array of type {}", + array.data_type() + ))) + }, + t => { + if t == &T::DATA_TYPE { + Ok(Arc::new(try_unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + )?)) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation on array of type {}", + t + ))) + } + } + } +} + /// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting /// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the /// corresponding index in the result will also be null