Skip to content

Commit

Permalink
Add PrimitiveBuilder type constructors (#4401)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold authored Jun 12, 2023
1 parent 481c197 commit c1283f1
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 49 deletions.
51 changes: 3 additions & 48 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,64 +1388,19 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
/// Returns a Decimal array with the same data as self, with the
/// specified precision and scale.
///
/// Returns an Error if:
/// - `precision` is zero
/// - `precision` is larger than `T:MAX_PRECISION`
/// - `scale` is larger than `T::MAX_SCALE`
/// - `scale` is > `precision`
/// See [`validate_decimal_precision_and_scale`]
pub fn with_precision_and_scale(
self,
precision: u8,
scale: i8,
) -> Result<Self, ArrowError>
where
Self: Sized,
{
// validate precision and scale
self.validate_precision_scale(precision, scale)?;

// safety: self.data is valid DataType::Decimal as checked above
) -> Result<Self, ArrowError> {
validate_decimal_precision_and_scale::<T>(precision, scale)?;
Ok(Self {
data_type: T::TYPE_CONSTRUCTOR(precision, scale),
..self
})
}

// validate that the new precision and scale are valid or not
fn validate_precision_scale(
&self,
precision: u8,
scale: i8,
) -> Result<(), ArrowError> {
if precision == 0 {
return Err(ArrowError::InvalidArgumentError(format!(
"precision cannot be 0, has to be between [1, {}]",
T::MAX_PRECISION
)));
}
if precision > T::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision,
T::MAX_PRECISION
)));
}
if scale > T::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale,
T::MAX_SCALE
)));
}
if scale > 0 && scale as u8 > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {scale} is greater than precision {precision}"
)));
}

Ok(())
}

/// Validates values in this array can be properly interpreted
/// with the specified precision.
pub fn validate_decimal_precision(&self, precision: u8) -> Result<(), ArrowError> {
Expand Down
32 changes: 31 additions & 1 deletion arrow-array/src/builder/primitive_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow_buffer::NullBufferBuilder;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use arrow_schema::{ArrowError, DataType};
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -331,6 +331,36 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
}
}

impl<P: DecimalType> PrimitiveBuilder<P> {
/// Sets the precision and scale
pub fn with_precision_and_scale(
self,
precision: u8,
scale: i8,
) -> Result<Self, ArrowError> {
validate_decimal_precision_and_scale::<P>(precision, scale)?;
Ok(Self {
data_type: P::TYPE_CONSTRUCTOR(precision, scale),
..self
})
}
}

impl<P: ArrowTimestampType> PrimitiveBuilder<P> {
/// Sets the timezone
pub fn with_timezone(self, timezone: impl Into<Arc<str>>) -> Self {
self.with_timezone_opt(Some(timezone.into()))
}

/// Sets an optional timezone
pub fn with_timezone_opt<S: Into<Arc<str>>>(self, timezone: Option<S>) -> Self {
Self {
data_type: DataType::Timestamp(P::UNIT, timezone.map(Into::into)),
..self
}
}
}

impl<P: ArrowPrimitiveType> Extend<Option<P::Native>> for PrimitiveBuilder<P> {
#[inline]
fn extend<T: IntoIterator<Item = Option<P::Native>>>(&mut self, iter: T) {
Expand Down
40 changes: 40 additions & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,46 @@ pub trait DecimalType:
) -> Result<(), ArrowError>;
}

/// Validate that `precision` and `scale` are valid for `T`
///
/// Returns an Error if:
/// - `precision` is zero
/// - `precision` is larger than `T:MAX_PRECISION`
/// - `scale` is larger than `T::MAX_SCALE`
/// - `scale` is > `precision`
pub fn validate_decimal_precision_and_scale<T: DecimalType>(
precision: u8,
scale: i8,
) -> Result<(), ArrowError> {
if precision == 0 {
return Err(ArrowError::InvalidArgumentError(format!(
"precision cannot be 0, has to be between [1, {}]",
T::MAX_PRECISION
)));
}
if precision > T::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision,
T::MAX_PRECISION
)));
}
if scale > T::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale,
T::MAX_SCALE
)));
}
if scale > 0 && scale as u8 > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {scale} is greater than precision {precision}"
)));
}

Ok(())
}

/// The decimal type for a Decimal128Array
#[derive(Debug)]
pub struct Decimal128Type {}
Expand Down

0 comments on commit c1283f1

Please sign in to comment.