Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiply_fixed_point_scalar and multiply_fixed_point_scalar_checked #4468

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions arrow-arith/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1269,6 +1269,95 @@ pub fn multiply_fixed_point(
.and_then(|a| a.with_precision_and_scale(precision, required_scale))
}

/// Perform `left * right` operation on decimal array and a scalar. If any value in the array is null
/// then the result is also null.
///
/// This performs decimal multiplication which allows precision loss if an exact representation
/// is not possible for the result, according to the required scale. In the case, the result
/// will be rounded to the required scale.
///
/// If the required scale is greater than the product scale, an error is returned.
///
/// It is implemented for compatibility with precision loss `multiply` function provided by
/// other data processing engines. For multiplication with precision loss detection, use
/// `multiply_scalar` or `multiply_scalar_checked` instead.
pub fn multiply_fixed_point_scalar_checked(
left: &PrimitiveArray<Decimal128Type>,
scalar: i128,
required_scale: i8,
scalar_precision: u8,
scalar_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
let (precision, product_scale, divisor) = get_fixed_point_info(
(left.precision(), left.scale()),
(scalar_precision, scalar_scale),
required_scale,
)?;

if required_scale == product_scale {
return multiply_scalar_checked(left, scalar)?
.with_precision_and_scale(precision, required_scale);
}

let b = i256::from_i128(scalar);

try_unary::<_, _, Decimal128Type>(left, |a| {
let a = i256::from_i128(a);

let mut mul = a.wrapping_mul(b);
mul = divide_and_round::<Decimal256Type>(mul, divisor);
mul.to_i128().ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?} * {:?}", a, b))
})
})
.and_then(|a| a.with_precision_and_scale(precision, required_scale))
}

/// Perform `left * right` operation on decimal array and a scalar. If any value in the array is null
/// then the result is also null.
///
/// This performs decimal multiplication which allows precision loss if an exact representation
/// is not possible for the result, according to the required scale. In the case, the result
/// will be rounded to the required scale.
///
/// If the required scale is greater than the product scale, an error is returned.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `multiply_fixed_point_scalar_checked` instead.
///
/// It is implemented for compatibility with precision loss `multiply` function provided by
/// other data processing engines. For multiplication with precision loss detection, use
/// `multiply_scalar` or `multiply_scalar_checked` instead.
pub fn multiply_fixed_point_scalar(
left: &PrimitiveArray<Decimal128Type>,
scalar: i128,
required_scale: i8,
scalar_precision: u8,
scalar_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
let (precision, product_scale, divisor) = get_fixed_point_info(
(left.precision(), left.scale()),
(scalar_precision, scalar_scale),
required_scale,
)?;

if required_scale == product_scale {
return multiply_scalar(left, scalar)?
.with_precision_and_scale(precision, required_scale);
}

let b = i256::from_i128(scalar);

unary::<_, _, Decimal128Type>(left, |a| {
let a = i256::from_i128(a);

let mut mul = a.wrapping_mul(b);
mul = divide_and_round::<Decimal256Type>(mul, divisor);
mul.as_i128()
})
.with_precision_and_scale(precision, required_scale)
}

/// Divide a decimal native value by given divisor and round the result.
fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
where
@@ -3239,6 +3328,55 @@ mod tests {
);
}

#[test]
fn test_decimal_multiply_fixed_point_scalar() {
// [123456789]
let a = Decimal128Array::from(vec![123456789000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();

// 10
let b = 10000000000000000000;

// `multiply_scalar` overflows on this case.
let result = multiply_scalar(&a, b).unwrap();
let expected =
Decimal128Array::from(vec![-16672482290199102048610367863168958464])
.with_precision_and_scale(38, 10)
.unwrap();
assert_eq!(&expected, &result);

// Avoid overflow by reducing the scale.
let result = multiply_fixed_point_scalar(&a, b, 28, 38, 18).unwrap();
// [1234567890]
let expected =
Decimal128Array::from(vec![12345678900000000000000000000000000000])
.with_precision_and_scale(38, 28)
.unwrap();

assert_eq!(&expected, &result);
assert_eq!(
result.value_as_string(0),
"1234567890.0000000000000000000000000000"
);
}

#[test]
fn test_decimal_multiply_fixed_point_scalar_checked_overflow() {
// [99999999999123456789]
let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();

// 9999999999910
let b = 9999999999910000000000000000000;

let err = multiply_fixed_point_scalar_checked(&a, b, 28, 38, 18).unwrap_err();
assert!(err.to_string().contains(
"Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
));
}

#[test]
fn test_timestamp_second_add_interval() {
// timestamp second + interval year month