From 6c6c9b72f084dea3c759a79e613bf67627d0149e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 14 Dec 2022 23:14:19 -0800 Subject: [PATCH 1/2] Fix unary for decimal arithmetic computation --- arrow-schema/src/datatype.rs | 7 +++++++ arrow/src/compute/kernels/arithmetic.rs | 20 +++++++++++++++++++- arrow/src/compute/kernels/arity.rs | 17 +++++++++++------ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index da1c20ddbd38..e82609c98ef0 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -320,6 +320,13 @@ impl DataType { ) } + /// Returns true if this type is decimal: (Decimal*). + #[inline] + pub fn is_decimal(&self) -> bool { + use DataType::*; + matches!(self, Decimal128(_, _) | Decimal256(_, _)) + } + /// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval). #[inline] pub fn is_temporal(&self) -> bool { diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 23cefe48e2c8..913a2cad6c93 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -1633,7 +1633,7 @@ mod tests { use super::*; use crate::array::Int32Array; use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut}; - use crate::datatypes::{Date64Type, Int32Type, Int8Type}; + use crate::datatypes::{Date64Type, Decimal128Type, Int32Type, Int8Type}; use arrow_buffer::i256; use chrono::NaiveDate; use half::f16; @@ -3226,4 +3226,22 @@ mod tests { ])) as ArrayRef; assert_eq!(&result, &expected); } + + #[test] + fn test_decimal_add_scalar_dyn() { + let a = Decimal128Array::from(vec![100, 210, 320]) + .with_precision_and_scale(38, 2) + .unwrap(); + + let result = add_scalar_dyn::(&a, 1).unwrap(); + let result = as_primitive_array::(&result) + .clone() + .with_precision_and_scale(38, 2) + .unwrap(); + let expected = Decimal128Array::from(vec![101, 211, 321]) + .with_precision_and_scale(38, 2) + .unwrap(); + + assert_eq!(&expected, &result); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 6207ab63935d..4e267fbc9535 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -114,9 +114,12 @@ where T: ArrowPrimitiveType, F: Fn(T::Native) -> Result, { - if array.value_type() != T::DATA_TYPE { + if array.value_type() != T::DATA_TYPE + && !(array.value_type().is_decimal() && T::DATA_TYPE.is_decimal()) + { return Err(ArrowError::CastError(format!( - "Cannot perform the unary operation on dictionary array of value type {}", + "Cannot perform the unary operation of type {} on dictionary array of value type {}", + T::DATA_TYPE, array.value_type() ))); } @@ -135,14 +138,15 @@ where downcast_dictionary_array! { array => unary_dict::<_, F, T>(array, op), t => { - if t == &T::DATA_TYPE { + if t == &T::DATA_TYPE || (t.is_decimal() && T::DATA_TYPE.is_decimal()) { Ok(Arc::new(unary::( array.as_any().downcast_ref::>().unwrap(), op, ))) } else { Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation on array of type {}", + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, t ))) } @@ -166,14 +170,15 @@ where ))) }, t => { - if t == &T::DATA_TYPE { + if t == &T::DATA_TYPE || (t.is_decimal() && T::DATA_TYPE.is_decimal()) { Ok(Arc::new(try_unary::( array.as_any().downcast_ref::>().unwrap(), op, )?)) } else { Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation on array of type {}", + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, t ))) } From 91daa8fd099b905643ebe15a0b0c968e6a69196f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Dec 2022 16:00:20 -0800 Subject: [PATCH 2/2] Use discriminant --- arrow-schema/src/datatype.rs | 7 ------- arrow/src/compute/kernels/arity.rs | 8 ++++---- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index e82609c98ef0..da1c20ddbd38 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -320,13 +320,6 @@ impl DataType { ) } - /// Returns true if this type is decimal: (Decimal*). - #[inline] - pub fn is_decimal(&self) -> bool { - use DataType::*; - matches!(self, Decimal128(_, _) | Decimal256(_, _)) - } - /// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval). #[inline] pub fn is_temporal(&self) -> bool { diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 4e267fbc9535..02659a5a7738 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -114,8 +114,8 @@ where T: ArrowPrimitiveType, F: Fn(T::Native) -> Result, { - if array.value_type() != T::DATA_TYPE - && !(array.value_type().is_decimal() && T::DATA_TYPE.is_decimal()) + if std::mem::discriminant(&array.value_type()) + != std::mem::discriminant(&T::DATA_TYPE) { return Err(ArrowError::CastError(format!( "Cannot perform the unary operation of type {} on dictionary array of value type {}", @@ -138,7 +138,7 @@ where downcast_dictionary_array! { array => unary_dict::<_, F, T>(array, op), t => { - if t == &T::DATA_TYPE || (t.is_decimal() && T::DATA_TYPE.is_decimal()) { + if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) { Ok(Arc::new(unary::( array.as_any().downcast_ref::>().unwrap(), op, @@ -170,7 +170,7 @@ where ))) }, t => { - if t == &T::DATA_TYPE || (t.is_decimal() && T::DATA_TYPE.is_decimal()) { + if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) { Ok(Arc::new(try_unary::( array.as_any().downcast_ref::>().unwrap(), op,