Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove temporal to kernels_arrow #6069

Merged
merged 10 commits into from
Apr 24, 2023
52 changes: 41 additions & 11 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ use kernels::{
bitwise_xor, bitwise_xor_scalar,
};
use kernels_arrow::{
add_decimal_dyn_scalar, add_dyn_decimal, divide_decimal_dyn_scalar,
divide_dyn_opt_decimal, is_distinct_from, is_distinct_from_bool,
is_distinct_from_decimal, is_distinct_from_f32, is_distinct_from_f64,
is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_f32,
is_not_distinct_from_f64, is_not_distinct_from_null, is_not_distinct_from_utf8,
modulus_decimal_dyn_scalar, modulus_dyn_decimal, multiply_decimal_dyn_scalar,
multiply_dyn_decimal, subtract_decimal_dyn_scalar, subtract_dyn_decimal,
add_decimal_dyn_scalar, add_dyn_decimal, add_dyn_temporal, add_dyn_temporal_scalar,
divide_decimal_dyn_scalar, divide_dyn_opt_decimal, is_distinct_from,
is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_f32,
is_distinct_from_f64, is_distinct_from_null, is_distinct_from_utf8,
is_not_distinct_from, is_not_distinct_from_bool, is_not_distinct_from_decimal,
is_not_distinct_from_f32, is_not_distinct_from_f64, is_not_distinct_from_null,
is_not_distinct_from_utf8, modulus_decimal_dyn_scalar, modulus_dyn_decimal,
multiply_decimal_dyn_scalar, multiply_dyn_decimal, subtract_decimal_dyn_scalar,
subtract_dyn_decimal, subtract_dyn_temporal, subtract_dyn_temporal_scalar,
};

use arrow::datatypes::{DataType, Schema, TimeUnit};
Expand Down Expand Up @@ -1312,10 +1313,39 @@ macro_rules! sub_timestamp_macro {
Arc::new(ret) as ArrayRef
}};
}

pub fn resolve_temporal_op(
lhs: &ArrayRef,
sign: i32,
rhs: &ArrayRef,
) -> Result<ArrayRef> {
match sign {
1 => add_dyn_temporal(lhs, rhs),
-1 => subtract_dyn_temporal(lhs, rhs),
other => Err(DataFusionError::Internal(format!(
"Undefined operation for temporal types {other}"
))),
}
}

pub fn resolve_temporal_op_scalar(
lhs: &ArrayRef,
sign: i32,
rhs: &ScalarValue,
) -> Result<ColumnarValue> {
match sign {
1 => add_dyn_temporal_scalar(lhs, rhs),
-1 => subtract_dyn_temporal_scalar(lhs, rhs),
other => Err(DataFusionError::Internal(format!(
"Undefined operation for temporal types {other}"
))),
}
}

/// This function handles the Timestamp - Timestamp operations,
/// where the first one is an array, and the second one is a scalar,
/// hence the result is also an array.
pub fn ts_scalar_ts_op(array: ArrayRef, scalar: &ScalarValue) -> Result<ColumnarValue> {
pub fn ts_scalar_ts_op(array: &ArrayRef, scalar: &ScalarValue) -> Result<ColumnarValue> {
let ret = match (array.data_type(), scalar) {
(
DataType::Timestamp(TimeUnit::Second, opt_tz_lhs),
Expand Down Expand Up @@ -1410,7 +1440,7 @@ macro_rules! sub_timestamp_interval_macro {
/// where the first one is an array, and the second one is a scalar,
/// hence the result is also an array.
pub fn ts_scalar_interval_op(
array: ArrayRef,
array: &ArrayRef,
sign: i32,
scalar: &ScalarValue,
) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -1494,7 +1524,7 @@ macro_rules! sub_interval_cross_macro {
/// where the first one is an array, and the second one is a scalar,
/// hence the result is also an interval array.
pub fn interval_scalar_interval_op(
array: ArrayRef,
array: &ArrayRef,
sign: i32,
scalar: &ScalarValue,
) -> Result<ColumnarValue> {
Expand Down
123 changes: 119 additions & 4 deletions datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,24 @@
use arrow::compute::{
add_dyn, add_scalar_dyn, divide_dyn_opt, divide_scalar_dyn, modulus_dyn,
modulus_scalar_dyn, multiply_dyn, multiply_scalar_dyn, subtract_dyn,
subtract_scalar_dyn,
subtract_scalar_dyn, try_unary,
};
use arrow::datatypes::Decimal128Type;
use arrow::datatypes::{Date32Type, Date64Type, Decimal128Type};
use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array};
use arrow_schema::DataType;
use datafusion_common::cast::as_decimal128_array;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::cast::{as_date32_array, as_date64_array, as_decimal128_array};
use datafusion_common::scalar::{date32_add, date64_add};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type;
use datafusion_expr::ColumnarValue;
use datafusion_expr::Operator;
use std::sync::Arc;

use super::{
interval_array_op, interval_scalar_interval_op, ts_array_op, ts_interval_array_op,
ts_scalar_interval_op, ts_scalar_ts_op,
};

// Simple (low performance) kernels until optimized kernels are added to arrow
// See https://github.com/apache/arrow-rs/issues/960

Expand Down Expand Up @@ -280,6 +287,57 @@ pub(crate) fn add_decimal_dyn_scalar(
decimal_array_with_precision_scale(array, precision, scale)
}

pub(crate) fn add_dyn_temporal(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef> {
match (left.data_type(), right.data_type()) {
(DataType::Timestamp(..), DataType::Timestamp(..)) => ts_array_op(left, right),
(DataType::Interval(..), DataType::Interval(..)) => {
interval_array_op(left, right, 1)
}
(DataType::Timestamp(..), DataType::Interval(..)) => {
ts_interval_array_op(left, 1, right)
}
(DataType::Interval(..), DataType::Timestamp(..)) => {
ts_interval_array_op(right, 1, left)
}
_ => {
// fall back to kernels in arrow-rs
Ok(add_dyn(left, right)?)
}
}
}

pub(crate) fn add_dyn_temporal_scalar(
left: &ArrayRef,
right: &ScalarValue,
) -> Result<ColumnarValue> {
match (left.data_type(), right.get_datatype()) {
(DataType::Date32, DataType::Interval(..)) => {
let left = as_date32_array(&left)?;
let ret = Arc::new(try_unary::<Date32Type, _, Date32Type>(left, |days| {
Ok(date32_add(days, right, 1)?)
})?) as ArrayRef;
Ok(ColumnarValue::Array(ret))
}
(DataType::Date64, DataType::Interval(..)) => {
let left = as_date64_array(&left)?;
let ret = Arc::new(try_unary::<Date64Type, _, Date64Type>(left, |ms| {
Ok(date64_add(ms, right, 1)?)
})?) as ArrayRef;
Ok(ColumnarValue::Array(ret))
}
(DataType::Interval(..), DataType::Interval(..)) => {
interval_scalar_interval_op(left, 1, right)
}
(DataType::Timestamp(..), DataType::Interval(..)) => {
ts_scalar_interval_op(left, 1, right)
}
_ => {
// fall back to kernels in arrow-rs
Ok(ColumnarValue::Array(add_dyn(left, &right.to_array())?))
}
}
}

pub(crate) fn subtract_decimal_dyn_scalar(
left: &dyn Array,
right: i128,
Expand All @@ -291,6 +349,63 @@ pub(crate) fn subtract_decimal_dyn_scalar(
decimal_array_with_precision_scale(array, precision, scale)
}

pub(crate) fn subtract_dyn_temporal(
left: &ArrayRef,
right: &ArrayRef,
) -> Result<ArrayRef> {
match (left.data_type(), right.data_type()) {
(DataType::Timestamp(..), DataType::Timestamp(..)) => ts_array_op(left, right),
(DataType::Interval(..), DataType::Interval(..)) => {
interval_array_op(left, right, -1)
}
(DataType::Timestamp(..), DataType::Interval(..)) => {
ts_interval_array_op(left, -1, right)
}
(DataType::Interval(..), DataType::Timestamp(..)) => {
ts_interval_array_op(right, -1, left)
}
_ => {
// fall back to kernels in arrow-rs
Ok(subtract_dyn(left, right)?)
}
}
}

pub(crate) fn subtract_dyn_temporal_scalar(
left: &ArrayRef,
right: &ScalarValue,
) -> Result<ColumnarValue> {
match (left.data_type(), right.get_datatype()) {
(DataType::Date32, DataType::Interval(..)) => {
let left = as_date32_array(&left)?;
let ret = Arc::new(try_unary::<Date32Type, _, Date32Type>(left, |days| {
Ok(date32_add(days, right, -1)?)
})?) as ArrayRef;
Ok(ColumnarValue::Array(ret))
}
(DataType::Date64, DataType::Interval(..)) => {
let left = as_date64_array(&left)?;
let ret = Arc::new(try_unary::<Date64Type, _, Date64Type>(left, |ms| {
Ok(date64_add(ms, right, -1)?)
})?) as ArrayRef;
Ok(ColumnarValue::Array(ret))
}
(DataType::Timestamp(..), DataType::Timestamp(..)) => {
ts_scalar_ts_op(left, right)
}
(DataType::Interval(..), DataType::Interval(..)) => {
interval_scalar_interval_op(left, -1, right)
}
(DataType::Timestamp(..), DataType::Interval(..)) => {
ts_scalar_interval_op(left, -1, right)
}
_ => {
// fall back to kernels in arrow-rs
Ok(ColumnarValue::Array(subtract_dyn(left, &right.to_array())?))
}
}
}

fn get_precision_scale(data_type: &DataType) -> Result<(u8, i8)> {
match data_type {
DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
Expand Down
Loading