From 77c814cfc66845192c1e0be6ae8b0de6a1a0d5f0 Mon Sep 17 00:00:00 2001 From: Remzi Yang <59198230+HaoYang670@users.noreply.github.com> Date: Tue, 9 Aug 2022 18:38:46 +0800 Subject: [PATCH] Rewrite `Decimal` and `DecimalArray` using `const_generic` (#2383) * const generic decimal Signed-off-by: remzi <13716567376yh@gmail.com> * fix docs and lint Signed-off-by: remzi <13716567376yh@gmail.com> * add bound Signed-off-by: remzi <13716567376yh@gmail.com> --- arrow/src/array/array_decimal.rs | 422 ++++++++---------- arrow/src/array/builder/decimal_builder.rs | 22 +- arrow/src/array/equal/mod.rs | 1 - arrow/src/array/iterator.rs | 11 +- arrow/src/array/ord.rs | 1 - arrow/src/array/transform/mod.rs | 2 - arrow/src/compute/kernels/cast.rs | 1 - arrow/src/compute/kernels/sort.rs | 1 - arrow/src/compute/kernels/take.rs | 2 - arrow/src/csv/reader.rs | 1 - arrow/src/ffi.rs | 1 - arrow/src/util/decimal.rs | 229 +++++----- arrow/src/util/display.rs | 1 - arrow/src/util/integration_util.rs | 2 +- arrow/src/util/pretty.rs | 2 +- .../src/arrow/array_reader/primitive_array.rs | 2 +- parquet/src/arrow/arrow_writer/mod.rs | 1 - parquet/src/arrow/buffer/converter.rs | 6 +- 18 files changed, 324 insertions(+), 384 deletions(-) diff --git a/arrow/src/array/array_decimal.rs b/arrow/src/array/array_decimal.rs index 9d7644befd6e..781ed5f8f625 100644 --- a/arrow/src/array/array_decimal.rs +++ b/arrow/src/array/array_decimal.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{ArrayAccessor, Decimal128Iter, Decimal256Iter}; +use crate::array::ArrayAccessor; use num::BigInt; use std::borrow::Borrow; use std::convert::From; @@ -25,17 +25,17 @@ use std::{any::Any, iter::FromIterator}; use super::{ array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, FixedSizeListArray, }; -use super::{BooleanBufferBuilder, FixedSizeBinaryArray}; +use super::{BasicDecimalIter, BooleanBufferBuilder, FixedSizeBinaryArray}; #[allow(deprecated)] pub use crate::array::DecimalIter; use crate::buffer::{Buffer, MutableBuffer}; +use crate::datatypes::DataType; use crate::datatypes::{ validate_decimal256_precision, validate_decimal_precision, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, + DECIMAL_DEFAULT_SCALE, }; -use crate::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE}; use crate::error::{ArrowError, Result}; -use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; +use crate::util::decimal::{BasicDecimal, Decimal256}; /// `Decimal128Array` stores fixed width decimal numbers, /// with a fixed precision and scale. @@ -71,19 +71,9 @@ use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; /// assert_eq!(6, decimal_array.scale()); /// ``` /// -pub struct Decimal128Array { - data: ArrayData, - value_data: RawPtrBox, - precision: usize, - scale: usize, -} +pub type Decimal128Array = BasicDecimalArray<16>; -pub struct Decimal256Array { - data: ArrayData, - value_data: RawPtrBox, - precision: usize, - scale: usize, -} +pub type Decimal256Array = BasicDecimalArray<32>; mod private_decimal { pub trait DecimalArrayPrivate { @@ -91,24 +81,37 @@ mod private_decimal { } } -pub trait BasicDecimalArray>: - private_decimal::DecimalArrayPrivate -{ - const VALUE_LENGTH: i32; - const DEFAULT_TYPE: DataType; - const MAX_PRECISION: usize; - const MAX_SCALE: usize; +pub struct BasicDecimalArray { + data: ArrayData, + value_data: RawPtrBox, + precision: usize, + scale: usize, +} + +impl BasicDecimalArray { + pub const VALUE_LENGTH: i32 = BYTE_WIDTH as i32; + pub const DEFAULT_TYPE: DataType = BasicDecimal::::DEFAULT_TYPE; + pub const MAX_PRECISION: usize = BasicDecimal::::MAX_PRECISION; + pub const MAX_SCALE: usize = BasicDecimal::::MAX_SCALE; + pub const TYPE_CONSTRUCTOR: fn(usize, usize) -> DataType = + BasicDecimal::::TYPE_CONSTRUCTOR; - fn data(&self) -> &ArrayData; + pub fn data(&self) -> &ArrayData { + &self.data + } /// Return the precision (total digits) that can be stored by this array - fn precision(&self) -> usize; + pub fn precision(&self) -> usize { + self.precision + } /// Return the scale (digits after the decimal) that can be stored by this array - fn scale(&self) -> usize; + pub fn scale(&self) -> usize { + self.scale + } /// Returns the element at index `i`. - fn value(&self, i: usize) -> T { + pub fn value(&self, i: usize) -> BasicDecimal { assert!(i < self.data().len(), "Out of bounds access"); unsafe { self.value_unchecked(i) } @@ -117,7 +120,7 @@ pub trait BasicDecimalArray>: /// Returns the element at index `i`. /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array - unsafe fn value_unchecked(&self, i: usize) -> T { + pub unsafe fn value_unchecked(&self, i: usize) -> BasicDecimal { let data = self.data(); let offset = i + data.offset(); let raw_val = { @@ -127,14 +130,14 @@ pub trait BasicDecimalArray>: Self::VALUE_LENGTH as usize, ) }; - T::new(self.precision(), self.scale(), raw_val) + BasicDecimal::::new(self.precision(), self.scale(), raw_val) } /// Returns the offset for the element at index `i`. /// /// Note this doesn't do any bound checking, for performance reason. #[inline] - fn value_offset(&self, i: usize) -> i32 { + pub fn value_offset(&self, i: usize) -> i32 { self.value_offset_at(self.data().offset() + i) } @@ -142,22 +145,22 @@ pub trait BasicDecimalArray>: /// /// All elements have the same length as the array is a fixed size. #[inline] - fn value_length(&self) -> i32 { + pub fn value_length(&self) -> i32 { Self::VALUE_LENGTH } /// Returns a clone of the value data buffer - fn value_data(&self) -> Buffer { + pub fn value_data(&self) -> Buffer { self.data().buffers()[0].clone() } #[inline] - fn value_offset_at(&self, i: usize) -> i32 { + pub fn value_offset_at(&self, i: usize) -> i32 { Self::VALUE_LENGTH * i as i32 } #[inline] - fn value_as_string(&self, row: usize) -> String { + pub fn value_as_string(&self, row: usize) -> String { self.value(row).to_string() } @@ -165,11 +168,11 @@ pub trait BasicDecimalArray>: /// /// NB: This function does not validate that each value is in the permissible /// range for a decimal - fn from_fixed_size_binary_array( + pub fn from_fixed_size_binary_array( v: FixedSizeBinaryArray, precision: usize, scale: usize, - ) -> U { + ) -> Self { assert!( v.value_length() == Self::VALUE_LENGTH, "Value length of the array ({}) must equal to the byte width of the decimal ({})", @@ -184,7 +187,7 @@ pub trait BasicDecimalArray>: let builder = v.into_data().into_builder().data_type(data_type); let array_data = unsafe { builder.build_unchecked() }; - U::from(array_data) + Self::from(array_data) } /// Build a decimal array from [`FixedSizeListArray`]. @@ -192,11 +195,11 @@ pub trait BasicDecimalArray>: /// NB: This function does not validate that each value is in the permissible /// range for a decimal. #[deprecated(note = "please use `from_fixed_size_binary_array` instead")] - fn from_fixed_size_list_array( + pub fn from_fixed_size_list_array( v: FixedSizeListArray, precision: usize, scale: usize, - ) -> U { + ) -> Self { assert_eq!( v.data_ref().child_data().len(), 1, @@ -242,14 +245,49 @@ pub trait BasicDecimalArray>: .offset(list_offset); let array_data = unsafe { builder.build_unchecked() }; - U::from(array_data) + Self::from(array_data) } /// The default precision and scale used when not specified. - fn default_type() -> DataType { + pub const fn default_type() -> DataType { Self::DEFAULT_TYPE } + fn raw_value_data_ptr(&self) -> *const u8 { + self.value_data.as_ptr() + } +} + +impl Decimal128Array { + /// Creates a [Decimal128Array] with default precision and scale, + /// based on an iterator of `i128` values without nulls + pub fn from_iter_values>(iter: I) -> Self { + let val_buf: Buffer = iter.into_iter().collect(); + let data = unsafe { + ArrayData::new_unchecked( + Self::default_type(), + val_buf.len() / std::mem::size_of::(), + None, + None, + 0, + vec![val_buf], + vec![], + ) + }; + Decimal128Array::from(data) + } + + /// Validates decimal values in this array can be properly interpreted + /// with the specified precision. + pub fn validate_decimal_precision(&self, precision: usize) -> Result<()> { + if precision < self.precision { + for v in self.iter().flatten() { + validate_decimal_precision(v.as_i128(), precision)?; + } + } + Ok(()) + } + /// Returns a Decimal array with the same data as self, with the /// specified precision. /// @@ -257,7 +295,7 @@ pub trait BasicDecimalArray>: /// 1. `precision` is larger than [`Self::MAX_PRECISION`] /// 2. `scale` is larger than [`Self::MAX_SCALE`]; /// 3. `scale` is > `precision` - fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result + pub fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result where Self: Sized, { @@ -287,116 +325,86 @@ pub trait BasicDecimalArray>: // decreased self.validate_decimal_precision(precision)?; - let data_type = if Self::VALUE_LENGTH == 16 { - DataType::Decimal128(self.precision(), self.scale()) - } else { - DataType::Decimal256(self.precision(), self.scale()) - }; + let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale); assert_eq!(self.data().data_type(), &data_type); // safety: self.data is valid DataType::Decimal as checked above - let new_data_type = if Self::VALUE_LENGTH == 16 { - DataType::Decimal128(precision, scale) - } else { - DataType::Decimal256(precision, scale) - }; + let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale); Ok(self.data().clone().with_data_type(new_data_type).into()) } +} +impl Decimal256Array { /// Validates decimal values in this array can be properly interpreted /// with the specified precision. - fn validate_decimal_precision(&self, precision: usize) -> Result<()>; -} - -impl BasicDecimalArray for Decimal128Array { - const VALUE_LENGTH: i32 = 16; - const DEFAULT_TYPE: DataType = - DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); - const MAX_PRECISION: usize = DECIMAL128_MAX_PRECISION; - const MAX_SCALE: usize = DECIMAL128_MAX_SCALE; - - fn data(&self) -> &ArrayData { - &self.data - } - - fn precision(&self) -> usize { - self.precision - } - - fn scale(&self) -> usize { - self.scale - } - - fn validate_decimal_precision(&self, precision: usize) -> Result<()> { + pub fn validate_decimal_precision(&self, precision: usize) -> Result<()> { if precision < self.precision { for v in self.iter().flatten() { - validate_decimal_precision(v.as_i128(), precision)?; + validate_decimal256_precision(&v.to_string(), precision)?; } } Ok(()) } -} - -impl BasicDecimalArray for Decimal256Array { - const VALUE_LENGTH: i32 = 32; - const DEFAULT_TYPE: DataType = - DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); - const MAX_PRECISION: usize = DECIMAL256_MAX_PRECISION; - const MAX_SCALE: usize = DECIMAL256_MAX_SCALE; - fn data(&self) -> &ArrayData { - &self.data - } + /// Returns a Decimal array with the same data as self, with the + /// specified precision. + /// + /// Returns an Error if: + /// 1. `precision` is larger than [`Self::MAX_PRECISION`] + /// 2. `scale` is larger than [`Self::MAX_SCALE`]; + /// 3. `scale` is > `precision` + pub fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result + where + Self: Sized, + { + if precision > Self::MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, + Self::MAX_PRECISION + ))); + } + if scale > Self::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, + Self::MAX_SCALE + ))); + } + if scale > precision { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than precision {}", + scale, precision + ))); + } - fn precision(&self) -> usize { - self.precision - } + // Ensure that all values are within the requested + // precision. For performance, only check if the precision is + // decreased + self.validate_decimal_precision(precision)?; - fn scale(&self) -> usize { - self.scale - } + let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale); + assert_eq!(self.data().data_type(), &data_type); - fn validate_decimal_precision(&self, precision: usize) -> Result<()> { - if precision < self.precision { - for v in self.iter().flatten() { - validate_decimal256_precision(&v.to_string(), precision)?; - } - } - Ok(()) - } -} + // safety: self.data is valid DataType::Decimal as checked above + let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale); -impl Decimal128Array { - /// Creates a [Decimal128Array] with default precision and scale, - /// based on an iterator of `i128` values without nulls - pub fn from_iter_values>(iter: I) -> Self { - let val_buf: Buffer = iter.into_iter().collect(); - let data = unsafe { - ArrayData::new_unchecked( - Self::default_type(), - val_buf.len() / std::mem::size_of::(), - None, - None, - 0, - vec![val_buf], - vec![], - ) - }; - Decimal128Array::from(data) + Ok(self.data().clone().with_data_type(new_data_type).into()) } } -impl From for Decimal128Array { +impl From for BasicDecimalArray { fn from(data: ArrayData) -> Self { assert_eq!( data.buffers().len(), 1, - "Decimal128Array data should contain 1 buffer only (values)" + "DecimalArray data should contain 1 buffer only (values)" ); let values = data.buffers()[0].as_ptr(); - let (precision, scale) = match data.data_type() { - DataType::Decimal128(precision, scale) => (*precision, *scale), + let (precision, scale) = match (data.data_type(), BYTE_WIDTH) { + (DataType::Decimal128(precision, scale), 16) + | (DataType::Decimal256(precision, scale), 32) => (*precision, *scale), _ => panic!("Expected data type to be Decimal"), }; Self { @@ -408,27 +416,6 @@ impl From for Decimal128Array { } } -impl From for Decimal256Array { - fn from(data: ArrayData) -> Self { - assert_eq!( - data.buffers().len(), - 1, - "Decimal256Array data should contain 1 buffer only (values)" - ); - let values = data.buffers()[0].as_ptr(); - let (precision, scale) = match data.data_type() { - DataType::Decimal256(precision, scale) => (*precision, *scale), - _ => panic!("Expected data type to be Decimal256"), - }; - Self { - data, - value_data: unsafe { RawPtrBox::new(values) }, - precision, - scale, - } - } -} - impl<'a> Decimal128Array { /// Constructs a new iterator that iterates `Decimal128` values as i128 values. /// This is kept mostly for back-compatibility purpose. @@ -446,17 +433,13 @@ impl From for Decimal256 { } } -fn build_decimal_array_from, T>( +fn build_decimal_array_from( null_buf: BooleanBufferBuilder, buffer: Buffer, -) -> U -where - T: BasicDecimal, - U: From, -{ +) -> BasicDecimalArray { let data = unsafe { ArrayData::new_unchecked( - U::default_type(), + BasicDecimalArray::::default_type(), null_buf.len(), None, Some(null_buf.into()), @@ -465,7 +448,7 @@ where vec![], ) }; - U::from(data) + BasicDecimalArray::::from(data) } impl> FromIterator> for Decimal256Array { @@ -488,7 +471,7 @@ impl> FromIterator> for Decimal256Array { } }); - build_decimal_array_from::(null_buf, buffer.into()) + build_decimal_array_from::<32>(null_buf, buffer.into()) } } @@ -513,96 +496,75 @@ impl>> FromIterator for Decimal128Array { }) .collect(); - build_decimal_array_from::(null_buf, buffer) + build_decimal_array_from::<16>(null_buf, buffer) } } -macro_rules! def_decimal_array { - ($ty:ident, $array_name:expr, $decimal_ty:ident, $iter_ty:ident) => { - impl private_decimal::DecimalArrayPrivate for $ty { - fn raw_value_data_ptr(&self) -> *const u8 { - self.value_data.as_ptr() - } - } - - impl Array for $ty { - fn as_any(&self) -> &dyn Any { - self - } - - fn data(&self) -> &ArrayData { - &self.data - } +impl Array for BasicDecimalArray { + fn as_any(&self) -> &dyn Any { + self + } - fn into_data(self) -> ArrayData { - self.into() - } - } + fn data(&self) -> &ArrayData { + &self.data + } - impl From<$ty> for ArrayData { - fn from(array: $ty) -> Self { - array.data - } - } + fn into_data(self) -> ArrayData { + self.into() + } +} - impl fmt::Debug for $ty { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}<{}, {}>\n[\n", - $array_name, self.precision, self.scale - )?; - print_long_array(self, f, |array, index, f| { - let formatted_decimal = array.value_as_string(index); - - write!(f, "{}", formatted_decimal) - })?; - write!(f, "]") - } - } +impl From> for ArrayData { + fn from(array: BasicDecimalArray) -> Self { + array.data + } +} - impl<'a> ArrayAccessor for &'a $ty { - type Item = $decimal_ty; +impl fmt::Debug for BasicDecimalArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Decimal{}Array<{}, {}>\n[\n", + BYTE_WIDTH * 8, + self.precision, + self.scale + )?; + print_long_array(self, f, |array, index, f| { + let formatted_decimal = array.value_as_string(index); + + write!(f, "{}", formatted_decimal) + })?; + write!(f, "]") + } +} - fn value(&self, index: usize) -> Self::Item { - $ty::value(self, index) - } +impl<'a, const BYTE_WIDTH: usize> ArrayAccessor for &'a BasicDecimalArray { + type Item = BasicDecimal; - unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - $ty::value_unchecked(self, index) - } - } + fn value(&self, index: usize) -> Self::Item { + BasicDecimalArray::::value(self, index) + } - impl<'a> IntoIterator for &'a $ty { - type Item = Option<$decimal_ty>; - type IntoIter = $iter_ty<'a>; + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + BasicDecimalArray::::value_unchecked(self, index) + } +} - fn into_iter(self) -> Self::IntoIter { - $iter_ty::<'a>::new(self) - } - } +impl<'a, const BYTE_WIDTH: usize> IntoIterator for &'a BasicDecimalArray { + type Item = Option>; + type IntoIter = BasicDecimalIter<'a, BYTE_WIDTH>; - impl<'a> $ty { - /// constructs a new iterator - pub fn iter(&'a self) -> $iter_ty<'a> { - $iter_ty::<'a>::new(self) - } - } - }; + fn into_iter(self) -> Self::IntoIter { + BasicDecimalIter::<'a, BYTE_WIDTH>::new(self) + } } -def_decimal_array!( - Decimal128Array, - "Decimal128Array", - Decimal128, - Decimal128Iter -); -def_decimal_array!( - Decimal256Array, - "Decimal256Array", - Decimal256, - Decimal256Iter -); +impl<'a, const BYTE_WIDTH: usize> BasicDecimalArray { + /// constructs a new iterator + pub fn iter(&'a self) -> BasicDecimalIter<'a, BYTE_WIDTH> { + BasicDecimalIter::<'a, BYTE_WIDTH>::new(self) + } +} #[cfg(test)] mod tests { diff --git a/arrow/src/array/builder/decimal_builder.rs b/arrow/src/array/builder/decimal_builder.rs index 22c1490e86f3..bd43100b7b71 100644 --- a/arrow/src/array/builder/decimal_builder.rs +++ b/arrow/src/array/builder/decimal_builder.rs @@ -19,7 +19,7 @@ use num::BigInt; use std::any::Any; use std::sync::Arc; -use crate::array::array_decimal::{BasicDecimalArray, Decimal256Array}; +use crate::array::array_decimal::Decimal256Array; use crate::array::ArrayRef; use crate::array::Decimal128Array; use crate::array::{ArrayBuilder, FixedSizeBinaryBuilder}; @@ -27,7 +27,7 @@ use crate::array::{ArrayBuilder, FixedSizeBinaryBuilder}; use crate::error::{ArrowError, Result}; use crate::datatypes::{validate_decimal256_precision, validate_decimal_precision}; -use crate::util::decimal::{BasicDecimal, Decimal256}; +use crate::util::decimal::Decimal256; /// Array Builder for [`Decimal128Array`] /// @@ -258,7 +258,7 @@ mod tests { use super::*; use num::Num; - use crate::array::array_decimal::{BasicDecimalArray, Decimal128Array}; + use crate::array::array_decimal::Decimal128Array; use crate::array::{array_decimal, Array}; use crate::datatypes::DataType; use crate::util::decimal::{Decimal128, Decimal256}; @@ -305,21 +305,21 @@ mod tests { fn test_decimal256_builder() { let mut builder = Decimal256Builder::new(30, 40, 6); - let mut bytes = vec![0; 32]; + let mut bytes = [0_u8; 32]; bytes[0..16].clone_from_slice(&8_887_000_000_i128.to_le_bytes()); - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); builder.append_null(); - bytes = vec![255; 32]; - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + bytes = [255; 32]; + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); - bytes = vec![0; 32]; + bytes = [0; 32]; bytes[0..16].clone_from_slice(&0_i128.to_le_bytes()); bytes[15] = 128; - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); builder.append_option(None::<&Decimal256>).unwrap(); @@ -349,9 +349,9 @@ mod tests { fn test_decimal256_builder_unmatched_precision_scale() { let mut builder = Decimal256Builder::new(30, 10, 6); - let mut bytes = vec![0; 32]; + let mut bytes = [0_u8; 32]; bytes[0..16].clone_from_slice(&8_887_000_000_i128.to_le_bytes()); - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); } diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 3387e2842264..6fdc06f837c0 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -262,7 +262,6 @@ mod tests { use std::convert::TryFrom; use std::sync::Arc; - use crate::array::BasicDecimalArray; use crate::array::{ array::Array, ArrayData, ArrayDataBuilder, ArrayRef, BooleanArray, FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray, Int32Builder, diff --git a/arrow/src/array/iterator.rs b/arrow/src/array/iterator.rs index 8ee9f25447d3..7cc9bde6b4c5 100644 --- a/arrow/src/array/iterator.rs +++ b/arrow/src/array/iterator.rs @@ -16,7 +16,7 @@ // under the License. use crate::array::array::ArrayAccessor; -use crate::array::{BasicDecimalArray, Decimal256Array}; +use crate::array::BasicDecimalArray; use super::{ Array, BooleanArray, Decimal128Array, GenericBinaryArray, GenericListArray, @@ -104,14 +104,15 @@ pub type GenericStringIter<'a, T> = ArrayIter<&'a GenericStringArray>; pub type GenericBinaryIter<'a, T> = ArrayIter<&'a GenericBinaryArray>; pub type GenericListArrayIter<'a, O> = ArrayIter<&'a GenericListArray>; +pub type BasicDecimalIter<'a, const BYTE_WIDTH: usize> = + ArrayIter<&'a BasicDecimalArray>; /// an iterator that returns `Some(Decimal128)` or `None`, that can be used on a /// [`Decimal128Array`] -pub type Decimal128Iter<'a> = ArrayIter<&'a Decimal128Array>; +pub type Decimal128Iter<'a> = BasicDecimalIter<'a, 16>; /// an iterator that returns `Some(Decimal256)` or `None`, that can be used on a -/// [`Decimal256Array`] -pub type Decimal256Iter<'a> = ArrayIter<&'a Decimal256Array>; - +/// [`super::Decimal256Array`] +pub type Decimal256Iter<'a> = BasicDecimalIter<'a, 32>; /// an iterator that returns `Some(i128)` or `None`, that can be used on a /// [`Decimal128Array`] #[derive(Debug)] diff --git a/arrow/src/array/ord.rs b/arrow/src/array/ord.rs index 1e19c7cc2fca..47173aa7d927 100644 --- a/arrow/src/array/ord.rs +++ b/arrow/src/array/ord.rs @@ -19,7 +19,6 @@ use std::cmp::Ordering; -use crate::array::BasicDecimalArray; use crate::array::*; use crate::datatypes::TimeUnit; use crate::datatypes::*; diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index 409f2dd143ea..f0fccef14fd7 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -670,8 +670,6 @@ mod tests { use std::{convert::TryFrom, sync::Arc}; use super::*; - - use crate::array::BasicDecimalArray; use crate::array::Decimal128Array; use crate::{ array::{ diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index c6b8f477986f..c9082afefad9 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -2420,7 +2420,6 @@ where #[cfg(test)] mod tests { use super::*; - use crate::array::BasicDecimalArray; use crate::datatypes::TimeUnit; use crate::util::decimal::Decimal128; use crate::{buffer::Buffer, util::display::array_value_to_string}; diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 0a3d0541ce3c..dca09a66a8cf 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -17,7 +17,6 @@ //! Defines sort kernel for `ArrayRef` -use crate::array::BasicDecimalArray; use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::take; diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 05832e4f5279..fb8f75651882 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -19,8 +19,6 @@ use std::{ops::AddAssign, sync::Arc}; -use crate::array::BasicDecimalArray; - use crate::buffer::{Buffer, MutableBuffer}; use crate::compute::util::{ take_value_indices_from_fixed_size_list, take_value_indices_from_list, diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs index 7c533a8f8b24..f01ce37c7399 100644 --- a/arrow/src/csv/reader.rs +++ b/arrow/src/csv/reader.rs @@ -1116,7 +1116,6 @@ mod tests { use std::io::{Cursor, Write}; use tempfile::NamedTempFile; - use crate::array::BasicDecimalArray; use crate::array::*; use crate::compute::cast; use crate::datatypes::Field; diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 0716a49d634c..528f3adc2d84 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -911,7 +911,6 @@ impl<'a> ArrowArrayChild<'a> { #[cfg(test)] mod tests { use super::*; - use crate::array::BasicDecimalArray; use crate::array::{ export_array_into_raw, make_array, Array, ArrayData, BooleanArray, Decimal128Array, DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, diff --git a/arrow/src/util/decimal.rs b/arrow/src/util/decimal.rs index 62a950795378..74f3379f4c10 100644 --- a/arrow/src/util/decimal.rs +++ b/arrow/src/util/decimal.rs @@ -18,21 +18,51 @@ //! Decimal related utils use crate::datatypes::{ - DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, + DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, }; use crate::error::{ArrowError, Result}; use num::bigint::BigInt; use num::Signed; use std::cmp::{min, Ordering}; -pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { - /// The bit-width of the internal representation. - const BIT_WIDTH: usize; - /// The maximum precision. - const MAX_PRECISION: usize; - /// The maximum scale. - const MAX_SCALE: usize; +#[derive(Debug)] +pub struct BasicDecimal { + precision: usize, + scale: usize, + value: [u8; BYTE_WIDTH], +} + +impl BasicDecimal { + #[allow(clippy::type_complexity)] + const _MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE: ( + usize, + usize, + fn(usize, usize) -> DataType, + DataType, + ) = match BYTE_WIDTH { + 16 => ( + DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, + DataType::Decimal128, + DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE), + ), + 32 => ( + DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, + DataType::Decimal256, + DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE), + ), + _ => panic!("invalid byte width"), + }; + + pub const MAX_PRECISION: usize = + Self::_MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.0; + pub const MAX_SCALE: usize = Self::_MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.1; + pub const TYPE_CONSTRUCTOR: fn(usize, usize) -> DataType = + Self::_MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.2; + pub const DEFAULT_TYPE: DataType = + Self::_MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.3; /// Tries to create a decimal value from precision, scale and bytes. /// If the length of bytes isn't same as the bit width of this decimal, @@ -41,7 +71,11 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { /// Safety: /// This method doesn't validate if the decimal value represented by the bytes /// can be fitted into the specified precision. - fn try_new_from_bytes(precision: usize, scale: usize, bytes: &[u8]) -> Result + pub fn try_new_from_bytes( + precision: usize, + scale: usize, + bytes: &[u8; BYTE_WIDTH], + ) -> Result where Self: Sized, { @@ -67,13 +101,13 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { ))); } - if bytes.len() == Self::BIT_WIDTH / 8 { + if bytes.len() == BYTE_WIDTH { Ok(Self::new(precision, scale, bytes)) } else { Err(ArrowError::InvalidArgumentError(format!( "Input to Decimal{} must be {} bytes", - Self::BIT_WIDTH, - Self::BIT_WIDTH / 8 + BYTE_WIDTH * 8, + BYTE_WIDTH ))) } } @@ -83,21 +117,33 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { /// Safety: /// This method doesn't check if the length of bytes is compatible with this decimal. /// Use `try_new_from_bytes` for safe constructor. - fn new(precision: usize, scale: usize, bytes: &[u8]) -> Self; - + pub fn new(precision: usize, scale: usize, bytes: &[u8]) -> Self { + Self { + precision, + scale, + value: bytes.try_into().unwrap(), + } + } /// Returns the raw bytes of the integer representation of the decimal. - fn raw_value(&self) -> &[u8]; + pub fn raw_value(&self) -> &[u8] { + &self.value + } /// Returns the precision of the decimal. - fn precision(&self) -> usize; + pub fn precision(&self) -> usize { + self.precision + } /// Returns the scale of the decimal. - fn scale(&self) -> usize; + pub fn scale(&self) -> usize { + self.scale + } /// Returns the string representation of the decimal. /// If the string representation cannot be fitted with the precision of the decimal, /// the string will be truncated. - fn to_string(&self) -> String { + #[allow(clippy::inherent_to_string)] + pub fn to_string(&self) -> String { let raw_bytes = self.raw_value(); let integer = BigInt::from_signed_bytes_le(raw_bytes); let value_str = integer.to_string(); @@ -119,15 +165,44 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { } } +impl PartialOrd for BasicDecimal { + fn partial_cmp(&self, other: &Self) -> Option { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimals with different scale: {}, {}", + self.scale, other.scale + ); + Some(singed_cmp_le_bytes(&self.value, &other.value)) + } +} + +impl Ord for BasicDecimal { + fn cmp(&self, other: &Self) -> Ordering { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimals with different scale: {}, {}", + self.scale, other.scale + ); + singed_cmp_le_bytes(&self.value, &other.value) + } +} + +impl PartialEq for BasicDecimal { + fn eq(&self, other: &Self) -> bool { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimals with different scale: {}, {}", + self.scale, other.scale + ); + self.value.eq(&other.value) + } +} + +impl Eq for BasicDecimal {} + /// Represents a decimal value with precision and scale. /// The decimal value could represented by a signed 128-bit integer. -#[derive(Debug)] -pub struct Decimal128 { - #[allow(dead_code)] - precision: usize, - scale: usize, - value: [u8; 16], -} +pub type Decimal128 = BasicDecimal<16>; impl Decimal128 { /// Creates `Decimal128` from an `i128` value. @@ -154,13 +229,7 @@ impl From for i128 { /// Represents a decimal value with precision and scale. /// The decimal value could be represented by a signed 256-bit integer. -#[derive(Debug)] -pub struct Decimal256 { - #[allow(dead_code)] - precision: usize, - scale: usize, - value: [u8; 32], -} +pub type Decimal256 = BasicDecimal<32>; impl Decimal256 { /// Constructs a `Decimal256` value from a `BigInt`. @@ -170,9 +239,9 @@ impl Decimal256 { scale: usize, ) -> Result { let mut bytes = if num.is_negative() { - vec![255; 32] + [255_u8; 32] } else { - vec![0; 32] + [0; 32] }; let num_bytes = &num.to_signed_bytes_le(); bytes[0..num_bytes.len()].clone_from_slice(num_bytes); @@ -180,71 +249,6 @@ impl Decimal256 { } } -macro_rules! def_decimal { - ($ty:ident, $bit:expr, $max_p:expr, $max_s:expr) => { - impl BasicDecimal for $ty { - const BIT_WIDTH: usize = $bit; - const MAX_PRECISION: usize = $max_p; - const MAX_SCALE: usize = $max_s; - - fn new(precision: usize, scale: usize, bytes: &[u8]) -> Self { - $ty { - precision, - scale, - value: bytes.try_into().unwrap(), - } - } - - fn raw_value(&self) -> &[u8] { - &self.value - } - - fn precision(&self) -> usize { - self.precision - } - - fn scale(&self) -> usize { - self.scale - } - } - - impl PartialOrd for $ty { - fn partial_cmp(&self, other: &Self) -> Option { - assert_eq!( - self.scale, other.scale, - "Cannot compare two Decimals with different scale: {}, {}", - self.scale, other.scale - ); - Some(singed_cmp_le_bytes(&self.value, &other.value)) - } - } - - impl Ord for $ty { - fn cmp(&self, other: &Self) -> Ordering { - assert_eq!( - self.scale, other.scale, - "Cannot compare two Decimals with different scale: {}, {}", - self.scale, other.scale - ); - singed_cmp_le_bytes(&self.value, &other.value) - } - } - - impl PartialEq for $ty { - fn eq(&self, other: &Self) -> bool { - assert_eq!( - self.scale, other.scale, - "Cannot compare two Decimals with different scale: {}, {}", - self.scale, other.scale - ); - self.value.eq(&other.value) - } - } - - impl Eq for $ty {} - }; -} - // compare two signed integer which are encoded with little endian. // left bytes and right bytes must have the same length. fn singed_cmp_le_bytes(left: &[u8], right: &[u8]) -> Ordering { @@ -286,24 +290,9 @@ fn singed_cmp_le_bytes(left: &[u8], right: &[u8]) -> Ordering { Ordering::Equal } -def_decimal!( - Decimal128, - 128, - DECIMAL128_MAX_PRECISION, - DECIMAL128_MAX_SCALE -); -def_decimal!( - Decimal256, - 256, - DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE -); - #[cfg(test)] mod tests { - use crate::util::decimal::{ - singed_cmp_le_bytes, BasicDecimal, Decimal128, Decimal256, - }; + use super::*; use num::{BigInt, Num}; use rand::random; @@ -356,9 +345,9 @@ mod tests { #[test] fn decimal_256_from_bytes() { - let mut bytes = vec![0; 32]; + let mut bytes = [0_u8; 32]; bytes[0..16].clone_from_slice(&100_i128.to_le_bytes()); - let value = Decimal256::try_new_from_bytes(5, 2, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(5, 2, &bytes).unwrap(); assert_eq!(value.to_string(), "1.00"); bytes[0..16].clone_from_slice(&i128::MAX.to_le_bytes()); @@ -378,7 +367,7 @@ mod tests { ); // smaller than i128 minimum - bytes = vec![255; 32]; + bytes = [255; 32]; bytes[31] = 128; let value = Decimal256::try_new_from_bytes(76, 4, &bytes).unwrap(); assert_eq!( @@ -386,7 +375,7 @@ mod tests { "-574437317700748313234121683441537667865831564552201235664496608164256541.5731" ); - bytes = vec![255; 32]; + bytes = [255; 32]; let value = Decimal256::try_new_from_bytes(5, 2, &bytes).unwrap(); assert_eq!(value.to_string(), "-0.01"); } diff --git a/arrow/src/util/display.rs b/arrow/src/util/display.rs index 26bc8a1923a6..aa4fd4200870 100644 --- a/arrow/src/util/display.rs +++ b/arrow/src/util/display.rs @@ -23,7 +23,6 @@ use std::fmt::Write; use std::sync::Arc; use crate::array::Array; -use crate::array::BasicDecimalArray; use crate::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, diff --git a/arrow/src/util/integration_util.rs b/arrow/src/util/integration_util.rs index 65d54a02b64f..ee5c947a2fff 100644 --- a/arrow/src/util/integration_util.rs +++ b/arrow/src/util/integration_util.rs @@ -34,7 +34,7 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::record_batch::{RecordBatch, RecordBatchReader}; use crate::util::bit_util; -use crate::util::decimal::{BasicDecimal, Decimal256}; +use crate::util::decimal::Decimal256; /// A struct that represents an Arrow file with a schema and record batches #[derive(Deserialize, Serialize, Debug)] diff --git a/arrow/src/util/pretty.rs b/arrow/src/util/pretty.rs index 84d445e9a1f8..6f4d9e34a99b 100644 --- a/arrow/src/util/pretty.rs +++ b/arrow/src/util/pretty.rs @@ -107,7 +107,7 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result { mod tests { use crate::{ array::{ - self, new_null_array, Array, BasicDecimalArray, Date32Array, Date64Array, + self, new_null_array, Array, Date32Array, Date64Array, FixedSizeBinaryBuilder, Float16Array, Int32Array, PrimitiveBuilder, StringArray, StringBuilder, StringDictionaryBuilder, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index 516a3f50c712..61883bc7029a 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -25,7 +25,7 @@ use crate::data_type::DataType; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use arrow::array::{ - ArrayDataBuilder, ArrayRef, BasicDecimalArray, BooleanArray, BooleanBufferBuilder, + ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, Decimal128Array, Float32Array, Float64Array, Int32Array, Int64Array, }; use arrow::buffer::Buffer; diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 800aff98a6f4..08f37c395658 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use arrow::array as arrow_array; use arrow::array::ArrayRef; -use arrow::array::BasicDecimalArray; use arrow::datatypes::{DataType as ArrowDataType, IntervalUnit, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_array::Array; diff --git a/parquet/src/arrow/buffer/converter.rs b/parquet/src/arrow/buffer/converter.rs index d8cbd256a460..aa98b48d5d1e 100644 --- a/parquet/src/arrow/buffer/converter.rs +++ b/parquet/src/arrow/buffer/converter.rs @@ -17,9 +17,9 @@ use crate::data_type::{ByteArray, FixedLenByteArray, Int96}; use arrow::array::{ - Array, ArrayRef, BasicDecimalArray, Decimal128Array, FixedSizeBinaryArray, - FixedSizeBinaryBuilder, IntervalDayTimeArray, IntervalDayTimeBuilder, - IntervalYearMonthArray, IntervalYearMonthBuilder, TimestampNanosecondArray, + Array, ArrayRef, Decimal128Array, FixedSizeBinaryArray, FixedSizeBinaryBuilder, + IntervalDayTimeArray, IntervalDayTimeBuilder, IntervalYearMonthArray, + IntervalYearMonthBuilder, TimestampNanosecondArray, }; use std::sync::Arc;