Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow precision loss on multiplying decimal arrays #3690

Merged
merged 18 commits into from
Mar 16, 2023
Merged
139 changes: 138 additions & 1 deletion arrow-arith/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use crate::arity::*;
use arrow_array::cast::*;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::i256;
use arrow_buffer::ArrowNativeType;
use arrow_schema::*;
use num::traits::Pow;
use std::sync::Arc;
Expand Down Expand Up @@ -61,7 +63,7 @@ fn math_checked_op<LT, RT, F>(
where
LT: ArrowNumericType,
RT: ArrowNumericType,
F: Fn(LT::Native, RT::Native) -> Result<LT::Native, ArrowError>,
F: FnMut(LT::Native, RT::Native) -> Result<LT::Native, ArrowError>,
{
try_binary(left, right, op)
}
Expand Down Expand Up @@ -1165,6 +1167,77 @@ pub fn multiply_dyn_checked(
}
}

/// Perform `left * right` operation on two decimal arrays. If either left or right value is
/// null then the result is also null.
///
/// This performs decimal multiplication which allows precision loss if an exact representation
/// is not possible for the result, according to the required scale. In the case, the result
/// will be rounded to the required scale.
///
/// It is implemented for compatibility with precision loss `multiply` function provided by
/// other data processing engines. For multiplication with precision loss detection, use
/// `multiply` or `multiply_checked` instead.
pub fn mul_fixed_point_checked(
viirya marked this conversation as resolved.
Show resolved Hide resolved
left: &PrimitiveArray<Decimal128Type>,
right: &PrimitiveArray<Decimal128Type>,
required_scale: i8,
) -> Result<ArrayRef, ArrowError> {
let precision = left.precision();
viirya marked this conversation as resolved.
Show resolved Hide resolved
let product_scale = left.scale() + right.scale();

try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
viirya marked this conversation as resolved.
Show resolved Hide resolved
let a = i256::from_i128(a);
let b = i256::from_i128(b);

a.checked_mul(b)
tustvold marked this conversation as resolved.
Show resolved Hide resolved
viirya marked this conversation as resolved.
Show resolved Hide resolved
.map(|mut a| {
if required_scale < product_scale {
viirya marked this conversation as resolved.
Show resolved Hide resolved
let divisor = i256::from_i128(10)
viirya marked this conversation as resolved.
Show resolved Hide resolved
.pow_wrapping((product_scale - required_scale) as u32);
a = divide_and_round::<Decimal256Type>(a, divisor);
}
a
})
.ok_or_else(|| {
ArrowError::ComputeError(format!(
"Overflow happened on: {:?} * {:?}, {:?}",
a,
b,
a.checked_mul(b)
))
})
.and_then(|a| {
viirya marked this conversation as resolved.
Show resolved Hide resolved
a.to_i128().ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?}", a))
})
})
})
.and_then(|a| {
a.with_precision_and_scale(precision, required_scale)
.map(|a| Arc::new(a) as ArrayRef)
})
}

/// Divide a decimal native value by given divisor and round the result.
fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
where
I: DecimalType,
I::Native: ArrowNativeTypeOp,
{
let d = input.div_wrapping(div);
let r = input.mod_wrapping(div);

let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
let half_neg = half.neg_wrapping();

// Round result
match input >= I::Native::ZERO {
true if r >= half => d.add_wrapping(I::Native::ONE),
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
_ => d,
}
}

/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
Expand Down Expand Up @@ -3231,4 +3304,68 @@ mod tests {

assert_eq!(&expected, &result);
}

#[test]
fn test_decimal_multiply_allow_precision_loss() {
// Overflow happening as i128 cannot hold multiplying result.
let a = Decimal128Array::from(vec![123456789000000000000000000])
viirya marked this conversation as resolved.
Show resolved Hide resolved
.with_precision_and_scale(38, 18)
.unwrap();

let b = Decimal128Array::from(vec![10000000000000000000])
viirya marked this conversation as resolved.
Show resolved Hide resolved
.with_precision_and_scale(38, 18)
.unwrap();

let err = multiply_dyn_checked(&a, &b).unwrap_err();
assert!(err.to_string().contains(
"Overflow happened on: 123456789000000000000000000 * 10000000000000000000"
));

// Allow precision loss.
let result = mul_fixed_point_checked(&a, &b, 28).unwrap();
let result = as_primitive_array::<Decimal128Type>(&result).clone();
let expected =
viirya marked this conversation as resolved.
Show resolved Hide resolved
Decimal128Array::from(vec![12345678900000000000000000000000000000])
.with_precision_and_scale(38, 28)
.unwrap();

assert_eq!(&expected, &result);
assert_eq!(
result.value_as_string(0),
"1234567890.0000000000000000000000000000"
);

// Rounding case
let a = Decimal128Array::from(vec![
viirya marked this conversation as resolved.
Show resolved Hide resolved
1,
123456789555555555555555555,
1555555555555555555,
])
.with_precision_and_scale(38, 18)
.unwrap();

let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
viirya marked this conversation as resolved.
Show resolved Hide resolved
.with_precision_and_scale(38, 18)
.unwrap();

let result = mul_fixed_point_checked(&a, &b, 28).unwrap();
let result = as_primitive_array::<Decimal128Type>(&result).clone();
let expected = Decimal128Array::from(vec![
viirya marked this conversation as resolved.
Show resolved Hide resolved
15555555556,
13854595272345679012071330528765432099,
15555555556,
])
.with_precision_and_scale(38, 28)
.unwrap();

assert_eq!(&expected, &result);

// Rounded the value "1385459527.234567901207133052876543209876543210".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

assert_eq!(
result.value_as_string(1),
"1385459527.2345679012071330528765432099"
);
assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
}
}
10 changes: 5 additions & 5 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub fn try_unary<I, F, O>(
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<O::Native, ArrowError>,
F: FnMut(I::Native) -> Result<O::Native, ArrowError>,
viirya marked this conversation as resolved.
Show resolved Hide resolved
{
array.try_unary(op)
}
Expand Down Expand Up @@ -307,11 +307,11 @@ where
pub fn try_binary<A: ArrayAccessor, B: ArrayAccessor, F, O>(
a: A,
b: B,
op: F,
mut op: F,
) -> Result<PrimitiveArray<O>, ArrowError>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
F: FnMut(A::Item, B::Item) -> Result<O::Native, ArrowError>,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
Expand Down Expand Up @@ -431,11 +431,11 @@ fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
a: A,
b: B,
op: F,
mut op: F,
) -> Result<PrimitiveArray<O>, ArrowError>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
F: FnMut(A::Item, B::Item) -> Result<O::Native, ArrowError>,
{
let mut buffer = MutableBuffer::new(len * O::get_byte_width());
for idx in 0..len {
Expand Down
4 changes: 2 additions & 2 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,10 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
/// be preferred if `op` is infallible.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
pub fn try_unary<F, O, E>(&self, op: F) -> Result<PrimitiveArray<O>, E>
pub fn try_unary<F, O, E>(&self, mut op: F) -> Result<PrimitiveArray<O>, E>
where
O: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<O::Native, E>,
F: FnMut(T::Native) -> Result<O::Native, E>,
{
let data = self.data();
let len = self.len();
Expand Down