Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 1, 2023
1 parent 9a377d7 commit 82fa873
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
67 changes: 40 additions & 27 deletions arrow-arith/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1179,14 +1179,15 @@ pub fn multiply_decimal(
let precision = left.precision();
let product_scale = left.scale() + right.scale();
let mut min_product_scale = product_scale;
let mut scales = vec![0; left.len()];

math_checked_op(left, right, |a, b| {
try_binary::<_, _, _, Decimal256Type>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);

a.checked_mul(b)
.map(|mut a| {
let mut product_scale = product_scale;
let mut scale = product_scale;

// Round the value if an exact representation is not possible.
// ref: java.math.BigDecimal#doRound
Expand All @@ -1196,25 +1197,16 @@ pub fn multiply_decimal(
while diff > 0 {
let divisor = i256::from_i128(10).pow_wrapping(diff as u32);
a = divide_and_round::<Decimal256Type>(a, divisor);
product_scale -= diff;
scale -= diff;

digits = a.to_string().len() as i8;
diff = digits - (Decimal128Type::MAX_PRECISION as i8);
}
if product_scale < min_product_scale {
min_product_scale = product_scale;
}
(a, product_scale)
})
.and_then(|(a, scale)| {
if scale > min_product_scale {
let divisor = i256::from_i128(10)
.pow_wrapping((scale - min_product_scale) as u32);
let a = divide_and_round::<Decimal256Type>(a, divisor);
a.to_i128()
} else {
a.to_i128()
if scale < min_product_scale {
min_product_scale = scale;
}
scales.push(scale);
a
})
.ok_or_else(|| {
ArrowError::ComputeError(format!(
Expand All @@ -1226,10 +1218,25 @@ pub fn multiply_decimal(
})
})
.and_then(|a| {
Ok(
Arc::new(a.with_precision_and_scale(precision, min_product_scale)?)
as ArrayRef,
)
try_unary::<Decimal256Type, _, Decimal128Type>(&a, |a| {
let scale = scales.pop().unwrap();

let scaled = if scale > min_product_scale {
let divisor =
i256::from_i128(10).pow_wrapping((scale - min_product_scale) as u32);
divide_and_round::<Decimal256Type>(a, divisor)
} else {
a
};

scaled.to_i128().ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?}", a))
})
})
.and_then(|a| {
a.with_precision_and_scale(precision, min_product_scale)
.map(|a| Arc::new(a) as ArrayRef)
})
})
}

Expand Down Expand Up @@ -3350,30 +3357,36 @@ mod tests {
);

// Rounding case
let a =
Decimal128Array::from(vec![123456789555555555555555555, 1555555555555555555])
.with_precision_and_scale(38, 18)
.unwrap();
let a = Decimal128Array::from(vec![
1,
123456789555555555555555555,
1555555555555555555,
])
.with_precision_and_scale(38, 18)
.unwrap();

let b = Decimal128Array::from(vec![11222222222222222222, 1])
let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
.with_precision_and_scale(38, 18)
.unwrap();

let result = multiply_decimal(&a, &b).unwrap();
let result = as_primitive_array::<Decimal128Type>(&result).clone();
let expected = Decimal128Array::from(vec![
15555555556,
13854595272345679012071330528765432099,
15555555556,
])
.with_precision_and_scale(38, 28)
.unwrap();

assert_eq!(&expected, &result);

// Rounded the value "1385459527.234567901207133052876543209876543210".
assert_eq!(
result.value_as_string(0),
result.value_as_string(1),
"1385459527.2345679012071330528765432099"
);
assert_eq!(result.value_as_string(1), "0.0000000000000000015555555556");
assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
}
}
2 changes: 1 addition & 1 deletion arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub fn try_unary<I, F, O>(
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<O::Native, ArrowError>,
F: FnMut(I::Native) -> Result<O::Native, ArrowError>,
{
array.try_unary(op)
}
Expand Down
4 changes: 2 additions & 2 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,10 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
/// be preferred if `op` is infallible.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
pub fn try_unary<F, O, E>(&self, op: F) -> Result<PrimitiveArray<O>, E>
pub fn try_unary<F, O, E>(&self, mut op: F) -> Result<PrimitiveArray<O>, E>
where
O: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<O::Native, E>,
F: FnMut(T::Native) -> Result<O::Native, E>,
{
let data = self.data();
let len = self.len();
Expand Down

0 comments on commit 82fa873

Please sign in to comment.