diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 4961d7efc0f2..04417c666c85 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -867,8 +867,8 @@ where #[cfg(test)] mod tests { use super::*; - use crate::arithmetic::add; use arrow_array::types::*; + use arrow_buffer::NullBuffer; use std::sync::Arc; #[test] @@ -897,54 +897,35 @@ mod tests { #[test] fn test_primitive_array_sum_large_64() { - let a: Int64Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(i) } else { None }) - .collect(); - let b: Int64Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int64Array::new((1..=100).collect(), Some(validity)); + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); } #[test] fn test_primitive_array_sum_large_32() { - let a: Int32Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(i) } else { None }) - .collect(); - let b: Int32Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int32Array::new((1..=100).collect(), Some(validity)); assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); } #[test] fn test_primitive_array_sum_large_16() { - let a: Int16Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(i) } else { None }) - .collect(); - let b: Int16Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int16Array::new((1..=100).collect(), Some(validity)); assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); } #[test] fn test_primitive_array_sum_large_8() { // include fewer values than other large tests so the result does not overflow the u8 - let a: UInt8Array = (1..=100) - .map(|i| if i % 33 == 0 { Some(i) } else { None }) - .collect(); - let b: UInt8Array = (1..=100) - .map(|i| if i % 33 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 33 == 0).collect()); + let c = UInt8Array::new((1..=100).collect(), Some(validity)); assert_eq!(Some((1..=100).filter(|i| i % 33 == 0).sum()), sum(&c)); } diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 8e7ab44042cf..4f6ecc78dc58 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -23,7 +23,6 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. use crate::arity::*; -use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::i256; @@ -39,6 +38,7 @@ use std::sync::Arc; /// # Errors /// /// This function errors if the arrays have different lengths +#[deprecated(note = "Use arrow_arith::arity::binary")] pub fn math_op( left: &PrimitiveArray, right: &PrimitiveArray, @@ -52,43 +52,6 @@ where binary(left, right, op) } -/// This is similar to `math_op` as it performs given operation between two input primitive arrays. -/// But the given operation can return `Err` if overflow is detected. For the case, this function -/// returns an `Err`. -fn math_checked_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result, ArrowError> -where - LT: ArrowNumericType, - RT: ArrowNumericType, - F: Fn(LT::Native, RT::Native) -> Result, -{ - try_binary(left, right, op) -} - -/// Helper function for operations where a valid `0` on the right array should -/// result in an [ArrowError::DivideByZero], namely the division and modulo operations -/// -/// # Errors -/// -/// This function errors if: -/// * the arrays have different lengths -/// * there is an element where both left and right values are valid and the right value is `0` -fn math_checked_divide_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result, ArrowError> -where - LT: ArrowNumericType, - RT: ArrowNumericType, - F: Fn(LT::Native, RT::Native) -> Result, -{ - math_checked_op(left, right, op) -} - /// Calculates the modulus operation `left % right` on two SIMD inputs. /// The lower-most bits of `valid_mask` specify which vector lanes are considered as valid. /// @@ -335,11 +298,12 @@ where /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `add_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")] pub fn add( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result, ArrowError> { - math_op(left, right, |a, b| a.add_wrapping(b)) + binary(left, right, |a, b| a.add_wrapping(b)) } /// Perform `left + right` operation on two arrays. If either left or right value is null @@ -347,11 +311,12 @@ pub fn add( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `add` instead. +#[deprecated(note = "Use arrow_arith::numeric::add")] pub fn add_checked( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result, ArrowError> { - math_checked_op(left, right, |a, b| a.add_checked(b)) + try_binary(left, right, |a, b| a.add_checked(b)) } /// Perform `left + right` operation on two arrays. If either left or right value is null @@ -359,176 +324,9 @@ pub fn add_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `add_dyn_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")] pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Second, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Interval(_) - if matches!( - right.data_type(), - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) - ) => - { - add_dyn(right, left) - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_op(left, right, |a, b| a.add_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } + crate::numeric::add_wrapping(&left, &right) } /// Perform `left + right` operation on two arrays. If either left or right value is null @@ -536,71 +334,12 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_checked_op(left, right, |a, b| a.add_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } + crate::numeric::add(&left, &right) } /// Add every value in an array by a scalar. If any value in the array is null then the @@ -608,6 +347,7 @@ pub fn add_dyn_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `add_scalar_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")] pub fn add_scalar( array: &PrimitiveArray, scalar: T::Native, @@ -620,6 +360,7 @@ pub fn add_scalar( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `add_scalar` instead. +#[deprecated(note = "Use arrow_arith::numeric::add")] pub fn add_scalar_checked( array: &PrimitiveArray, scalar: T::Native, @@ -635,6 +376,7 @@ pub fn add_scalar_checked( /// For an overflow-checking variant, use `add_scalar_checked_dyn` instead. /// /// This returns an `Err` when the input array is not supported for adding operation. +#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")] pub fn add_scalar_dyn( array: &dyn Array, scalar: T::Native, @@ -651,6 +393,7 @@ pub fn add_scalar_dyn( /// /// As this kernel has the branching costs and also prevents LLVM from vectorising it correctly, /// it is usually much slower than non-checking variant. +#[deprecated(note = "Use arrow_arith::numeric::add")] pub fn add_scalar_checked_dyn( array: &dyn Array, scalar: T::Native, @@ -664,11 +407,12 @@ pub fn add_scalar_checked_dyn( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `subtract_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")] pub fn subtract( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result, ArrowError> { - math_op(left, right, |a, b| a.sub_wrapping(b)) + binary(left, right, |a, b| a.sub_wrapping(b)) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -676,11 +420,12 @@ pub fn subtract( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `subtract` instead. +#[deprecated(note = "Use arrow_arith::numeric::sub")] pub fn subtract_checked( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result, ArrowError> { - math_checked_op(left, right, |a, b| a.sub_checked(b)) + try_binary(left, right, |a, b| a.sub_checked(b)) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -688,184 +433,9 @@ pub fn subtract_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `subtract_dyn_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")] pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Second, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Second, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_op(left, right, |a, b| a.sub_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } + crate::numeric::sub_wrapping(&left, &right) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -873,127 +443,12 @@ pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Second, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Second, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_checked_op(left, right, |a, b| a.sub_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } + crate::numeric::sub(&left, &right) } /// Subtract every value in an array by a scalar. If any value in the array is null then the @@ -1001,6 +456,7 @@ pub fn subtract_dyn_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `subtract_scalar_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")] pub fn subtract_scalar( array: &PrimitiveArray, scalar: T::Native, @@ -1013,6 +469,7 @@ pub fn subtract_scalar( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `subtract_scalar` instead. +#[deprecated(note = "Use arrow_arith::numeric::sub")] pub fn subtract_scalar_checked( array: &PrimitiveArray, scalar: T::Native, @@ -1026,6 +483,7 @@ pub fn subtract_scalar_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead. +#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")] pub fn subtract_scalar_dyn( array: &dyn Array, scalar: T::Native, @@ -1039,6 +497,7 @@ pub fn subtract_scalar_dyn( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `subtract_scalar_dyn` instead. +#[deprecated(note = "Use arrow_arith::numeric::sub")] pub fn subtract_scalar_checked_dyn( array: &dyn Array, scalar: T::Native, @@ -1072,11 +531,12 @@ pub fn negate_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `multiply_check` instead. +#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")] pub fn multiply( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result, ArrowError> { - math_op(left, right, |a, b| a.mul_wrapping(b)) + binary(left, right, |a, b| a.mul_wrapping(b)) } /// Perform `left * right` operation on two arrays. If either left or right value is null @@ -1084,11 +544,12 @@ pub fn multiply( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `multiply` instead. +#[deprecated(note = "Use arrow_arith::numeric::mul")] pub fn multiply_checked( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result, ArrowError> { - math_checked_op(left, right, |a, b| a.mul_checked(b)) + try_binary(left, right, |a, b| a.mul_checked(b)) } /// Perform `left * right` operation on two arrays. If either left or right value is null @@ -1096,16 +557,9 @@ pub fn multiply_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `multiply_dyn_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")] pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result { - downcast_primitive_array!( - (left, right) => { - math_op(left, right, |a, b| a.mul_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) + crate::numeric::mul_wrapping(&left, &right) } /// Perform `left * right` operation on two arrays. If either left or right value is null @@ -1113,19 +567,12 @@ pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - downcast_primitive_array!( - (left, right) => { - math_checked_op(left, right, |a, b| a.mul_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) + crate::numeric::mul(&left, &right) } /// Returns the precision and scale of the result of a multiplication of two decimal types, @@ -1210,8 +657,10 @@ pub fn multiply_fixed_point_checked( )?; if required_scale == product_scale { - return multiply_checked(left, right)? - .with_precision_and_scale(precision, required_scale); + return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { + a.mul_checked(b) + })? + .with_precision_and_scale(precision, required_scale); } try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { @@ -1254,7 +703,7 @@ pub fn multiply_fixed_point( )?; if required_scale == product_scale { - return multiply(left, right)? + return binary(left, right, |a, b| a.mul_wrapping(b))? .with_precision_and_scale(precision, required_scale); } @@ -1294,6 +743,7 @@ where /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `multiply_scalar_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")] pub fn multiply_scalar( array: &PrimitiveArray, scalar: T::Native, @@ -1306,6 +756,7 @@ pub fn multiply_scalar( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `multiply_scalar` instead. +#[deprecated(note = "Use arrow_arith::numeric::mul")] pub fn multiply_scalar_checked( array: &PrimitiveArray, scalar: T::Native, @@ -1319,6 +770,7 @@ pub fn multiply_scalar_checked( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead. +#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")] pub fn multiply_scalar_dyn( array: &dyn Array, scalar: T::Native, @@ -1332,6 +784,7 @@ pub fn multiply_scalar_dyn( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `multiply_scalar_dyn` instead. +#[deprecated(note = "Use arrow_arith::numeric::mul")] pub fn multiply_scalar_checked_dyn( array: &dyn Array, scalar: T::Native, @@ -1343,6 +796,7 @@ pub fn multiply_scalar_checked_dyn( /// Perform `left % right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. +#[deprecated(note = "Use arrow_arith::numeric::rem")] pub fn modulus( left: &PrimitiveArray, right: &PrimitiveArray, @@ -1364,22 +818,9 @@ pub fn modulus( /// Perform `left % right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. +#[deprecated(note = "Use arrow_arith::numeric::rem")] pub fn modulus_dyn(left: &dyn Array, right: &dyn Array) -> Result { - downcast_primitive_array!( - (left, right) => { - math_checked_divide_op(left, right, |a, b| { - if b.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(a.mod_wrapping(b)) - } - }).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) + crate::numeric::rem(&left, &right) } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1388,6 +829,7 @@ pub fn modulus_dyn(left: &dyn Array, right: &dyn Array) -> Result( left: &PrimitiveArray, right: &PrimitiveArray, @@ -1397,7 +839,7 @@ pub fn divide_checked( a.div_wrapping(b) }); #[cfg(not(feature = "simd"))] - return math_checked_divide_op(left, right, |a, b| a.div_checked(b)); + return try_binary(left, right, |a, b| a.div_checked(b)); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1414,6 +856,7 @@ pub fn divide_checked( /// /// For integer types overflow will wrap around. /// +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide_opt( left: &PrimitiveArray, right: &PrimitiveArray, @@ -1433,17 +876,23 @@ pub fn divide_opt( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `divide_dyn_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { + fn divide_op( + left: &PrimitiveArray, + right: &PrimitiveArray, + ) -> Result, ArrowError> { + try_binary(left, right, |a, b| { + if b.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(a.div_wrapping(b)) + } + }) + } + downcast_primitive_array!( - (left, right) => { - math_checked_divide_op(left, right, |a, b| { - if b.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(a.div_wrapping(b)) - } - }).map(|a| Arc::new(a) as ArrayRef) - } + (left, right) => divide_op(left, right).map(|a| Arc::new(a) as ArrayRef), _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", left.data_type(), right.data_type() @@ -1457,19 +906,12 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - downcast_primitive_array!( - (left, right) => { - math_checked_divide_op(left, right, |a, b| a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) + crate::numeric::div(&left, &right) } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1481,6 +923,7 @@ pub fn divide_dyn_checked( /// Unlike `divide_dyn` or `divide_dyn_checked`, division by zero will get a null value instead /// returning an `Err`, this also doesn't check overflowing, overflowing will just wrap /// the result around. +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide_dyn_opt( left: &dyn Array, right: &dyn Array, @@ -1513,18 +956,20 @@ pub fn divide_dyn_opt( /// If either left or right value is null then the result is also null. /// /// For an overflow-checking variant, use `divide_checked` instead. +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result, ArrowError> { // TODO: This is incorrect as div_wrapping has side-effects for integer types // and so may panic on null values (#2647) - math_op(left, right, |a, b| a.div_wrapping(b)) + binary(left, right, |a, b| a.div_wrapping(b)) } /// Modulus every value in an array by a scalar. If any value in the array is null then the /// result is also null. If the scalar is zero then the result of this operation will be /// `Err(ArrowError::DivideByZero)`. +#[deprecated(note = "Use arrow_arith::numeric::rem")] pub fn modulus_scalar( array: &PrimitiveArray, modulo: T::Native, @@ -1539,6 +984,7 @@ pub fn modulus_scalar( /// Modulus every value in an array by a scalar. If any value in the array is null then the /// result is also null. If the scalar is zero then the result of this operation will be /// `Err(ArrowError::DivideByZero)`. +#[deprecated(note = "Use arrow_arith::numeric::rem")] pub fn modulus_scalar_dyn( array: &dyn Array, modulo: T::Native, @@ -1552,6 +998,7 @@ pub fn modulus_scalar_dyn( /// Divide every value in an array by a scalar. If any value in the array is null then the /// result is also null. If the scalar is zero then the result of this operation will be /// `Err(ArrowError::DivideByZero)`. +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide_scalar( array: &PrimitiveArray, divisor: T::Native, @@ -1569,6 +1016,7 @@ pub fn divide_scalar( /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead. +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide_scalar_dyn( array: &dyn Array, divisor: T::Native, @@ -1586,6 +1034,7 @@ pub fn divide_scalar_dyn( /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `divide_scalar_dyn` instead. +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide_scalar_checked_dyn( array: &dyn Array, divisor: T::Native, @@ -1608,6 +1057,7 @@ pub fn divide_scalar_checked_dyn( /// Unlike `divide_scalar_dyn` or `divide_scalar_checked_dyn`, division by zero will get a /// null value instead returning an `Err`, this also doesn't check overflowing, overflowing /// will just wrap the result around. +#[deprecated(note = "Use arrow_arith::numeric::div")] pub fn divide_scalar_opt_dyn( array: &dyn Array, divisor: T::Native, @@ -1625,11 +1075,13 @@ pub fn divide_scalar_opt_dyn( } #[cfg(test)] +#[allow(deprecated)] mod tests { use super::*; use arrow_array::builder::{ BooleanBufferBuilder, BufferBuilder, PrimitiveDictionaryBuilder, }; + use arrow_array::cast::AsArray; use arrow_array::temporal_conversions::SECONDS_IN_DAY; use arrow_buffer::buffer::NullBuffer; use arrow_buffer::i256; @@ -1678,16 +1130,14 @@ mod tests { )]); let b = IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]); let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) ); let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) ); } @@ -1702,16 +1152,14 @@ mod tests { 1, 2, 3, )]); let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) ); let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) ); } @@ -1724,16 +1172,14 @@ mod tests { let b = IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 2)]); let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3, 1).unwrap()) ); let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3, 1).unwrap()) ); } @@ -1745,16 +1191,14 @@ mod tests { )]); let b = IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]); let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) ); let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) ); } @@ -1769,16 +1213,14 @@ mod tests { 1, 2, 3, )]); let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) ); let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); assert_eq!( - c.value(0), + c.as_primitive::().value(0), Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) ); } @@ -2584,11 +2026,11 @@ mod tests { } #[test] - #[should_panic(expected = "DivideByZero")] fn test_f32_array_modulus_dyn_by_zero() { let a = Float32Array::from(vec![1.5]); let b = Float32Array::from(vec![0.0]); - modulus_dyn(&a, &b).unwrap(); + let result = modulus_dyn(&a, &b).unwrap(); + assert!(result.as_primitive::().value(0).is_nan()); } #[test] @@ -3838,10 +3280,6 @@ mod tests { ::Native::MIN, ]); - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); - // checked let result = subtract_dyn_checked(&a, &b); assert!(&result.is_err()); @@ -3866,16 +3304,8 @@ mod tests { #[test] fn test_timestamp_microsecond_subtract_timestamp_overflow() { - let a = TimestampMicrosecondArray::from(vec![ - ::Native::MAX, - ]); - let b = TimestampMicrosecondArray::from(vec![ - ::Native::MIN, - ]); - - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); + let a = TimestampMicrosecondArray::from(vec![i64::MAX]); + let b = TimestampMicrosecondArray::from(vec![i64::MIN]); // checked let result = subtract_dyn_checked(&a, &b); @@ -3901,16 +3331,8 @@ mod tests { #[test] fn test_timestamp_millisecond_subtract_timestamp_overflow() { - let a = TimestampMillisecondArray::from(vec![ - ::Native::MAX, - ]); - let b = TimestampMillisecondArray::from(vec![ - ::Native::MIN, - ]); - - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); + let a = TimestampMillisecondArray::from(vec![i64::MAX]); + let b = TimestampMillisecondArray::from(vec![i64::MIN]); // checked let result = subtract_dyn_checked(&a, &b); @@ -3943,10 +3365,6 @@ mod tests { ::Native::MIN, ]); - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); - // checked let result = subtract_dyn_checked(&a, &b); assert!(&result.is_err()); diff --git a/arrow-arith/src/lib.rs b/arrow-arith/src/lib.rs index 60d31c972b66..2d5451e04dd2 100644 --- a/arrow-arith/src/lib.rs +++ b/arrow-arith/src/lib.rs @@ -18,8 +18,10 @@ //! Arrow arithmetic and aggregation kernels pub mod aggregate; +#[doc(hidden)] // Kernels to be removed in a future release pub mod arithmetic; pub mod arity; pub mod bitwise; pub mod boolean; +pub mod numeric; pub mod temporal; diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs new file mode 100644 index 000000000000..816fcaa944f5 --- /dev/null +++ b/arrow-arith/src/numeric.rs @@ -0,0 +1,672 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines numeric arithmetic kernels on [`PrimitiveArray`], such as [`add`] + +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow_array::cast::AsArray; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; + +use crate::arity::{binary, try_binary}; + +/// Perform `lhs + rhs`, returning an error on overflow +pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Add, lhs, rhs) +} + +/// Perform `lhs + rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::AddWrapping, lhs, rhs) +} + +/// Perform `lhs - rhs`, returning an error on overflow +pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Sub, lhs, rhs) +} + +/// Perform `lhs - rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::SubWrapping, lhs, rhs) +} + +/// Perform `lhs * rhs`, returning an error on overflow +pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Mul, lhs, rhs) +} + +/// Perform `lhs * rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::MulWrapping, lhs, rhs) +} + +/// Perform `lhs / rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Div, lhs, rhs) +} + +/// Perform `lhs % rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Rem, lhs, rhs) +} + +/// An enumeration of arithmetic operations +/// +/// This allows sharing the type dispatch logic across the various kernels +#[derive(Debug, Copy, Clone)] +enum Op { + AddWrapping, + Add, + SubWrapping, + Sub, + MulWrapping, + Mul, + Div, + Rem, +} + +impl Op { + fn commutative(&self) -> bool { + matches!(self, Self::Add | Self::AddWrapping) + } +} + +/// Dispatch the given `op` to the appropriate specialized kernel +fn arithmetic_op( + op: Op, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + + macro_rules! integer_helper { + ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident, $r_scalar:ident) => { + integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar) + }; + } + + let (l, l_scalar) = lhs.get(); + let (r, r_scalar) = rhs.get(); + downcast_integer! { + l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r, r_scalar), + (Float16, Float16) => float_op::(op, l, l_scalar, r, r_scalar), + (Float32, Float32) => float_op::(op, l, l_scalar, r, r_scalar), + (Float64, Float64) => float_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Second, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Millisecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Microsecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Nanosecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Duration(Second), Duration(Second)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Millisecond), Duration(Millisecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Microsecond), Duration(Microsecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Nanosecond), Duration(Nanosecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Interval(YearMonth), Interval(YearMonth)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(DayTime), Interval(DayTime)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Decimal128(_, _), Decimal128(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (Decimal256(_, _), Decimal256(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (l_t, r_t) => match (l_t, r_t) { + (Duration(_) | Interval(_), Date32 | Date64 | Timestamp(_, _)) if op.commutative() => { + arithmetic_op(op, rhs, lhs) + } + _ => Err(ArrowError::InvalidArgumentError( + format!("Invalid arithmetic operation: {l_t} {op:?} {r_t}") + )) + } + } +} + +/// Perform an infallible binary operation on potentially scalar inputs +macro_rules! op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.unary(|$r| $op), + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.unary(|$l| $op), + }, + } + }; +} + +/// Same as `op` but with a type hint for the returned array +macro_rules! op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform a fallible binary operation on potentially scalar inputs +macro_rules! try_op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.try_unary(|$r| $op)?, + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.try_unary(|$l| $op)?, + }, + } + }; +} + +/// Same as `try_op` but with a type hint for the returned array +macro_rules! try_op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform an arithmetic operation on integers +fn integer_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)), + Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)), + Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)), + Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)), + Op::Rem => try_op!(l, l_s, r, r_s, l.mod_checked(r)), + }; + Ok(Arc::new(array)) +} + +/// Perform an arithmetic operation on floats +fn float_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)), + Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)), + }; + Ok(Arc::new(array)) +} + +/// Arithmetic trait for timestamp arrays +trait TimestampOp: ArrowTimestampType { + type Duration: ArrowPrimitiveType; + + fn add_year_month(timestamp: i64, delta: i32) -> Result; + fn add_day_time(timestamp: i64, delta: i64) -> Result; + fn add_month_day_nano(timestamp: i64, delta: i128) -> Result; + + fn sub_year_month(timestamp: i64, delta: i32) -> Result; + fn sub_day_time(timestamp: i64, delta: i64) -> Result; + fn sub_month_day_nano(timestamp: i64, delta: i128) -> Result; +} + +macro_rules! timestamp { + ($t:ty, $d:ty) => { + impl TimestampOp for $t { + type Duration = $d; + + fn add_year_month(left: i64, right: i32) -> Result { + Self::add_year_months(left, right) + } + + fn add_day_time(left: i64, right: i64) -> Result { + Self::add_day_time(left, right) + } + + fn add_month_day_nano(left: i64, right: i128) -> Result { + Self::add_month_day_nano(left, right) + } + + fn sub_year_month(left: i64, right: i32) -> Result { + Self::subtract_year_months(left, right) + } + + fn sub_day_time(left: i64, right: i64) -> Result { + Self::subtract_day_time(left, right) + } + + fn sub_month_day_nano(left: i64, right: i128) -> Result { + Self::subtract_month_day_nano(left, right) + } + } + }; +} +timestamp!(TimestampSecondType, DurationSecondType); +timestamp!(TimestampMillisecondType, DurationMillisecondType); +timestamp!(TimestampMicrosecondType, DurationMicrosecondType); +timestamp!(TimestampNanosecondType, DurationNanosecondType); + +/// Perform arithmetic operation on a timestamp array +fn timestamp_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + // Note: interval arithmetic should account for timezones (#4457) + let l = l.as_primitive::(); + let array: PrimitiveArray = match (op, r.data_type()) { + (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + return Ok(try_op_ref!(T::Duration, l, l_s, r, r_s, l.sub_checked(r))); + } + + (Op::Add | Op::AddWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.add_checked(r)) + } + (Op::Sub | Op::SubWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.sub_checked(r)) + } + + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, T::add_year_month(l, r)) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, T::sub_year_month(l, r)) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, T::add_day_time(l, r)) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, T::sub_day_time(l, r)) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, T::add_month_day_nano(l, r)) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, T::sub_month_day_nano(l, r)) + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid timestamp arithmetic operation: {} {op:?} {}", + l.data_type(), + r.data_type() + ))) + } + }; + Ok(Arc::new(array.with_timezone_opt(l.timezone()))) +} + +/// Arithmetic trait for date arrays +/// +/// Note: these should be fallible (#4456) +trait DateOp: ArrowTemporalType { + fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn add_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn add_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; + + fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn sub_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn sub_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; +} + +macro_rules! date { + ($t:ty) => { + impl DateOp for $t { + fn add_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::add_year_months(left, right) + } + + fn add_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::add_day_time(left, right) + } + + fn add_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::add_month_day_nano(left, right) + } + + fn sub_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::subtract_year_months(left, right) + } + + fn sub_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::subtract_day_time(left, right) + } + + fn sub_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::subtract_month_day_nano(left, right) + } + } + }; +} +date!(Date32Type); +date!(Date64Type); + +/// Arithmetic trait for interval arrays +trait IntervalOp: ArrowPrimitiveType { + fn add(left: Self::Native, right: Self::Native) -> Result; + fn sub(left: Self::Native, right: Self::Native) -> Result; +} + +impl IntervalOp for IntervalYearMonthType { + fn add(left: Self::Native, right: Self::Native) -> Result { + left.add_checked(right) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + left.sub_checked(right) + } +} + +impl IntervalOp for IntervalDayTimeType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.add_checked(r_days)?; + let ms = l_ms.add_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.sub_checked(r_days)?; + let ms = l_ms.sub_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } +} + +impl IntervalOp for IntervalMonthDayNanoType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.add_checked(r_months)?; + let days = l_days.add_checked(r_days)?; + let nanos = l_nanos.add_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.sub_checked(r_months)?; + let days = l_days.sub_checked(r_days)?; + let nanos = l_nanos.sub_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } +} + +/// Perform arithmetic operation on an interval array +fn interval_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::add(l, r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub(l, r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval arithmetic operation: {} {op:?} {}", + l.data_type(), + r.data_type() + ))), + } +} + +fn duration_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.add_checked(r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.sub_checked(r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid duration arithmetic operation: {} {op:?} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on a date array +fn date_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + // Note: interval arithmetic should account for timezones (#4457) + let l = l.as_primitive::(); + match (op, r.data_type()) { + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) + } + + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid date arithmetic operation: {} {op:?} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on decimal arrays +fn decimal_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + + let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) { + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1, s1, p2, s2), + (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1, s1, p2, s2), + _ => unreachable!(), + }; + + // Follow the Hive decimal arithmetic rules + // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + let array: PrimitiveArray = match op { + Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => { + // max(s1, s2) + let result_scale = *s1.max(s2); + + // max(s1, s2) + max(p1-s1, p2-s2) + 1 + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).max(*p2 as i8 - s2)) as u8) + .saturating_add(1) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s1) as _); + let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s2) as _); + + match op { + Op::Add | Op::AddWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.add_checked(r.mul_checked(r_mul)?) + ) + } + Op::Sub | Op::SubWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.sub_checked(r.mul_checked(r_mul)?) + ) + } + _ => unreachable!(), + } + .with_precision_and_scale(result_precision, result_scale)? + } + Op::Mul | Op::MulWrapping => { + let result_precision = p1.saturating_add(p2 + 1).min(T::MAX_PRECISION); + let result_scale = s1.saturating_add(*s2); + if result_scale > T::MAX_SCALE { + // SQL standard says that if the resulting scale of a multiply operation goes + // beyond the maximum, rounding is not acceptable and thus an error occurs + return Err(ArrowError::InvalidArgumentError(format!( + "Output scale of {} {op:?} {} would exceed max scale of {}", + l.data_type(), + r.data_type(), + T::MAX_SCALE + ))); + } + + try_op!(l, l_s, r, r_s, l.mul_checked(r)) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Div => { + // Follow postgres and MySQL adding a fixed scale increment of 4 + // s1 + 4 + let result_scale = s1.saturating_add(4).min(T::MAX_SCALE); + let mul_pow = result_scale - s1 + s2; + + // p1 - s1 + s2 + result_scale + let result_precision = + (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION); + + let (l_mul, r_mul) = match mul_pow.cmp(&0) { + Ordering::Greater => ( + T::Native::usize_as(10).pow_wrapping(mul_pow as _), + T::Native::ONE, + ), + Ordering::Equal => (T::Native::ONE, T::Native::ONE), + Ordering::Less => ( + T::Native::ONE, + T::Native::usize_as(10).pow_wrapping(mul_pow.neg_wrapping() as _), + ), + }; + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.div_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Rem => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // min(p1-s1, p2 -s2) + max( s1,s2 ) + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).min(*p2 as i8 - s2)) as u8) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s1) as _); + let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s2) as _); + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.mod_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + }; + + Ok(Arc::new(array)) +} diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 576f645b0375..8337326370dd 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -517,6 +517,15 @@ impl PrimitiveArray { Self::try_new(values, nulls).unwrap() } + /// Create a new [`PrimitiveArray`] of the given length where all values are null + pub fn new_null(length: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + values: vec![T::Native::usize_as(0); length].into(), + nulls: Some(NullBuffer::new_null(length)), + } + } + /// Create a new [`PrimitiveArray`] from the provided values and nulls /// /// # Errors diff --git a/arrow-array/src/scalar.rs b/arrow-array/src/scalar.rs index e54a999f9980..c142107c5cf3 100644 --- a/arrow-array/src/scalar.rs +++ b/arrow-array/src/scalar.rs @@ -92,6 +92,12 @@ impl Datum for dyn Array { } } +impl Datum for &dyn Array { + fn get(&self) -> (&dyn Array, bool) { + (*self, false) + } +} + /// A wrapper around a single value [`Array`] indicating kernels should treat it as a scalar value /// /// See [`Datum`] for more information diff --git a/arrow/benches/arithmetic_kernels.rs b/arrow/benches/arithmetic_kernels.rs index 4ed197783b07..e982b0eb4b5f 100644 --- a/arrow/benches/arithmetic_kernels.rs +++ b/arrow/benches/arithmetic_kernels.rs @@ -15,65 +15,61 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; -use rand::Rng; +use criterion::*; extern crate arrow; +use arrow::compute::kernels::numeric::*; use arrow::datatypes::Float32Type; use arrow::util::bench_util::*; -use arrow::{compute::kernels::arithmetic::*, util::test_util::seedable_rng}; +use arrow_array::Scalar; fn add_benchmark(c: &mut Criterion) { const BATCH_SIZE: usize = 64 * 1024; for null_density in [0., 0.1, 0.5, 0.9, 1.0] { let arr_a = create_primitive_array::(BATCH_SIZE, null_density); let arr_b = create_primitive_array::(BATCH_SIZE, null_density); - let scalar = seedable_rng().gen(); + let scalar_a = create_primitive_array::(1, 0.); + let scalar = Scalar::new(&scalar_a); c.bench_function(&format!("add({null_density})"), |b| { - b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(add_wrapping(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("add_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(add_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("add_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(add_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(add_wrapping(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("subtract({null_density})"), |b| { - b.iter(|| criterion::black_box(subtract(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(sub_wrapping(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("subtract_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(subtract_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(sub(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("subtract_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(subtract_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(sub_wrapping(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("multiply({null_density})"), |b| { - b.iter(|| criterion::black_box(multiply(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(mul_wrapping(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("multiply_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(multiply_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(mul(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("multiply_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(multiply_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(mul_wrapping(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("divide({null_density})"), |b| { - b.iter(|| criterion::black_box(divide(&arr_a, &arr_b).unwrap())) - }); - c.bench_function(&format!("divide_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(divide_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(div(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("divide_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(divide_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(div(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("modulo({null_density})"), |b| { - b.iter(|| criterion::black_box(modulus(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(rem(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("modulo_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(modulus_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(rem(&arr_a, &scalar).unwrap())) }); } } diff --git a/arrow/src/compute/kernels/mod.rs b/arrow/src/compute/kernels/mod.rs index d9c948c607bd..49eae6d3ade5 100644 --- a/arrow/src/compute/kernels/mod.rs +++ b/arrow/src/compute/kernels/mod.rs @@ -19,7 +19,9 @@ pub mod limit; -pub use arrow_arith::{aggregate, arithmetic, arity, bitwise, boolean, temporal}; +pub use arrow_arith::{ + aggregate, arithmetic, arity, bitwise, boolean, numeric, temporal, +}; pub use arrow_cast::cast; pub use arrow_cast::parse as cast_utils; pub use arrow_ord::{partition, sort}; diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 12aa1309c552..a392d1deec86 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -105,6 +105,8 @@ To export an array, create an `ArrowArray` using [ArrowArray::try_new]. use std::{mem::size_of, ptr::NonNull, sync::Arc}; +pub use arrow_data::ffi::FFI_ArrowArray; +pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags}; use arrow_schema::UnionMode; use crate::array::{layout, ArrayData}; @@ -113,9 +115,6 @@ use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -pub use arrow_data::ffi::FFI_ArrowArray; -pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags}; - // returns the number of bits that buffer `i` (in the C data interface) is expected to have. // This is set by the Arrow specification fn bit_width(data_type: &DataType, i: usize) -> Result { @@ -412,7 +411,16 @@ impl<'a> ArrowArray<'a> { #[cfg(test)] mod tests { - use super::*; + use std::collections::HashMap; + use std::convert::TryFrom; + use std::mem::ManuallyDrop; + use std::ptr::addr_of_mut; + + use arrow_array::builder::UnionBuilder; + use arrow_array::cast::AsArray; + use arrow_array::types::{Float64Type, Int32Type}; + use arrow_array::{StructArray, UnionArray}; + use crate::array::{ make_array, Array, ArrayData, BooleanArray, Decimal128Array, DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, @@ -421,14 +429,8 @@ mod tests { }; use crate::compute::kernels; use crate::datatypes::{Field, Int8Type}; - use arrow_array::builder::UnionBuilder; - use arrow_array::cast::AsArray; - use arrow_array::types::{Float64Type, Int32Type}; - use arrow_array::{StructArray, UnionArray}; - use std::collections::HashMap; - use std::convert::TryFrom; - use std::mem::ManuallyDrop; - use std::ptr::addr_of_mut; + + use super::*; #[test] fn test_round_trip() { @@ -440,10 +442,10 @@ mod tests { // (simulate consumer) import it let array = Int32Array::from(from_ffi(array, &schema).unwrap()); - let array = kernels::arithmetic::add(&array, &array).unwrap(); + let array = kernels::numeric::add(&array, &array).unwrap(); // verify - assert_eq!(array, Int32Array::from(vec![2, 4, 6])); + assert_eq!(array.as_ref(), &Int32Array::from(vec![2, 4, 6])); } #[test] @@ -491,10 +493,10 @@ mod tests { let array = array.as_any().downcast_ref::().unwrap(); assert_eq!(array, &Int32Array::from(vec![Some(2), None])); - let array = kernels::arithmetic::add(array, array).unwrap(); + let array = kernels::numeric::add(array, array).unwrap(); // verify - assert_eq!(array, Int32Array::from(vec![Some(4), None])); + assert_eq!(array.as_ref(), &Int32Array::from(vec![Some(4), None])); // (drop/release) Ok(())