Skip to content

Commit

Permalink
Add multiply_scalar (#1159)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Jan 12, 2022
1 parent 884c6a6 commit d03cd47
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,31 @@ where
return math_op(left, right, |a, b| a * b);
}

/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
pub fn multiply_scalar<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
{
#[cfg(feature = "simd")]
{
let scalar_vector = T::init(scalar);
return simd_unary_math_op(array, |x| x * scalar_vector, |x| x * scalar);
}
#[cfg(not(feature = "simd"))]
return Ok(unary(array, |value| value * scalar));
}

/// Perform `left % right` operation on two arrays. If either left or right value is null
/// then the result is also null. If any right hand value is zero then the result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
Expand Down Expand Up @@ -1298,6 +1323,25 @@ mod tests {
assert_eq!(72, c.value(4));
}

#[test]
fn test_primitive_array_multiply_scalar() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = 3;
let c = multiply_scalar(&a, b).unwrap();
let expected = Int32Array::from(vec![45, 42, 27, 24, 3]);
assert_eq!(c, expected);
}

#[test]
fn test_primitive_array_multiply_scalar_sliced() {
let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);
let a = a.slice(1, 4);
let a = as_primitive_array(&a);
let actual = multiply_scalar(a, 3).unwrap();
let expected = Int32Array::from(vec![None, Some(27), Some(24), None]);
assert_eq!(actual, expected);
}

#[test]
fn test_primitive_array_divide() {
let a = Int32Array::from(vec![15, 15, 8, 1, 9]);
Expand Down

0 comments on commit d03cd47

Please sign in to comment.