From 491b0239a81bb3e7e2829d69c5a59799a0d4f6e6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 18 Dec 2022 14:15:31 -0800 Subject: [PATCH] Fix unary_dyn for decimal scalar arithmetic computation (#3345) * Fix unary for decimal arithmetic computation * Use discriminant --- arrow/src/compute/kernels/arithmetic.rs | 20 +++++++++++++++++++- arrow/src/compute/kernels/arity.rs | 17 +++++++++++------ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 23cefe48e2c8..913a2cad6c93 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -1633,7 +1633,7 @@ mod tests { use super::*; use crate::array::Int32Array; use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut}; - use crate::datatypes::{Date64Type, Int32Type, Int8Type}; + use crate::datatypes::{Date64Type, Decimal128Type, Int32Type, Int8Type}; use arrow_buffer::i256; use chrono::NaiveDate; use half::f16; @@ -3226,4 +3226,22 @@ mod tests { ])) as ArrayRef; assert_eq!(&result, &expected); } + + #[test] + fn test_decimal_add_scalar_dyn() { + let a = Decimal128Array::from(vec![100, 210, 320]) + .with_precision_and_scale(38, 2) + .unwrap(); + + let result = add_scalar_dyn::(&a, 1).unwrap(); + let result = as_primitive_array::(&result) + .clone() + .with_precision_and_scale(38, 2) + .unwrap(); + let expected = Decimal128Array::from(vec![101, 211, 321]) + .with_precision_and_scale(38, 2) + .unwrap(); + + assert_eq!(&expected, &result); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 6207ab63935d..02659a5a7738 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -114,9 +114,12 @@ where T: ArrowPrimitiveType, F: Fn(T::Native) -> Result, { - if array.value_type() != T::DATA_TYPE { + if std::mem::discriminant(&array.value_type()) + != std::mem::discriminant(&T::DATA_TYPE) + { return Err(ArrowError::CastError(format!( - "Cannot perform the unary operation on dictionary array of value type {}", + "Cannot perform the unary operation of type {} on dictionary array of value type {}", + T::DATA_TYPE, array.value_type() ))); } @@ -135,14 +138,15 @@ where downcast_dictionary_array! { array => unary_dict::<_, F, T>(array, op), t => { - if t == &T::DATA_TYPE { + if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) { Ok(Arc::new(unary::( array.as_any().downcast_ref::>().unwrap(), op, ))) } else { Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation on array of type {}", + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, t ))) } @@ -166,14 +170,15 @@ where ))) }, t => { - if t == &T::DATA_TYPE { + if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) { Ok(Arc::new(try_unary::( array.as_any().downcast_ref::>().unwrap(), op, )?)) } else { Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation on array of type {}", + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, t ))) }