Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added dynamic version of neg. (#685)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Dec 17, 2021
1 parent 3685ae8 commit e84a140
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 3 deletions.
68 changes: 67 additions & 1 deletion src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub mod decimal;
pub mod time;

use crate::{
array::{Array, PrimitiveArray},
array::{Array, DictionaryArray, PrimitiveArray},
bitmap::Bitmap,
datatypes::{DataType, IntervalUnit, TimeUnit},
scalar::{PrimitiveScalar, Scalar},
Expand Down Expand Up @@ -400,6 +400,72 @@ pub fn can_rem(lhs: &DataType, rhs: &DataType) -> bool {
)
}

macro_rules! with_match_negatable {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
DaysMs => __with_ty__! { days_ms },
MonthDayNano => __with_ty__! { months_days_ns },
UInt8 | UInt16 | UInt32 | UInt64=> todo!(),
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
})}

/// Negates an [`Array`].
/// # Panic
/// This function panics iff either
/// * the opertion is not supported for the logical type (use [`can_neg`] to check)
/// * the operation overflows
pub fn neg(array: &dyn Array) -> Box<dyn Array> {
use crate::datatypes::PhysicalType::*;
match array.data_type().to_physical_type() {
Primitive(primitive) => with_match_negatable!(primitive, |$T| {
let array = array.as_any().downcast_ref().unwrap();

let result = basic::negate::<$T>(array);
Box::new(result) as Box<dyn Array>
}),
Dictionary(key) => match_integer_type!(key, |$T| {
let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();

let values = neg(array.values().as_ref()).into();

Box::new(DictionaryArray::<$T>::from_data(array.keys().clone(), values)) as Box<dyn Array>
}),
_ => todo!(),
}
}

/// Whether [`neg`] is supported for a given [`DataType`]
pub fn can_neg(data_type: &DataType) -> bool {
if let DataType::Dictionary(_, values) = data_type.to_logical_type() {
return can_neg(values.as_ref());
}

use crate::datatypes::PhysicalType::*;
use crate::datatypes::PrimitiveType::*;
matches!(
data_type.to_physical_type(),
Primitive(Int8)
| Primitive(Int16)
| Primitive(Int32)
| Primitive(Int64)
| Primitive(Float64)
| Primitive(Float32)
| Primitive(DaysMs)
| Primitive(MonthDayNano)
)
}

/// Defines basic addition operation for primitive arrays
pub trait ArrayAdd<Rhs>: Sized {
/// Adds itself to `rhs`
Expand Down
20 changes: 19 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! represent chunks of bits (e.g. `u8`, `u16`), and [`BitChunkIter`], that can be used to
//! iterate over bitmaps in [`BitChunk`]s.
//! Finally, this module also contains traits used to compile code optimized for SIMD instructions at [`mod@simd`].
use std::convert::TryFrom;
use std::{convert::TryFrom, ops::Neg};

mod bit_chunk;
pub use bit_chunk::{BitChunk, BitChunkIter};
Expand Down Expand Up @@ -399,3 +399,21 @@ impl months_days_ns {
self.2
}
}

impl Neg for days_ms {
type Output = Self;

#[inline(always)]
fn neg(self) -> Self::Output {
Self([-self.0[0], -self.0[0]])
}
}

impl Neg for months_days_ns {
type Output = Self;

#[inline(always)]
fn neg(self) -> Self::Output {
Self(-self.0, -self.1, -self.2)
}
}
24 changes: 23 additions & 1 deletion tests/it/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod basic;
mod decimal;
mod time;

use arrow2::array::{new_empty_array, Int32Array};
use arrow2::array::*;
use arrow2::compute::arithmetics::*;
use arrow2::datatypes::DataType::*;
use arrow2::datatypes::{IntervalUnit, TimeUnit};
Expand Down Expand Up @@ -84,3 +84,25 @@ fn consistency() {
}
});
}

#[test]
fn test_neg() {
let a = Int32Array::from(&[None, Some(6), None, Some(6)]);
let result = neg(&a);
let expected = Int32Array::from(&[None, Some(-6), None, Some(-6)]);
assert_eq!(expected, result.as_ref());
}

#[test]
fn test_neg_dict() {
let a = DictionaryArray::<u8>::from_data(
UInt8Array::from_slice(&[0, 0, 1]),
std::sync::Arc::new(Int8Array::from_slice(&[1, 2])),
);
let result = neg(&a);
let expected = DictionaryArray::<u8>::from_data(
UInt8Array::from_slice(&[0, 0, 1]),
std::sync::Arc::new(Int8Array::from_slice(&[-1, -2])),
);
assert_eq!(expected, result.as_ref());
}

0 comments on commit e84a140

Please sign in to comment.