Skip to content

Commit

Permalink
Verify ArrayData::data_type compatible in PrimitiveArray::from
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jan 3, 2023
1 parent 17b3210 commit e03afed
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
8 changes: 3 additions & 5 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native, ArrowError>,
{
if std::mem::discriminant(&array.value_type())
!= std::mem::discriminant(&T::DATA_TYPE)
{
if !PrimitiveArray::<T>::is_compatible(&array.value_type()) {
return Err(ArrowError::CastError(format!(
"Cannot perform the unary operation of type {} on dictionary array of value type {}",
T::DATA_TYPE,
Expand All @@ -138,7 +136,7 @@ where
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) {
if PrimitiveArray::<T>::is_compatible(t) {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
Expand Down Expand Up @@ -170,7 +168,7 @@ where
)))
},
t => {
if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) {
if PrimitiveArray::<T>::is_compatible(t) {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
Expand Down
30 changes: 26 additions & 4 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,21 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
PrimitiveBuilder::<T>::with_capacity(capacity)
}

/// Returns if this [`PrimitiveArray`] is compatible with the provided [`DataType`]
///
/// This is equivalent to `data_type == T::DATA_TYPE`, however ignores timestamp
/// timezones and decimal precision and scale
pub fn is_compatible(data_type: &DataType) -> bool {
match T::DATA_TYPE {
DataType::Timestamp(t1, _) => {
matches!(data_type, DataType::Timestamp(t2, _) if &t1 == t2)
}
DataType::Decimal128(_, _) => matches!(data_type, DataType::Decimal128(_, _)),
DataType::Decimal256(_, _) => matches!(data_type, DataType::Decimal256(_, _)),
_ => T::DATA_TYPE.eq(data_type),
}
}

/// Returns the primitive value at index `i`.
///
/// # Safety
Expand Down Expand Up @@ -1042,10 +1057,8 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
/// Constructs a `PrimitiveArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
fn from(data: ArrayData) -> Self {
// Use discriminant to allow for decimals
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
std::mem::discriminant(data.data_type()),
assert!(
Self::is_compatible(data.data_type()),
"PrimitiveArray expected ArrayData with type {} got {}",
T::DATA_TYPE,
data.data_type()
Expand Down Expand Up @@ -2205,4 +2218,13 @@ mod tests {
let c = array.unary_mut(|x| x * 2 + 1).unwrap();
assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
}

#[test]
#[should_panic(
expected = "PrimitiveArray expected ArrayData with type Interval(MonthDayNano) got Interval(DayTime)"
)]
fn test_invalid_interval_type() {
let array = IntervalDayTimeArray::from(vec![1, 2, 3]);
let _ = IntervalMonthDayNanoArray::from(array.into_data());
}
}
5 changes: 1 addition & 4 deletions arrow-row/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,7 @@ fn decode_primitive<T: ArrowPrimitiveType>(
where
T::Native: FixedLengthEncoding,
{
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
std::mem::discriminant(&data_type),
);
assert!(PrimitiveArray::<T>::is_compatible(&data_type));

// SAFETY:
// Validated data type above
Expand Down
5 changes: 1 addition & 4 deletions arrow-row/src/fixed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,7 @@ pub fn decode_primitive<T: ArrowPrimitiveType>(
where
T::Native: FixedLengthEncoding,
{
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
std::mem::discriminant(&data_type),
);
assert!(PrimitiveArray::<T>::is_compatible(&data_type));
// SAFETY:
// Validated data type above
unsafe { decode_fixed::<T::Native>(rows, data_type, options).into() }
Expand Down

0 comments on commit e03afed

Please sign in to comment.