diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index b44147afa3ca..c368a306603c 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -27,7 +27,8 @@ use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; use crate::compute::{ - binary, binary_mut, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, + binary, binary_mut, binary_opt, try_binary, try_binary_mut, try_unary, try_unary_dyn, + unary_dyn, }; use crate::datatypes::{ ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, @@ -733,7 +734,7 @@ where math_op(left, right, |a, b| a.add_wrapping(b)) } -/// Perform `left + right` operation on two arrays while mutating `left` with operation results. +/// Perform `left + right` operation on two arrays by mutating `left` with operation results. /// If either left or right value is null then the result is also null. /// /// This only mutates the array if it is not shared buffers with other arrays. For shared @@ -771,6 +772,28 @@ where try_binary(left, right, |a, b| a.add_checked(b)) } +/// Perform `left + right` operation on two arrays by mutating `left` with operation results. +/// If either left or right value is null then the result is also null. +/// +/// This only mutates the array if it is not shared buffers with other arrays. For shared +/// array, it returns an `Err` which wraps input array. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `add_mut` instead. +pub fn add_checked_mut( + left: PrimitiveArray, + right: &PrimitiveArray, +) -> std::result::Result< + PrimitiveArray, + std::result::Result, ArrowError>, +> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + try_binary_mut(left, right, |a, b| a.add_checked(b)) +} + /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. /// @@ -3120,4 +3143,29 @@ mod tests { assert_eq!(result.len(), 13); assert_eq!(result.null_count(), 13); } + + #[test] + fn test_primitive_array_add_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + + let c = add_mut(a, &b).unwrap(); + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_add_mut_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + + let wrapped = add_mut(a, &b).unwrap(); + let expected = Int32Array::from(vec![-2147483648, -2147483647]); + assert_eq!(expected, wrapped); + + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + let overflow = add_checked_mut(a, &b); + let _ = overflow.expect_err("overflow should be detected"); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 0cedcce64049..152be00e675b 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -251,7 +251,7 @@ where .map(|x| len - x.count_set_bits_offset(0, len)) .unwrap_or_default(); - let mut builder = a.into_builder().map_err(|arr| Ok(arr))?; + let mut builder = a.into_builder().map_err(Ok)?; builder .values_slice_mut() @@ -375,7 +375,7 @@ where .map(|x| len - x.count_set_bits_offset(0, len)) .unwrap_or_default(); - let mut builder = a.into_builder().map_err(|arr| Ok(arr))?; + let mut builder = a.into_builder().map_err(Ok)?; let slice = builder.values_slice_mut(); @@ -386,7 +386,7 @@ where }; Ok::<_, ArrowError>(()) }) - .map_err(|err| Err(err))?; + .map_err(Err)?; let array_builder = builder .finish() @@ -437,14 +437,13 @@ where T: ArrowPrimitiveType, F: Fn(T::Native, T::Native) -> Result, { - let mut builder = a.into_builder().map_err(|arr| Ok(arr))?; + let mut builder = a.into_builder().map_err(Ok)?; let slice = builder.values_slice_mut(); for idx in 0..len { unsafe { *slice.get_unchecked_mut(idx) = - op(*slice.get_unchecked(idx), b.value_unchecked(idx)) - .map_err(|err| Err(err))?; + op(*slice.get_unchecked(idx), b.value_unchecked(idx)).map_err(Err)?; }; } Ok(builder.finish())