Skip to content

Commit

Permalink
Round instead of Truncate while casting float to decimal (#3000)
Browse files Browse the repository at this point in the history
* add .round() before casting to integer

* add more test cases

* update test cases

* add doc

* Format

Co-authored-by: Raphael Taylor-Davies <[email protected]>
  • Loading branch information
waitingkuo and tustvold authored Nov 3, 2022
1 parent 24afac4 commit 61cf6f7
Showing 1 changed file with 79 additions and 24 deletions.
103 changes: 79 additions & 24 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
/// * Time32 and Time64: precision lost when going to higher interval
/// * Timestamp and Date{32|64}: precision lost when going to higher interval
/// * Temporal to/from backing primitive: zero-copy with data type change
/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals
/// (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`.
/// It used to truncate it instead of round (i.e. outputs 6.4 instead)
///
/// Unsupported Casts
/// * To or from `StructArray`
Expand Down Expand Up @@ -353,7 +356,7 @@ where
{
let mul = 10_f64.powi(scale as i32);

unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul) as i128)
unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul).round() as i128)
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}
Expand All @@ -368,9 +371,11 @@ where
{
let mul = 10_f64.powi(scale as i32);

unary::<T, _, Decimal256Type>(array, |v| i256::from_i128((v.as_() * mul) as i128))
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
unary::<T, _, Decimal256Type>(array, |v| {
i256::from_i128((v.as_() * mul).round() as i128)
})
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}

/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
Expand Down Expand Up @@ -3192,8 +3197,8 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_7),
Some(1.123_456_7),
Some(1.123_456_4), // round down
Some(1.123_456_7), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3205,8 +3210,8 @@ mod tests {
Some(2200000_i128),
Some(4400000_i128),
None,
Some(1123456_i128),
Some(1123456_i128),
Some(1123456_i128), // round down
Some(1123457_i128), // round up
]
);

Expand All @@ -3216,9 +3221,10 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_789_123_4),
Some(1.123_456_789_012_345_6),
Some(1.123_456_789_012_345_6),
Some(1.123_456_489_123_4), // round up
Some(1.123_456_789_123_4), // round up
Some(1.123_456_489_012_345_6), // round down
Some(1.123_456_789_012_345_6), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3230,9 +3236,10 @@ mod tests {
Some(2200000_i128),
Some(4400000_i128),
None,
Some(1123456_i128),
Some(1123456_i128),
Some(1123456_i128),
Some(1123456_i128), // round down
Some(1123457_i128), // round up
Some(1123456_i128), // round down
Some(1123457_i128), // round up
]
);
}
Expand Down Expand Up @@ -3307,8 +3314,8 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_7),
Some(1.123_456_7),
Some(1.123_456_4), // round down
Some(1.123_456_7), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3320,8 +3327,8 @@ mod tests {
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)), // round down
Some(i256::from_i128(1123457_i128)), // round up
]
);

Expand All @@ -3331,9 +3338,10 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_789_123_4),
Some(1.123_456_789_012_345_6),
Some(1.123_456_789_012_345_6),
Some(1.123_456_489_123_4), // round down
Some(1.123_456_789_123_4), // round up
Some(1.123_456_489_012_345_6), // round down
Some(1.123_456_789_012_345_6), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3345,9 +3353,10 @@ mod tests {
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)), // round down
Some(i256::from_i128(1123457_i128)), // round up
Some(i256::from_i128(1123456_i128)), // round down
Some(i256::from_i128(1123457_i128)), // round up
]
);
}
Expand Down Expand Up @@ -5994,4 +6003,50 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]);
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_f64_to_decimal128() {
// to reproduce https://github.com/apache/arrow-rs/issues/2997

let decimal_type = DataType::Decimal128(18, 2);
let array = Float64Array::from(vec![
Some(0.0699999999),
Some(0.0659999999),
Some(0.0650000000),
Some(0.0649999999),
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal128Array,
&decimal_type,
vec![
Some(7_i128), // round up
Some(7_i128), // round up
Some(7_i128), // round up
Some(6_i128), // round down
]
);

let decimal_type = DataType::Decimal128(18, 3);
let array = Float64Array::from(vec![
Some(0.0699999999),
Some(0.0659999999),
Some(0.0650000000),
Some(0.0649999999),
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal128Array,
&decimal_type,
vec![
Some(70_i128), // round up
Some(66_i128), // round up
Some(65_i128), // round down
Some(65_i128), // round up
]
);
}
}

0 comments on commit 61cf6f7

Please sign in to comment.