diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 6122f9cb3f33..b7737c6de61f 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -153,6 +153,12 @@ where }) } +macro_rules! cmp_dict_primitive_helper { + ($t:ty, $key_type_lhs:expr, $left:expr, $right:expr) => { + cmp_dict_primitive::<$t>($key_type_lhs, $left, $right)? + }; +} + /// returns a comparison function that compares two values at two different positions /// between the two arrays. /// The arrays' types must be equal. @@ -193,6 +199,12 @@ pub fn build_compare( (Int64, Int64) => compare_primitives::(left, right), (Float32, Float32) => compare_float::(left, right), (Float64, Float64) => compare_float::(left, right), + (Decimal128(_, _), Decimal128(_, _)) => { + compare_primitives::(left, right) + } + (Decimal256(_, _), Decimal256(_, _)) => { + compare_primitives::(left, right) + } (Date32, Date32) => compare_primitives::(left, right), (Date64, Date64) => compare_primitives::(left, right), (Time32(Second), Time32(Second)) => { @@ -253,83 +265,8 @@ pub fn build_compare( } let key_type_lhs = key_type_lhs.as_ref(); - - match value_type_lhs.as_ref() { - Int8 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Int16 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Int32 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Int64 => cmp_dict_primitive::(key_type_lhs, left, right)?, - UInt8 => cmp_dict_primitive::(key_type_lhs, left, right)?, - UInt16 => cmp_dict_primitive::(key_type_lhs, left, right)?, - UInt32 => cmp_dict_primitive::(key_type_lhs, left, right)?, - UInt64 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Float32 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Float64 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Date32 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Date64 => cmp_dict_primitive::(key_type_lhs, left, right)?, - Time32(Second) => { - cmp_dict_primitive::(key_type_lhs, left, right)? - } - Time32(Millisecond) => cmp_dict_primitive::( - key_type_lhs, - left, - right, - )?, - Time64(Microsecond) => cmp_dict_primitive::( - key_type_lhs, - left, - right, - )?, - Time64(Nanosecond) => { - cmp_dict_primitive::(key_type_lhs, left, right)? - } - Timestamp(Second, _) => { - cmp_dict_primitive::(key_type_lhs, left, right)? - } - Timestamp(Millisecond, _) => cmp_dict_primitive::< - TimestampMillisecondType, - >(key_type_lhs, left, right)?, - Timestamp(Microsecond, _) => cmp_dict_primitive::< - TimestampMicrosecondType, - >(key_type_lhs, left, right)?, - Timestamp(Nanosecond, _) => { - cmp_dict_primitive::( - key_type_lhs, - left, - right, - )? - } - Interval(YearMonth) => cmp_dict_primitive::( - key_type_lhs, - left, - right, - )?, - Interval(DayTime) => { - cmp_dict_primitive::(key_type_lhs, left, right)? - } - Interval(MonthDayNano) => cmp_dict_primitive::( - key_type_lhs, - left, - right, - )?, - Duration(Second) => { - cmp_dict_primitive::(key_type_lhs, left, right)? - } - Duration(Millisecond) => cmp_dict_primitive::( - key_type_lhs, - left, - right, - )?, - Duration(Microsecond) => cmp_dict_primitive::( - key_type_lhs, - left, - right, - )?, - Duration(Nanosecond) => cmp_dict_primitive::( - key_type_lhs, - left, - right, - )?, + downcast_primitive! { + value_type_lhs.as_ref() => (cmp_dict_primitive_helper, key_type_lhs, left, right), Utf8 => match key_type_lhs { UInt8 => compare_dict_string::(left, right), UInt16 => compare_dict_string::(left, right), @@ -354,11 +291,6 @@ pub fn build_compare( } } } - (Decimal128(_, _), Decimal128(_, _)) => { - let left: Decimal128Array = Decimal128Array::from(left.data().clone()); - let right: Decimal128Array = Decimal128Array::from(right.data().clone()); - Box::new(move |i, j| left.value(i).cmp(&right.value(j))) - } (FixedSizeBinary(_), FixedSizeBinary(_)) => { let left: FixedSizeBinaryArray = FixedSizeBinaryArray::from(left.data().clone()); @@ -380,6 +312,7 @@ pub fn build_compare( pub mod tests { use super::*; use arrow_array::{FixedSizeBinaryArray, Float64Array, Int32Array}; + use arrow_buffer::i256; use std::cmp::Ordering; #[test] @@ -464,6 +397,23 @@ pub mod tests { assert_eq!(Ordering::Greater, (cmp)(0, 2)); } + #[test] + fn test_decimali256() { + let array = vec![ + Some(i256::from_i128(5_i128)), + Some(i256::from_i128(2_i128)), + Some(i256::from_i128(3_i128)), + ] + .into_iter() + .collect::() + .with_precision_and_scale(53, 6) + .unwrap(); + + let cmp = build_compare(&array, &array).unwrap(); + assert_eq!(Ordering::Less, (cmp)(1, 0)); + assert_eq!(Ordering::Greater, (cmp)(0, 2)); + } + #[test] fn test_dict() { let data = vec!["a", "b", "c", "a", "a", "c", "c"]; @@ -584,4 +534,52 @@ pub mod tests { assert_eq!(Ordering::Greater, (cmp)(3, 1)); assert_eq!(Ordering::Greater, (cmp)(3, 2)); } + + #[test] + fn test_decimal_dict() { + let values = Decimal128Array::from(vec![1, 0, 2, 5]); + let keys = Int8Array::from_iter_values([0, 0, 1, 3]); + let array1 = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let values = Decimal128Array::from(vec![2, 3, 4, 5]); + let keys = Int8Array::from_iter_values([0, 1, 1, 3]); + let array2 = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let cmp = build_compare(&array1, &array2).unwrap(); + + assert_eq!(Ordering::Less, (cmp)(0, 0)); + assert_eq!(Ordering::Less, (cmp)(0, 3)); + assert_eq!(Ordering::Equal, (cmp)(3, 3)); + assert_eq!(Ordering::Greater, (cmp)(3, 1)); + assert_eq!(Ordering::Greater, (cmp)(3, 2)); + } + + #[test] + fn test_decimal256_dict() { + let values = Decimal256Array::from(vec![ + i256::from_i128(1), + i256::from_i128(0), + i256::from_i128(2), + i256::from_i128(5), + ]); + let keys = Int8Array::from_iter_values([0, 0, 1, 3]); + let array1 = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let values = Decimal256Array::from(vec![ + i256::from_i128(2), + i256::from_i128(3), + i256::from_i128(4), + i256::from_i128(5), + ]); + let keys = Int8Array::from_iter_values([0, 1, 1, 3]); + let array2 = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let cmp = build_compare(&array1, &array2).unwrap(); + + assert_eq!(Ordering::Less, (cmp)(0, 0)); + assert_eq!(Ordering::Less, (cmp)(0, 3)); + assert_eq!(Ordering::Equal, (cmp)(3, 3)); + assert_eq!(Ordering::Greater, (cmp)(3, 1)); + assert_eq!(Ordering::Greater, (cmp)(3, 2)); + } }