From 9a73e9b7a5f24ee652324505643b46567ffd8158 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Jun 2023 15:19:01 -0700 Subject: [PATCH 1/2] Add multiply_fixed_point_scalar and multiply_fixed_point_scalar_checked --- arrow-arith/src/arithmetic.rs | 85 +++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 8e7ab44042cf..97976588de8f 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -1269,6 +1269,91 @@ pub fn multiply_fixed_point( .and_then(|a| a.with_precision_and_scale(precision, required_scale)) } +/// Perform `left * right` operation on decimal array and a scalar. If any value in the array is null +/// then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply_scalar` or `multiply_scalar_checked` instead. +pub fn multiply_fixed_point_scalar_checked( + left: &PrimitiveArray, + scalar: i128, + required_scale: i8, +) -> Result, ArrowError> { + let (precision, product_scale, divisor) = get_fixed_point_info( + (left.precision(), left.scale()), + (left.precision(), left.scale()), + required_scale, + )?; + + if required_scale == product_scale { + return multiply_scalar_checked(left, scalar)? + .with_precision_and_scale(precision, required_scale); + } + + let b = i256::from_i128(scalar); + + try_unary::<_, _, Decimal128Type>(left, |a| { + let a = i256::from_i128(a); + + let mut mul = a.wrapping_mul(b); + mul = divide_and_round::(mul, divisor); + mul.to_i128().ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow happened on: {:?} * {:?}", a, b)) + }) + }) + .and_then(|a| a.with_precision_and_scale(precision, required_scale)) +} + +/// Perform `left * right` operation on decimal array and a scalar. If any value in the array is null +/// then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_fixed_point_scalar_checked` instead. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply_scalar` or `multiply_scalar_checked` instead. +pub fn multiply_fixed_point_scalar( + left: &PrimitiveArray, + scalar: i128, + required_scale: i8, +) -> Result, ArrowError> { + let (precision, product_scale, divisor) = get_fixed_point_info( + (left.precision(), left.scale()), + (left.precision(), left.scale()), + required_scale, + )?; + + if required_scale == product_scale { + return multiply_scalar(left, scalar)? + .with_precision_and_scale(precision, required_scale); + } + + let b = i256::from_i128(scalar); + + unary::<_, _, Decimal128Type>(left, |a| { + let a = i256::from_i128(a); + + let mut mul = a.wrapping_mul(b); + mul = divide_and_round::(mul, divisor); + mul.as_i128() + }) + .with_precision_and_scale(precision, required_scale) +} + /// Divide a decimal native value by given divisor and round the result. fn divide_and_round(input: I::Native, div: I::Native) -> I::Native where From 8781c7c5856a3845677daf4c03c8f1cad848a9df Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 3 Jul 2023 14:31:48 -0700 Subject: [PATCH 2/2] Add tests --- arrow-arith/src/arithmetic.rs | 57 +++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 97976588de8f..fb4514ea5d84 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -1285,10 +1285,12 @@ pub fn multiply_fixed_point_scalar_checked( left: &PrimitiveArray, scalar: i128, required_scale: i8, + scalar_precision: u8, + scalar_scale: i8, ) -> Result, ArrowError> { let (precision, product_scale, divisor) = get_fixed_point_info( (left.precision(), left.scale()), - (left.precision(), left.scale()), + (scalar_precision, scalar_scale), required_scale, )?; @@ -1330,10 +1332,12 @@ pub fn multiply_fixed_point_scalar( left: &PrimitiveArray, scalar: i128, required_scale: i8, + scalar_precision: u8, + scalar_scale: i8, ) -> Result, ArrowError> { let (precision, product_scale, divisor) = get_fixed_point_info( (left.precision(), left.scale()), - (left.precision(), left.scale()), + (scalar_precision, scalar_scale), required_scale, )?; @@ -3324,6 +3328,55 @@ mod tests { ); } + #[test] + fn test_decimal_multiply_fixed_point_scalar() { + // [123456789] + let a = Decimal128Array::from(vec![123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // 10 + let b = 10000000000000000000; + + // `multiply_scalar` overflows on this case. + let result = multiply_scalar(&a, b).unwrap(); + let expected = + Decimal128Array::from(vec![-16672482290199102048610367863168958464]) + .with_precision_and_scale(38, 10) + .unwrap(); + assert_eq!(&expected, &result); + + // Avoid overflow by reducing the scale. + let result = multiply_fixed_point_scalar(&a, b, 28, 38, 18).unwrap(); + // [1234567890] + let expected = + Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + assert_eq!( + result.value_as_string(0), + "1234567890.0000000000000000000000000000" + ); + } + + #[test] + fn test_decimal_multiply_fixed_point_scalar_checked_overflow() { + // [99999999999123456789] + let a = Decimal128Array::from(vec![99999999999123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // 9999999999910 + let b = 9999999999910000000000000000000; + + let err = multiply_fixed_point_scalar_checked(&a, b, 28, 38, 18).unwrap_err(); + assert!(err.to_string().contains( + "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000" + )); + } + #[test] fn test_timestamp_second_add_interval() { // timestamp second + interval year month