Skip to content

Commit

Permalink
Fix unary_dyn for decimal scalar arithmetic computation (#3345)
Browse files Browse the repository at this point in the history
* Fix unary for decimal arithmetic computation

* Use discriminant
  • Loading branch information
viirya authored Dec 18, 2022
1 parent 07284c5 commit 491b023
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
20 changes: 19 additions & 1 deletion arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,7 @@ mod tests {
use super::*;
use crate::array::Int32Array;
use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut};
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use crate::datatypes::{Date64Type, Decimal128Type, Int32Type, Int8Type};
use arrow_buffer::i256;
use chrono::NaiveDate;
use half::f16;
Expand Down Expand Up @@ -3226,4 +3226,22 @@ mod tests {
])) as ArrayRef;
assert_eq!(&result, &expected);
}

#[test]
fn test_decimal_add_scalar_dyn() {
let a = Decimal128Array::from(vec![100, 210, 320])
.with_precision_and_scale(38, 2)
.unwrap();

let result = add_scalar_dyn::<Decimal128Type>(&a, 1).unwrap();
let result = as_primitive_array::<Decimal128Type>(&result)
.clone()
.with_precision_and_scale(38, 2)
.unwrap();
let expected = Decimal128Array::from(vec![101, 211, 321])
.with_precision_and_scale(38, 2)
.unwrap();

assert_eq!(&expected, &result);
}
}
17 changes: 11 additions & 6 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
if array.value_type() != T::DATA_TYPE {
if std::mem::discriminant(&array.value_type())
!= std::mem::discriminant(&T::DATA_TYPE)
{
return Err(ArrowError::CastError(format!(
"Cannot perform the unary operation on dictionary array of value type {}",
"Cannot perform the unary operation of type {} on dictionary array of value type {}",
T::DATA_TYPE,
array.value_type()
)));
}
Expand All @@ -135,14 +138,15 @@ where
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
if t == &T::DATA_TYPE {
if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
"Cannot perform unary operation of type {} on array of type {}",
T::DATA_TYPE,
t
)))
}
Expand All @@ -166,14 +170,15 @@ where
)))
},
t => {
if t == &T::DATA_TYPE {
if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
"Cannot perform unary operation of type {} on array of type {}",
T::DATA_TYPE,
t
)))
}
Expand Down

0 comments on commit 491b023

Please sign in to comment.