Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 19, 2022
1 parent 526abd5 commit c1e3f4d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 8 deletions.
52 changes: 50 additions & 2 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use crate::array::*;
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::{
binary, binary_mut, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
binary, binary_mut, binary_opt, try_binary, try_binary_mut, try_unary, try_unary_dyn,
unary_dyn,
};
use crate::datatypes::{
ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
Expand Down Expand Up @@ -733,7 +734,7 @@ where
math_op(left, right, |a, b| a.add_wrapping(b))
}

/// Perform `left + right` operation on two arrays while mutating `left` with operation results.
/// Perform `left + right` operation on two arrays by mutating `left` with operation results.
/// If either left or right value is null then the result is also null.
///
/// This only mutates the array if it is not shared buffers with other arrays. For shared
Expand Down Expand Up @@ -771,6 +772,28 @@ where
try_binary(left, right, |a, b| a.add_checked(b))
}

/// Perform `left + right` operation on two arrays by mutating `left` with operation results.
/// If either left or right value is null then the result is also null.
///
/// This only mutates the array if it is not shared buffers with other arrays. For shared
/// array, it returns an `Err` which wraps input array.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add_mut` instead.
pub fn add_checked_mut<T>(
left: PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> std::result::Result<
PrimitiveArray<T>,
std::result::Result<PrimitiveArray<T>, ArrowError>,
>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_binary_mut(left, right, |a, b| a.add_checked(b))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
Expand Down Expand Up @@ -3120,4 +3143,29 @@ mod tests {
assert_eq!(result.len(), 13);
assert_eq!(result.null_count(), 13);
}

#[test]
fn test_primitive_array_add_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);

let c = add_mut(a, &b).unwrap();
let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
assert_eq!(c, expected);
}

#[test]
fn test_primitive_add_mut_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let b = Int32Array::from(vec![1, 1]);

let wrapped = add_mut(a, &b).unwrap();
let expected = Int32Array::from(vec![-2147483648, -2147483647]);
assert_eq!(expected, wrapped);

let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let b = Int32Array::from(vec![1, 1]);
let overflow = add_checked_mut(a, &b);
let _ = overflow.expect_err("overflow should be detected");
}
}
11 changes: 5 additions & 6 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ where
.map(|x| len - x.count_set_bits_offset(0, len))
.unwrap_or_default();

let mut builder = a.into_builder().map_err(|arr| Ok(arr))?;
let mut builder = a.into_builder().map_err(Ok)?;

builder
.values_slice_mut()
Expand Down Expand Up @@ -375,7 +375,7 @@ where
.map(|x| len - x.count_set_bits_offset(0, len))
.unwrap_or_default();

let mut builder = a.into_builder().map_err(|arr| Ok(arr))?;
let mut builder = a.into_builder().map_err(Ok)?;

let slice = builder.values_slice_mut();

Expand All @@ -386,7 +386,7 @@ where
};
Ok::<_, ArrowError>(())
})
.map_err(|err| Err(err))?;
.map_err(Err)?;

let array_builder = builder
.finish()
Expand Down Expand Up @@ -437,14 +437,13 @@ where
T: ArrowPrimitiveType,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
let mut builder = a.into_builder().map_err(|arr| Ok(arr))?;
let mut builder = a.into_builder().map_err(Ok)?;
let slice = builder.values_slice_mut();

for idx in 0..len {
unsafe {
*slice.get_unchecked_mut(idx) =
op(*slice.get_unchecked(idx), b.value_unchecked(idx))
.map_err(|err| Err(err))?;
op(*slice.get_unchecked(idx), b.value_unchecked(idx)).map_err(Err)?;
};
}
Ok(builder.finish())
Expand Down

0 comments on commit c1e3f4d

Please sign in to comment.