From e6b298709025308ffcfbc3c244f089d27051d879 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 19 Nov 2022 11:13:12 -0800 Subject: [PATCH] Add test --- arrow/src/compute/kernels/arithmetic.rs | 52 ++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index b44147afa3ca..c368a306603c 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -27,7 +27,8 @@ use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; use crate::compute::{ - binary, binary_mut, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, + binary, binary_mut, binary_opt, try_binary, try_binary_mut, try_unary, try_unary_dyn, + unary_dyn, }; use crate::datatypes::{ ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, @@ -733,7 +734,7 @@ where math_op(left, right, |a, b| a.add_wrapping(b)) } -/// Perform `left + right` operation on two arrays while mutating `left` with operation results. +/// Perform `left + right` operation on two arrays by mutating `left` with operation results. /// If either left or right value is null then the result is also null. /// /// This only mutates the array if it is not shared buffers with other arrays. For shared @@ -771,6 +772,28 @@ where try_binary(left, right, |a, b| a.add_checked(b)) } +/// Perform `left + right` operation on two arrays by mutating `left` with operation results. +/// If either left or right value is null then the result is also null. +/// +/// This only mutates the array if it is not shared buffers with other arrays. For shared +/// array, it returns an `Err` which wraps input array. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `add_mut` instead. +pub fn add_checked_mut( + left: PrimitiveArray, + right: &PrimitiveArray, +) -> std::result::Result< + PrimitiveArray, + std::result::Result, ArrowError>, +> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + try_binary_mut(left, right, |a, b| a.add_checked(b)) +} + /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. /// @@ -3120,4 +3143,29 @@ mod tests { assert_eq!(result.len(), 13); assert_eq!(result.null_count(), 13); } + + #[test] + fn test_primitive_array_add_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + + let c = add_mut(a, &b).unwrap(); + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_add_mut_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + + let wrapped = add_mut(a, &b).unwrap(); + let expected = Int32Array::from(vec![-2147483648, -2147483647]); + assert_eq!(expected, wrapped); + + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + let overflow = add_checked_mut(a, &b); + let _ = overflow.expect_err("overflow should be detected"); + } }