Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interval Comparison #5180

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,7 @@ def_from_for_primitive!(Int8Type, i8);
def_from_for_primitive!(Int16Type, i16);
def_from_for_primitive!(Int32Type, i32);
def_from_for_primitive!(Int64Type, i64);
def_from_for_primitive!(Int128Type, i128);
def_from_for_primitive!(UInt8Type, u8);
def_from_for_primitive!(UInt16Type, u16);
def_from_for_primitive!(UInt32Type, u32);
Expand Down
70 changes: 70 additions & 0 deletions arrow-array/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,54 @@ macro_rules! downcast_primitive {
};
}

/// This macro functions similarly to [`downcast_primitive`], but it excludes
/// [`arrow_schema::IntervalUnit::DayTime`] and [`arrow_schema::IntervalUnit::MonthDayNano`]
/// because they cannot be simply cast to primitive types during a comparison operation.
#[macro_export]
macro_rules! downcast_primitive_cmp {
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_integer! {
$($data_type),+ => ($m $(, $args)*),
$crate::repeat_pat!(arrow_schema::DataType::Float16, $($data_type),+) => {
$m!($crate::types::Float16Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Float32, $($data_type),+) => {
$m!($crate::types::Float32Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Float64, $($data_type),+) => {
$m!($crate::types::Float64Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Decimal128(_, _), $($data_type),+) => {
$m!($crate::types::Decimal128Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Decimal256(_, _), $($data_type),+) => {
$m!($crate::types::Decimal256Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::YearMonth), $($data_type),+) => {
$m!($crate::types::IntervalYearMonthType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Second), $($data_type),+) => {
$m!($crate::types::DurationSecondType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Millisecond), $($data_type),+) => {
$m!($crate::types::DurationMillisecondType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Microsecond), $($data_type),+) => {
$m!($crate::types::DurationMicrosecondType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Nanosecond), $($data_type),+) => {
$m!($crate::types::DurationNanosecondType $(, $args)*)
}
_ => {
$crate::downcast_temporal! {
$($data_type),+ => ($m $(, $args)*),
$($p => $fallback,)*
}
}
}
};
}

#[macro_export]
#[doc(hidden)]
macro_rules! downcast_primitive_array_helper {
Expand Down Expand Up @@ -383,6 +431,28 @@ macro_rules! downcast_primitive_array {
};
}

/// This macro serves a similar function to [`downcast_primitive_array`], but it
/// incorporates [`downcast_primitive_cmp`]. [`downcast_primitive_cmp`] is a specialized
/// version of [`downcast_primitive`] designed specifically for comparison operations.
#[macro_export]
macro_rules! downcast_primitive_array_cmp {
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_array_cmp!($values => {$e} $($p => $fallback)*)
};
(($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_array_cmp!($($values),+ => {$e} $($p => $fallback)*)
};
($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_array_cmp!(($($values),+) => $e $($p => $fallback)*)
};
(($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_cmp!{
$($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e),
$($p => $fallback,)*
}
};
}

/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
/// [`PrimitiveArray<T>`], panic'ing on failure.
///
Expand Down
6 changes: 6 additions & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ make_type!(
DataType::Int64,
"A signed 64-bit integer type."
);
make_type!(
Int128Type,
i128,
DataType::Int128,
"A signed 128-bit integer type."
);
make_type!(
UInt8Type,
u8,
Expand Down
2 changes: 2 additions & 0 deletions arrow-data/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::Float16
| DataType::Float32
| DataType::Float64
Expand Down Expand Up @@ -1509,6 +1510,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout {
DataType::Int16 => DataTypeLayout::new_fixed_width::<i16>(),
DataType::Int32 => DataTypeLayout::new_fixed_width::<i32>(),
DataType::Int64 => DataTypeLayout::new_fixed_width::<i64>(),
DataType::Int128 => DataTypeLayout::new_fixed_width::<i128>(),
DataType::UInt8 => DataTypeLayout::new_fixed_width::<u8>(),
DataType::UInt16 => DataTypeLayout::new_fixed_width::<u16>(),
DataType::UInt32 => DataTypeLayout::new_fixed_width::<u32>(),
Expand Down
1 change: 1 addition & 0 deletions arrow-data/src/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn equal_values(
DataType::Int16 => primitive_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int32 => primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int64 => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int128 => primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Float32 => primitive_equal::<f32>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Float64 => primitive_equal::<f64>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Decimal128(_, _) => primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len),
Expand Down
3 changes: 3 additions & 0 deletions arrow-data/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::Int16 => primitive::build_extend::<i16>(array),
DataType::Int32 => primitive::build_extend::<i32>(array),
DataType::Int64 => primitive::build_extend::<i64>(array),
DataType::Int128 => primitive::build_extend::<i128>(array),
DataType::Float32 => primitive::build_extend::<f32>(array),
DataType::Float64 => primitive::build_extend::<f64>(array),
DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => {
Expand Down Expand Up @@ -251,6 +252,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
DataType::Int16 => primitive::extend_nulls::<i16>,
DataType::Int32 => primitive::extend_nulls::<i32>,
DataType::Int64 => primitive::extend_nulls::<i64>,
DataType::Int128 => primitive::extend_nulls::<i128>,
DataType::Float32 => primitive::extend_nulls::<f32>,
DataType::Float64 => primitive::extend_nulls::<f64>,
DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => {
Expand Down Expand Up @@ -404,6 +406,7 @@ impl<'a> MutableArrayData<'a> {
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::Float16
| DataType::Float32
| DataType::Float64
Expand Down
1 change: 1 addition & 0 deletions arrow-integration-test/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
DataType::Int128 => json!({"name": "int", "bitWidth": 128, "isSigned": true}),
DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
Expand Down
3 changes: 2 additions & 1 deletion arrow-ipc/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(children),
}
}
Int8 | Int16 | Int32 | Int64 => {
Int8 | Int16 | Int32 | Int64 | Int128 => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = crate::IntBuilder::new(fbb);
builder.add_is_signed(true);
Expand All @@ -508,6 +508,7 @@ pub(crate) fn get_fb_field_type<'a>(
Int16 => builder.add_bitWidth(16),
Int32 => builder.add_bitWidth(32),
Int64 => builder.add_bitWidth(64),
Int128 => builder.add_bitWidth(128),
_ => {}
};
FBFieldType {
Expand Down
132 changes: 126 additions & 6 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
//!

use std::ops::Not;

use arrow_array::cast::AsArray;
use arrow_array::types::ByteArrayType;
use arrow_array::types::{
ByteArrayType, Int128Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType,
};
use arrow_array::{
downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum,
FixedSizeBinaryArray, GenericByteArray,
downcast_primitive_array_cmp, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray,
Datum, FixedSizeBinaryArray, GenericByteArray, PrimitiveArray,
};
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
use arrow_schema::ArrowError;
use arrow_schema::{ArrowError, IntervalUnit};
use arrow_select::take::take;
use std::ops::Not;

#[derive(Debug, Copy, Clone)]
enum Op {
Expand Down Expand Up @@ -206,14 +209,16 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,

// Defer computation as may not be necessary
let values = || -> BooleanBuffer {
let d = downcast_primitive_array! {
let d = downcast_primitive_array_cmp! {
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => apply(op, safer_interval_dt(l, op, true).values().as_ref(), l_s, l_v, safer_interval_dt(r, op, false).values().as_ref(), r_s, r_v),
(Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => apply(op, safer_interval_mdn(l, op, true).values().as_ref(), l_s, l_v, safer_interval_mdn(r, op, false).values().as_ref(), r_s, r_v),
(Null, Null) => None,
_ => unreachable!(),
};
Expand Down Expand Up @@ -540,6 +545,121 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
}
}

/// Computes max or min milliseconds from a `PrimitiveArray<IntervalDayTimeType>` based on
/// the comparison operator (`op`) and operand side (`lhs`). This function is essential for
/// accurate interval comparison operations by considering the leap seconds.
///
/// # Arguments
/// * `dt` - Reference to an array, expected to be `PrimitiveArray<IntervalDayTimeType>`.
/// * `op` - Comparison operator.
/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator.
///
/// # Returns
/// A `PrimitiveArray<Int64Type>` with computed milliseconds values.
///
/// # Panics
/// If `dt` is not a `PrimitiveArray<IntervalDayTimeType>` or if an invalid operator is used.
#[inline]
fn safer_interval_dt(dt: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray<Int64Type> {
match dt.as_primitive_opt::<IntervalDayTimeType>() {
Some(dt) => match (op, lhs) {
(Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_max)))
}
(Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_min)))
}
(Op::Equal | Op::NotEqual, _) => PrimitiveArray::<Int64Type>::from_iter(dt.iter()),
_ => {
panic!(
"Invalid operator {:?} for Interval(IntervalDayTime) comparison",
op
)
}
},
_ => {
panic!("Invalid datatype for Interval(IntervalDayTime) comparison")
}
}
}

/// Computes max or min nanoseconds from a `PrimitiveArray<IntervalMonthDayNanoType>` based on
/// the comparison operator (`op`) and operand side (`lhs`). This function is crucial for
/// precise interval comparison operations involving months and days, which can result in different
/// number of nanoseconds depending on the timestamp.
///
/// # Arguments
/// * `mdn` - Reference to an array, expected to be `PrimitiveArray<IntervalMonthDayNanoType>`.
/// * `op` - Comparison operator.
/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator.
///
/// # Returns
/// A `PrimitiveArray<Int128Type>` with computed nanoseconds values.
///
/// # Panics
/// If `mdn` is not a `PrimitiveArray<IntervalMonthDayNanoType>` or if an invalid operator is used.
#[inline]
fn safer_interval_mdn(mdn: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray<Int128Type> {
match mdn.as_primitive_opt::<IntervalMonthDayNanoType>() {
Some(mdn) => match (op, lhs) {
(Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => {
PrimitiveArray::<Int128Type>::from_iter(
mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)),
)
}
(Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => {
PrimitiveArray::<Int128Type>::from_iter(
mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)),
)
}
(Op::Equal | Op::NotEqual, _) => PrimitiveArray::<Int128Type>::from_iter(mdn.iter()),
_ => {
panic!("Invalid operator for Interval(IntervalMonthDayNano) comparison")
}
},
_ => {
panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison")
}
}
}

/// Calculates the maximum milliseconds for an `IntervalDayTimeType` interval, accounting
/// for leap seconds by adding an extra 1000 milliseconds for each day.
#[inline]
fn dt_in_millis_max(dt: i64) -> i64 {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000 + 1_000) + m
}

/// Calculates the minimum milliseconds for an `IntervalDayTimeType` interval, excluding leap seconds.
#[inline]
fn dt_in_millis_min(dt: i64) -> i64 {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000) + m
}

/// Calculates the maximum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming
/// 31 days per month and adding extra nanoseconds for longer days.
#[inline]
fn mdn_in_nanos_max(mdn: i128) -> i128 {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 31) + d as i128) * (86_400_000_000_000 + 1_000_000_000) + n as i128
}

/// Calculates the minimum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming
/// 28 days per month and excluding additional nanoseconds for longer days.
#[inline]
fn mdn_in_nanos_min(mdn: i128) -> i128 {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 28) + d as i128) * (86_400_000_000_000) + n as i128
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down
47 changes: 47 additions & 0 deletions arrow-ord/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2036,6 +2036,53 @@ mod tests {
);
}

#[test]
fn test_interval_array_unit_aware() {
let a =
IntervalDayTimeArray::from(vec![Some(IntervalDayTimeType::make_value(0, -5)),Some(IntervalDayTimeType::make_value(3, -1_000_000)),Some(IntervalDayTimeType::make_value(4, -1000)),Some(IntervalDayTimeType::make_value(10, 20)),Some(IntervalDayTimeType::make_value(1, 2))]);
let b =
IntervalDayTimeArray::from(vec![Some(IntervalDayTimeType::make_value(0, -10)),Some(IntervalDayTimeType::make_value(3, -2_000_000)),Some(IntervalDayTimeType::make_value(2, 1000)),Some(IntervalDayTimeType::make_value(5, 6)),Some(IntervalDayTimeType::make_value(1, 1))]);
let res = gt(&a, &b).unwrap();
let res_eq = gt_eq(&a, &b).unwrap();
assert_eq!(res, res_eq);
assert_eq!(
&res,
&BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)])
);
let res = lt(&b, &a).unwrap();
let res_eq = lt_eq(&b, &a).unwrap();
assert_eq!(res, res_eq);
assert_eq!(
&res,
&BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)])
);

let a = IntervalMonthDayNanoArray::from(
vec![Some(IntervalMonthDayNanoType::make_value(0, 0, 1)),Some(IntervalMonthDayNanoType::make_value(0, 1, -1_000_000_000)),Some(IntervalMonthDayNanoType::make_value(3, 2, -100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 1)),Some(IntervalMonthDayNanoType::make_value(1, 28, 0)), Some(IntervalMonthDayNanoType::make_value(10, 0, -1_000_000_000_000))],
);
let b = IntervalMonthDayNanoArray::from(
vec![Some(IntervalMonthDayNanoType::make_value(0, 0,0)),Some(IntervalMonthDayNanoType::make_value(0, 1, -8_000_000_000)),Some(IntervalMonthDayNanoType::make_value(1, 25, 100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 0)),Some(IntervalMonthDayNanoType::make_value(2, 0, 0)), Some(IntervalMonthDayNanoType::make_value(5, 150, 1_000_000_000_000))],
);
let res = gt(&a, &b).unwrap();
let res_eq = gt_eq(&a, &b).unwrap();
assert_eq!(res, res_eq);
assert_eq!(
&res,
&BooleanArray::from(
vec![ Some(true), Some(true),Some(true),Some(false),Some(false), Some(false)]
)
);
let res = lt(&b, &a).unwrap();
let res_eq = lt_eq(&b, &a).unwrap();
assert_eq!(res, res_eq);
assert_eq!(
&res,
&BooleanArray::from(
vec![ Some(true), Some(true),Some(true),Some(false),Some(false), Some(false)]
)
);
}

macro_rules! test_binary {
($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
#[test]
Expand Down
Loading
Loading