diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index 3e7a81862927..2a9a9494d892 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -152,6 +152,37 @@ where } } +/// Applies an infallible unary function to an [`ArrayRef`] with primitive values. +/// If the buffer of ArrayRef is not shared with other arrays, then func will +/// mutate the buffer directly without allocating new buffer. +pub fn unary_dyn_mut(array: ArrayRef, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + let array_ref = array.as_ref(); + downcast_dictionary_array! { + array_ref => unary_dict::<_, F, T>(array_ref, op).map_err(|_| array), + t => { + //Todo support unary_dict_mut + if PrimitiveArray::::is_compatible(t) { + let primitive_array = array.as_any().downcast_ref::>().unwrap().clone(); + // Need drop the strong ref which clone before in this function. + std::mem::drop(array); + match unary_mut::( + primitive_array, + op, + ) { + Ok(arr) => Ok(Arc::new(arr)), + Err(arr) => Err(Arc::new(arr)), + } + } else { + Err(array) + } + } + } +} + /// Applies a fallible unary function to an array with primitive values. pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result where @@ -576,6 +607,27 @@ mod tests { ); } + #[test] + fn test_unary_f64_mut() { + // 1. only have one strong ref, use copy on write. + let input = + Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); + let result = + unary_dyn_mut::<_, Float64Type>(make_array(input.into_data()), |n| n + 1.0) + .unwrap(); + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &Float64Array::from(vec![Some(6.1f64), None, Some(7.8), None, Some(8.2)]) + ); + + // 2. More than one strong ref + let input = + Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); + let slice = input.slice(1, 4); + let result = unary_dyn_mut::<_, Float64Type>(slice, |n| n + 1.0); + assert!(result.is_err()) + } + #[test] fn test_unary_dict_and_unary_dyn() { let mut builder = PrimitiveDictionaryBuilder::::new();