diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 00375d32a677..8e2b7915357a 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -26,8 +26,11 @@ use crate::arity::*; use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; +use arrow_buffer::i256; +use arrow_buffer::ArrowNativeType; use arrow_schema::*; use num::traits::Pow; +use std::cmp::min; use std::sync::Arc; /// Helper function to perform math lambda function on values from two arrays. If either @@ -1165,6 +1168,77 @@ pub fn multiply_dyn_checked( } } +/// Perform `left * right` operation on two decimal arrays. If either left or right value is +/// null then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply` or `multiply_checked` instead. +pub fn multiply_fixed_point_checked( + left: &PrimitiveArray, + right: &PrimitiveArray, + required_scale: i8, +) -> Result, ArrowError> { + let product_scale = left.scale() + right.scale(); + let precision = min( + left.precision() + right.precision() + 1, + DECIMAL128_MAX_PRECISION, + ); + + if required_scale == product_scale { + return multiply_checked(left, right)? + .with_precision_and_scale(precision, required_scale); + } + + if required_scale > product_scale { + return Err(ArrowError::ComputeError(format!( + "Required scale {} is greater than product scale {}", + required_scale, product_scale + ))); + } + + let divisor = + i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32); + + try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { + let a = i256::from_i128(a); + let b = i256::from_i128(b); + + let mut mul = a.wrapping_mul(b); + mul = divide_and_round::(mul, divisor); + mul.to_i128().ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow happened on: {:?} * {:?}", a, b)) + }) + }) + .and_then(|a| a.with_precision_and_scale(precision, required_scale)) +} + +/// Divide a decimal native value by given divisor and round the result. +fn divide_and_round(input: I::Native, div: I::Native) -> I::Native +where + I: DecimalType, + I::Native: ArrowNativeTypeOp, +{ + let d = input.div_wrapping(div); + let r = input.mod_wrapping(div); + + let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); + let half_neg = half.neg_wrapping(); + + // Round result + match input >= I::Native::ZERO { + true if r >= half => d.add_wrapping(I::Native::ONE), + false if r <= half_neg => d.sub_wrapping(I::Native::ONE), + _ => d, + } +} + /// Multiply every value in an array by a scalar. If any value in the array is null then the /// result is also null. /// @@ -3231,4 +3305,101 @@ mod tests { assert_eq!(&expected, &result); } + + #[test] + fn test_decimal_multiply_allow_precision_loss() { + // Overflow happening as i128 cannot hold multiplying result. + // [123456789] + let a = Decimal128Array::from(vec![123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [10] + let b = Decimal128Array::from(vec![10000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let err = multiply_dyn_checked(&a, &b).unwrap_err(); + assert!(err.to_string().contains( + "Overflow happened on: 123456789000000000000000000 * 10000000000000000000" + )); + + // Allow precision loss. + let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); + // [1234567890] + let expected = + Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + assert_eq!( + result.value_as_string(0), + "1234567890.0000000000000000000000000000" + ); + + // Rounding case + // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555] + let a = Decimal128Array::from(vec![ + 1, + 123456789555555555555555555, + 1555555555555555555, + ]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [1.555555555555555555, 11.222222222222222222, 0.000000000000000001] + let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); + // [ + // 0.0000000000000000015555555556, + // 1385459527.2345679012071330528765432099, + // 0.0000000000000000015555555556 + // ] + let expected = Decimal128Array::from(vec![ + 15555555556, + 13854595272345679012071330528765432099, + 15555555556, + ]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + + // Rounded the value "1385459527.234567901207133052876543209876543210". + assert_eq!( + result.value_as_string(1), + "1385459527.2345679012071330528765432099" + ); + assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556"); + assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556"); + + let a = Decimal128Array::from(vec![1230]) + .with_precision_and_scale(4, 2) + .unwrap(); + + let b = Decimal128Array::from(vec![1000]) + .with_precision_and_scale(4, 2) + .unwrap(); + + // Required scale is same as the product of the input scales. Behavior is same as multiply. + let result = multiply_fixed_point_checked(&a, &b, 4).unwrap(); + assert_eq!(result.precision(), 9); + assert_eq!(result.scale(), 4); + + let expected = multiply_checked(&a, &b) + .unwrap() + .with_precision_and_scale(9, 4) + .unwrap(); + assert_eq!(&expected, &result); + + // Required scale cannot be larger than the product of the input scales. + let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err(); + assert!(result + .to_string() + .contains("Required scale 5 is greater than product scale 4")); + } }