Skip to content

Commit

Permalink
Prevent precision=0 for decimal type (#3162)
Browse files Browse the repository at this point in the history
* Adding decimal precision checks

* Doc edits
  • Loading branch information
psvri authored Nov 22, 2022
1 parent a110004 commit 6455e34
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,12 +993,13 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {

impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
/// Returns a Decimal array with the same data as self, with the
/// specified precision.
/// specified precision and scale.
///
/// Returns an Error if:
/// 1. `precision` is larger than `T:MAX_PRECISION`
/// 2. `scale` is larger than `T::MAX_SCALE`
/// 3. `scale` is > `precision`
/// - `precision` is zero
/// - `precision` is larger than `T:MAX_PRECISION`
/// - `scale` is larger than `T::MAX_SCALE`
/// - `scale` is > `precision`
pub fn with_precision_and_scale(
self,
precision: u8,
Expand All @@ -1025,18 +1026,24 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
precision: u8,
scale: u8,
) -> Result<(), ArrowError> {
if precision == 0 {
return Err(ArrowError::InvalidArgumentError(format!(
"precision cannot be 0, has to be between [1, {}]",
T::MAX_PRECISION
)));
}
if precision > T::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision,
Decimal128Type::MAX_PRECISION
T::MAX_PRECISION
)));
}
if scale > T::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale,
Decimal128Type::MAX_SCALE
T::MAX_SCALE
)));
}
if scale > precision {
Expand Down Expand Up @@ -1934,6 +1941,14 @@ mod tests {
arr.validate_decimal_precision(5).unwrap();
}

#[test]
#[should_panic(expected = "precision cannot be 0, has to be between [1, 38]")]
fn test_decimal_array_with_precision_zero() {
Decimal128Array::from_iter_values([12345, 456])
.with_precision_and_scale(0, 2)
.unwrap();
}

#[test]
#[should_panic(expected = "precision 40 is greater than max 38")]
fn test_decimal_array_with_precision_and_scale_invalid_precision() {
Expand Down

0 comments on commit 6455e34

Please sign in to comment.