From 8703f555c1b412afd3e67c53d3856e761d1b7a4c Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 12 Jun 2023 17:51:14 +0100 Subject: [PATCH] Add Datum based arithmetic kernels (#3999) --- arrow-arith/src/lib.rs | 1 + arrow-arith/src/operation.rs | 419 +++++++++++++++++++++++ arrow-array/src/array/primitive_array.rs | 9 + 3 files changed, 429 insertions(+) create mode 100644 arrow-arith/src/operation.rs diff --git a/arrow-arith/src/lib.rs b/arrow-arith/src/lib.rs index 60d31c972b66..a2eacdae772a 100644 --- a/arrow-arith/src/lib.rs +++ b/arrow-arith/src/lib.rs @@ -22,4 +22,5 @@ pub mod arithmetic; pub mod arity; pub mod bitwise; pub mod boolean; +pub mod operation; pub mod temporal; diff --git a/arrow-arith/src/operation.rs b/arrow-arith/src/operation.rs new file mode 100644 index 000000000000..6a5dc3aa9832 --- /dev/null +++ b/arrow-arith/src/operation.rs @@ -0,0 +1,419 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Arrow arithmetic operations + +use crate::arity::{binary, try_binary}; +use arrow_array::cast::AsArray; +use arrow_array::types::*; +use arrow_array::*; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; +use std::sync::Arc; + +/// Perform addition between two `Datum` +/// +/// An error will be returned if this results in overflow +pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Add, lhs, rhs) +} + +/// Perform addition between two `Datum` +/// +/// Unlike [`add`] this will not return an error for integer overflow +pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::AddWrapping, lhs, rhs) +} + +/// Perform addition between two `Datum` +/// +/// An error will be returned if this results in overflow +pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Sub, lhs, rhs) +} + +/// Perform subtraction between two `Datum` +/// +/// Unlike [`sub`] this will not return an error for integer overflow +pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::SubWrapping, lhs, rhs) +} + +/// Perform multiplication between two `Datum` +/// +/// An error will be returned if this results in overflow +pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Mul, lhs, rhs) +} + +/// Perform multiplication between two `Datum` +/// +/// Unlike [`mul`] this will not return an error for integer overflow +pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::MulWrapping, lhs, rhs) +} + +/// Perform division between two `Datum` +/// +/// An error will be returned if this results in overflow or would divide by zero +pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Div, lhs, rhs) +} + +/// Compute the remainder of division of two `Datum` +/// +/// An error will be returned if this results in overflow or would divide by zero +pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Rem, lhs, rhs) +} + +/// An enumeration of arithmetic operations +/// +/// This allows sharing the type dispatch logic across the various kernels +#[derive(Debug, Copy, Clone)] +enum Op { + AddWrapping, + Add, + SubWrapping, + Sub, + MulWrapping, + Mul, + Div, + Rem, +} + +/// Dispatch the given `op` to the appropriate specialized kernel +fn arithmetic_op( + op: Op, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + use DataType::*; + use TimeUnit::*; + + macro_rules! integer_helper { + ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident, $r_scalar:ident) => { + integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar) + }; + } + + let (l, l_scalar) = lhs.get(); + let (r, r_scalar) = rhs.get(); + downcast_integer! { + l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r, r_scalar), + (Float16, Float16) => float_op::(op, l, l_scalar, r, r_scalar), + (Float32, Float32) => float_op::(op, l, l_scalar, r, r_scalar), + (Float64, Float64) => float_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Second, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Millisecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Microsecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Nanosecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), + (l_t, r_t) => Err(ArrowError::InvalidArgumentError( + format!("Invalid arithmetic operation: {l_t} {op:?} {r_t}") + )) + } +} + +/// Perform an infallible binary operation on potentially scalar inputs +macro_rules! op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.unary(|$r| $op), + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.unary(|$l| $op), + }, + } + }; +} + +/// Same as `op` but with a type hint for the returned array +macro_rules! op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform a fallible binary operation on potentially scalar inputs +macro_rules! try_op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.try_unary(|$r| $op)?, + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.try_unary(|$l| $op)?, + }, + } + }; +} + +/// Same as `try_op` but with a type hint for the returned array +macro_rules! try_op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform an arithmetic operation on integers +fn integer_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)), + Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)), + Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)), + Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)), + Op::Rem => try_op!(l, l_s, r, r_s, l.div_checked(r)), + }; + Ok(Arc::new(array)) +} + +/// Perform an arithmetic operation on floats +fn float_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)), + Op::Rem => try_op!(l, l_s, r, r_s, l.div_checked(r)), + }; + Ok(Arc::new(array)) +} + +/// Arithmetic trait for timestamp arrays +trait TimestampOp: ArrowTimestampType { + type Duration: ArrowPrimitiveType; + + fn add_year_month(timestamp: i64, delta: i32) -> Result; + fn add_day_time(timestamp: i64, delta: i64) -> Result; + fn add_month_day_nano(timestamp: i64, delta: i128) -> Result; + + fn sub_year_month(timestamp: i64, delta: i32) -> Result; + fn sub_day_time(timestamp: i64, delta: i64) -> Result; + fn sub_month_day_nano(timestamp: i64, delta: i128) -> Result; +} + +macro_rules! timestamp { + ($t:ty, $d:ty) => { + impl TimestampOp for $t { + type Duration = $d; + + fn add_year_month(left: i64, right: i32) -> Result { + Self::add_year_months(left, right) + } + + fn add_day_time(left: i64, right: i64) -> Result { + Self::add_day_time(left, right) + } + + fn add_month_day_nano(left: i64, right: i128) -> Result { + Self::add_month_day_nano(left, right) + } + + fn sub_year_month(left: i64, right: i32) -> Result { + Self::subtract_year_months(left, right) + } + + fn sub_day_time(left: i64, right: i64) -> Result { + Self::subtract_day_time(left, right) + } + + fn sub_month_day_nano(left: i64, right: i128) -> Result { + Self::subtract_month_day_nano(left, right) + } + } + }; +} +timestamp!(TimestampSecondType, DurationSecondType); +timestamp!(TimestampMillisecondType, DurationMillisecondType); +timestamp!(TimestampMicrosecondType, DurationMicrosecondType); +timestamp!(TimestampNanosecondType, DurationNanosecondType); + +/// Perform arithmetic operation on a timestamp array +fn timestamp_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + // Note: interval arithmetic should account for timezones (#4457) + let l = l.as_primitive::(); + match (op, r.data_type()) { + (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + Ok(try_op_ref!(T::Duration, l, l_s, r, r_s, l.sub_checked(r))) + } + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(try_op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(try_op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(try_op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) + } + + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid timestamp arithmetic operation: {} {op:?} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Arithmetic trait for date arrays +/// +/// Note: these should be fallible (#4456) +trait DateOp: ArrowTemporalType { + fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn add_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn add_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; + + fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn sub_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn sub_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; +} + +macro_rules! date { + ($t:ty) => { + impl DateOp for $t { + fn add_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::add_year_months(left, right) + } + + fn add_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::add_day_time(left, right) + } + + fn add_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::add_month_day_nano(left, right) + } + + fn sub_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::subtract_year_months(left, right) + } + + fn sub_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::subtract_day_time(left, right) + } + + fn sub_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::subtract_month_day_nano(left, right) + } + } + }; +} +date!(Date32Type); +date!(Date64Type); + +/// Perform arithmetic operation on a timestamp array +fn date_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + // Note: interval arithmetic should account for timezones (#4457) + let l = l.as_primitive::(); + match (op, r.data_type()) { + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) + } + + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid date arithmetic operation: {} {op:?} {}", + l.data_type(), + r.data_type() + ))), + } +} diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 576f645b0375..8337326370dd 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -517,6 +517,15 @@ impl PrimitiveArray { Self::try_new(values, nulls).unwrap() } + /// Create a new [`PrimitiveArray`] of the given length where all values are null + pub fn new_null(length: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + values: vec![T::Native::usize_as(0); length].into(), + nulls: Some(NullBuffer::new_null(length)), + } + } + /// Create a new [`PrimitiveArray`] from the provided values and nulls /// /// # Errors