Skip to content

Commit

Permalink
[feat] Support unary_dyn_mut in arrow-arth.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ted-Jiang committed Feb 13, 2023
1 parent 3cf64df commit d3fe797
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,37 @@ where
}
}

/// Applies an infallible unary function to an [`ArrayRef`] with primitive values.
/// If the buffer of ArrayRef is not shared with other arrays, then func will
/// mutate the buffer directly without allocating new buffer.
pub fn unary_dyn_mut<F, T>(array: ArrayRef, op: F) -> Result<ArrayRef, ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
let array_ref = array.as_ref();
downcast_dictionary_array! {
array_ref => unary_dict::<_, F, T>(array_ref, op).map_err(|_| array),
t => {
//Todo support unary_dict_mut
if PrimitiveArray::<T>::is_compatible(t) {
let primitive_array = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap().clone();
// Need drop the strong ref which clone before in this function.
std::mem::drop(array);
match unary_mut::<T, F>(
primitive_array,
op,
) {
Ok(arr) => Ok(Arc::new(arr)),
Err(arr) => Err(Arc::new(arr)),
}
} else {
Err(array)
}
}
}
}

/// Applies a fallible unary function to an array with primitive values.
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -576,6 +607,27 @@ mod tests {
);
}

#[test]
fn test_unary_f64_mut() {
// 1. only have one strong ref, use copy on write.
let input =
Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
let result =
unary_dyn_mut::<_, Float64Type>(make_array(input.into_data()), |n| n + 1.0)
.unwrap();
assert_eq!(
result.as_any().downcast_ref::<Float64Array>().unwrap(),
&Float64Array::from(vec![Some(6.1f64), None, Some(7.8), None, Some(8.2)])
);

// 2. More than one strong ref
let input =
Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
let slice = input.slice(1, 4);
let result = unary_dyn_mut::<_, Float64Type>(slice, |n| n + 1.0);
assert!(result.is_err())
}

#[test]
fn test_unary_dict_and_unary_dyn() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
Expand Down

0 comments on commit d3fe797

Please sign in to comment.