From 6dd9dae1cea7618a7e136285e7927e4d802ec058 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 6 Nov 2022 16:05:16 -0800 Subject: [PATCH] Check overflow when casting floating point value to decimal256 (#3033) * Check overflow when casting floating point value to decimal256 * Add from_f64 --- arrow-buffer/src/bigint.rs | 15 +++++++++- arrow-cast/src/cast.rs | 59 +++++++++++++++++++++++++++++++++++--- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint.rs index e87c05826fe2..8dd57d2c4646 100644 --- a/arrow-buffer/src/bigint.rs +++ b/arrow-buffer/src/bigint.rs @@ -16,7 +16,7 @@ // under the License. use num::cast::AsPrimitive; -use num::BigInt; +use num::{BigInt, FromPrimitive}; use std::cmp::Ordering; /// A signed 256-bit integer @@ -102,6 +102,19 @@ impl i256 { Self::from_parts(v as u128, v >> 127) } + /// Create an optional i256 from the provided `f64`. Returning `None` + /// if overflow occurred + pub fn from_f64(v: f64) -> Option { + BigInt::from_f64(v).and_then(|i| { + let (integer, overflow) = i256::from_bigint_with_overflow(i); + if overflow { + None + } else { + Some(integer) + } + }) + } + /// Create an i256 from the provided low u128 and high i128 #[inline] pub const fn from_parts(low: u128, high: i128) -> Self { diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 3e23a059bf3e..5bf8c19c5baf 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -387,16 +387,38 @@ fn cast_floating_point_to_decimal256( array: &PrimitiveArray, precision: u8, scale: u8, + cast_options: &CastOptions, ) -> Result where ::Native: AsPrimitive, { let mul = 10_f64.powi(scale as i32); - array - .unary::<_, Decimal256Type>(|v| i256::from_i128((v.as_() * mul).round() as i128)) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) + if cast_options.safe { + let iter = array + .iter() + .map(|v| v.and_then(|v| i256::from_f64((v.as_() * mul).round()))); + let casted_array = + unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; + casted_array + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal256Type, _>(|v| { + i256::from_f64((v.as_() * mul).round()).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + precision, + scale, + v + )) + }) + }) + .and_then(|a| a.with_precision_and_scale(precision, scale)) + .map(|a| Arc::new(a) as ArrayRef) + } } /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] @@ -666,11 +688,13 @@ pub fn cast_with_options( as_primitive_array::(array), *precision, *scale, + cast_options, ), Float64 => cast_floating_point_to_decimal256( as_primitive_array::(array), *precision, *scale, + cast_options, ), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( @@ -6166,4 +6190,31 @@ mod tests { err ); } + + #[test] + fn test_cast_floating_point_to_decimal256_overflow() { + let array = Float64Array::from(vec![f64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 50), + &CastOptions { safe: true }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 50), + &CastOptions { safe: false }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Cast error: Cannot cast to Decimal256(76, 50)"; + assert!( + err.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + expected_error, + err + ); + } }