Skip to content

Commit

Permalink
Add add_scalar_mut and add_scalar_checked_mut
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 18, 2022
1 parent 5bce104 commit c98d21a
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 1 deletion.
36 changes: 36 additions & 0 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,42 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
})
}

/// Applies an unary and fallible function to all valid values in a mutable primitive array.
/// Mutable primitive array means that the buffer is not shared with other arrays.
/// As a result, this mutates the buffer directly without allocating new buffer.
///
/// This is unlike [`Self::unary_mut`] which will apply an infallible function to all rows
/// regardless of validity, in many cases this will be significantly faster and should
/// be preferred if `op` is infallible.
///
/// This returns an `Err` for two cases. First is input array is shared buffer with other
/// array. In the case, returned `Err` wraps a `Ok` of input array. Second, if the function
/// encounters an error during applying on values. In the case, returned `Err` wraps an
/// `Err` of the actual error.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
pub fn try_unary_mut<F, E>(
self,
op: F,
) -> Result<PrimitiveArray<T>, Result<PrimitiveArray<T>, E>>
where
F: Fn(T::Native) -> Result<T::Native, E>,
{
let len = self.len();
let null_count = self.null_count();
let mut builder = self.into_builder().map_err(|arr| Ok(arr))?;

let (slice, null_buffer) = builder.as_slice();

try_for_each_valid_idx(len, 0, null_count, null_buffer, |idx| {
unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? };
Ok::<_, E>(())
})
.map_err(|err| Err(err))?;

Ok(builder.finish())
}

/// Applies a unary and nullable function to all valid values in a primitive array
///
/// This is unlike [`Self::unary`] which will apply an infallible function to all rows
Expand Down
5 changes: 5 additions & 0 deletions arrow-array/src/builder/null_buffer_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ impl NullBufferBuilder {
self.bitmap_builder = Some(b);
}
}

#[inline]
pub fn as_slice(&self) -> Option<&[u8]> {
self.bitmap_builder.as_ref().map(|b| b.as_slice())
}
}

impl NullBufferBuilder {
Expand Down
8 changes: 8 additions & 0 deletions arrow-array/src/builder/primitive_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
pub fn values_slice_mut(&mut self) -> &mut [T::Native] {
self.values_builder.as_slice_mut()
}

/// Returns the current values buffer and null buffer as a slice
pub fn as_slice(&mut self) -> (&mut [T::Native], Option<&[u8]>) {
(
self.values_builder.as_slice_mut(),
self.null_buffer_builder.as_slice(),
)
}
}

#[cfg(test)]
Expand Down
66 changes: 65 additions & 1 deletion arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use crate::array::*;
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::{
binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
binary, binary_opt, try_binary, try_unary, try_unary_dyn, try_unary_mut, unary_dyn,
unary_mut,
};
use crate::datatypes::{
ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
Expand Down Expand Up @@ -914,6 +915,47 @@ where
Ok(unary(array, |value| value.add_wrapping(scalar)))
}

/// Mutate an array by adding every value in an array by a scalar. If any value in the array
/// is null then the result is also null.
///
/// This only mutates the array if it is not shared buffers with other arrays. For shared
/// array, it returns an `Err` which wraps input array.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_scalar_checked_mut` instead.
pub fn add_scalar_mut<T>(
array: PrimitiveArray<T>,
scalar: T::Native,
) -> std::result::Result<PrimitiveArray<T>, PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
unary_mut(array, |value| value.add_wrapping(scalar))
}

/// Mutate an array by adding every value in an array by a scalar. If any value in the array
/// is null then the result is also null.
///
/// This only mutates the array if it is not shared buffers with other arrays. For shared
/// array, it returns an `Err` which wraps input array with a `Ok`.
///
/// This detects overflow and returns an `Err` which wraps an `Erro` of actual error.
/// For an non-overflow-checking variant, use `add_scalar_mut` instead.
pub fn add_scalar_checked_mut<T>(
array: PrimitiveArray<T>,
scalar: T::Native,
) -> std::result::Result<
PrimitiveArray<T>,
std::result::Result<PrimitiveArray<T>, ArrowError>,
>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_mut(array, |value| value.add_checked(scalar))
}

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
Expand Down Expand Up @@ -3098,4 +3140,26 @@ mod tests {
assert_eq!(result.len(), 13);
assert_eq!(result.null_count(), 13);
}

#[test]
fn test_primitive_array_add_scalar_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = 3;
let c = add_scalar_mut(a, b).unwrap();
let expected = Int32Array::from(vec![18, 17, 12, 11, 4]);
assert_eq!(c, expected);
}

#[test]
fn test_primitive_add_scalar_mut_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);

let wrapped = add_scalar_mut(a, 1).unwrap();
let expected = Int32Array::from(vec![-2147483648, -2147483647]);
assert_eq!(expected, wrapped);

let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let overflow = add_scalar_checked_mut(a, 1);
let _ = overflow.expect_err("overflow should be detected");
}
}
27 changes: 27 additions & 0 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ where
array.unary(op)
}

/// See [`PrimitiveArray::unary_mut`]
pub fn unary_mut<I, F>(
array: PrimitiveArray<I>,
op: F,
) -> std::result::Result<PrimitiveArray<I>, PrimitiveArray<I>>
where
I: ArrowPrimitiveType,
F: Fn(I::Native) -> I::Native,
{
array.unary_mut(op)
}

/// See [`PrimitiveArray::try_unary`]
pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>>
where
Expand All @@ -68,6 +80,21 @@ where
array.try_unary(op)
}

/// See [`PrimitiveArray::try_unary_mut`]
pub fn try_unary_mut<I, F>(
array: PrimitiveArray<I>,
op: F,
) -> std::result::Result<
PrimitiveArray<I>,
std::result::Result<PrimitiveArray<I>, ArrowError>,
>
where
I: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<I::Native>,
{
array.try_unary_mut(op)
}

/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
Expand Down

0 comments on commit c98d21a

Please sign in to comment.