Skip to content

Commit

Permalink
feat: negation of Intervals (#6312)
Browse files Browse the repository at this point in the history
* feat: negation of Intervals

* fix: clippy
  • Loading branch information
izveigor authored and Ted-Jiang committed May 11, 2023
1 parent f4d5a69 commit 78d7a17
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 5 deletions.
35 changes: 35 additions & 0 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1923,6 +1923,19 @@ impl ScalarValue {
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))),
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))),
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))),
ScalarValue::IntervalYearMonth(Some(v)) => {
Ok(ScalarValue::IntervalYearMonth(Some(-v)))
}
ScalarValue::IntervalDayTime(Some(v)) => {
let (days, ms) = IntervalDayTimeType::to_parts(*v);
let val = IntervalDayTimeType::make_value(-days, -ms);
Ok(ScalarValue::IntervalDayTime(Some(val)))
}
ScalarValue::IntervalMonthDayNano(Some(v)) => {
let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v);
let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos);
Ok(ScalarValue::IntervalMonthDayNano(Some(val)))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale))
}
Expand Down Expand Up @@ -5430,6 +5443,28 @@ mod tests {
}
}

#[test]
fn test_scalar_interval_negate() {
let cases = [
(
ScalarValue::new_interval_ym(1, 12),
ScalarValue::new_interval_ym(-1, -12),
),
(
ScalarValue::new_interval_dt(1, 999),
ScalarValue::new_interval_dt(-1, -999),
),
(
ScalarValue::new_interval_mdn(12, 15, 123_456),
ScalarValue::new_interval_mdn(-12, -15, -123_456),
),
];
for (expr, expected) in cases.iter() {
let result = expr.arithmetic_negate().unwrap();
assert_eq!(*expected, result, "-expr:{expr:?}");
}
}

#[test]
fn test_scalar_interval_add() {
let cases = [
Expand Down
24 changes: 24 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/interval.slt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,30 @@ select interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 seco
----
0 years 11 mons -1 days -1 hours -1 mins -1.001001001 secs

# Interval with string literal negation and leading field
query ?
select -interval '5' - '1' - '2' year;
----
0 years -24 mons 0 days 0 hours 0 mins 0.000000000 secs

# Interval with nested string literal negation
query ?
select -interval '1 month' + '1 day' + '1 hour';
----
0 years -1 mons -1 days -1 hours 0 mins 0.000000000 secs

# Interval with nested string literal negation and leading field
query ?
select -interval '10' - '1' - '1' month;
----
0 years -8 mons 0 days 0 hours 0 mins 0.000000000 secs

# Interval mega nested string literal negation
query ?
select -interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 second' - '1 millisecond' - '1 microsecond' - '1 nanosecond'
----
0 years -11 mons 1 days 1 hours 1 mins 1.001001001 secs

# Interval string literal + date
query D
select interval '1 month' + '1 day' + '2012-01-01'::date;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub mod functions;
pub mod other;

use arrow::datatypes::DataType;
/// Determine whether the given data type `dt` represents unsigned numeric values.
/// Determine whether the given data type `dt` represents signed numeric values.
pub fn is_signed_numeric(dt: &DataType) -> bool {
matches!(
dt,
Expand Down
91 changes: 87 additions & 4 deletions datafusion/physical-expr/src/expressions/negative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::compute::kernels::arithmetic::negate;
use arrow::{
array::{Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array},
datatypes::{DataType, Schema},
array::{
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray,
},
datatypes::{DataType, IntervalUnit, Schema},
record_batch::RecordBatch,
};

use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{
type_coercion::{is_null, is_signed_numeric},
type_coercion::{is_interval, is_null, is_signed_numeric},
ColumnarValue,
};

Expand Down Expand Up @@ -98,6 +101,9 @@ impl PhysicalExpr for NegativeExpr {
DataType::Int64 => compute_op!(array, negate, Int64Array),
DataType::Float32 => compute_op!(array, negate, Float32Array),
DataType::Float64 => compute_op!(array, negate, Float64Array),
DataType::Interval(IntervalUnit::YearMonth) => compute_op!(array, negate, IntervalYearMonthArray),
DataType::Interval(IntervalUnit::DayTime) => compute_op!(array, negate, IntervalDayTimeArray),
DataType::Interval(IntervalUnit::MonthDayNano) => compute_op!(array, negate, IntervalMonthDayNanoArray),
_ => Err(DataFusionError::Internal(format!(
"(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed numeric",
self,
Expand Down Expand Up @@ -145,11 +151,88 @@ pub fn negative(
let data_type = arg.data_type(input_schema)?;
if is_null(&data_type) {
Ok(arg)
} else if !is_signed_numeric(&data_type) {
} else if !is_signed_numeric(&data_type) && !is_interval(&data_type) {
Err(DataFusionError::Internal(
format!("Can't create negative physical expr for (- '{arg:?}'), the type of child expr is {data_type}, not signed numeric"),
))
} else {
Ok(Arc::new(NegativeExpr::new(arg)))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::col;
#[allow(unused_imports)]
use arrow::array::*;
use arrow::datatypes::*;
use arrow_schema::DataType::{Float32, Float64, Int16, Int32, Int64, Int8};
use arrow_schema::IntervalUnit::{DayTime, MonthDayNano, YearMonth};
use datafusion_common::{cast::as_primitive_array, Result};
use paste::paste;

macro_rules! test_array_negative_op {
($DATA_TY:tt, $($VALUE:expr),* ) => {
let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
let expr = negative(col("a", &schema)?, &schema)?;
assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
assert!(expr.nullable(&schema)?);
let mut arr = Vec::new();
let mut arr_expected = Vec::new();
$(
arr.push(Some($VALUE));
arr_expected.push(Some(-$VALUE));
)+
arr.push(None);
arr_expected.push(None);
let input = paste!{[<$DATA_TY Array>]::from(arr)};
let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)};
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result =
as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
assert_eq!(result, expected);
};
}

macro_rules! test_array_negative_op_intervals {
($DATA_TY:tt, $($VALUE:expr),* ) => {
let schema = Schema::new(vec![Field::new("a", DataType::Interval(IntervalUnit::$DATA_TY), true)]);
let expr = negative(col("a", &schema)?, &schema)?;
assert_eq!(expr.data_type(&schema)?, DataType::Interval(IntervalUnit::$DATA_TY));
assert!(expr.nullable(&schema)?);
let mut arr = Vec::new();
let mut arr_expected = Vec::new();
$(
arr.push(Some($VALUE));
arr_expected.push(Some(-$VALUE));
)+
arr.push(None);
arr_expected.push(None);
let input = paste!{[<Interval $DATA_TY Array>]::from(arr)};
let expected = &paste!{[<Interval $DATA_TY Array>]::from(arr_expected)};
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result =
as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
assert_eq!(result, expected);
};
}

#[test]
fn array_negative_op() -> Result<()> {
test_array_negative_op!(Int8, 2i8, 1i8);
test_array_negative_op!(Int16, 234i16, 123i16);
test_array_negative_op!(Int32, 2345i32, 1234i32);
test_array_negative_op!(Int64, 23456i64, 12345i64);
test_array_negative_op!(Float32, 2345.0f32, 1234.0f32);
test_array_negative_op!(Float64, 23456.0f64, 12345.0f64);
test_array_negative_op_intervals!(YearMonth, 2345i32, 1234i32);
test_array_negative_op_intervals!(DayTime, 23456i64, 12345i64);
test_array_negative_op_intervals!(MonthDayNano, 234567i128, 123456i128);
Ok(())
}
}

0 comments on commit 78d7a17

Please sign in to comment.