From 7e7ac153c69a0b227ae11e0caf0f00b04b85cd23 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Mon, 25 Sep 2023 17:45:55 +0800 Subject: [PATCH] fix: add missing precision overflow checking for `cast_string_to_decimal` (#4830) * fix: add missing precision overflow checking for `cast_string_to_decimal` * Add test_cast_string_to_decimal256_precision_overflow --- arrow-cast/src/cast.rs | 75 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 7b8e6144bb49..e7727565c981 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -2801,6 +2801,11 @@ where if cast_options.safe { let iter = from.iter().map(|v| { v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) + .and_then(|v| { + T::validate_decimal_precision(v, precision) + .is_ok() + .then_some(v) + }) }); // Benefit: // 20% performance improvement @@ -2815,13 +2820,17 @@ where .iter() .map(|v| { v.map(|v| { - parse_string_to_decimal_native::(v, scale as usize).map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - T::DATA_TYPE, - )) - }) + parse_string_to_decimal_native::(v, scale as usize) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + T::DATA_TYPE, + )) + }) + .and_then(|v| { + T::validate_decimal_precision(v, precision).map(|_| v) + }) }) .transpose() }) @@ -8152,6 +8161,32 @@ mod tests { ); } + #[test] + fn test_cast_string_to_decimal128_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + #[test] fn test_cast_utf8_to_decimal128_overflow() { let overflow_str_array = StringArray::from(vec![ @@ -8209,6 +8244,32 @@ mod tests { assert!(decimal_arr.is_null(6)); } + #[test] + fn test_cast_string_to_decimal256_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + #[test] fn test_cast_utf8_to_decimal256_overflow() { let overflow_str_array = StringArray::from(vec![