Skip to content

Commit

Permalink
Intervals account for a reference timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
berkaysynnada committed Dec 7, 2023
1 parent 48bb07d commit bc4967e
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 107 deletions.
229 changes: 126 additions & 103 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => apply(op, safer_interval_dt(l, op, true).values().as_ref(), l_s, l_v, safer_interval_dt(r, op, false).values().as_ref(), r_s, r_v),
(Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => apply(op, safer_interval_mdn(l, op, true).values().as_ref(), l_s, l_v, safer_interval_mdn(r, op, false).values().as_ref(), r_s, r_v),
(Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => apply_interval_dt(op, l, l_s, l_v, r, r_s, r_v),
(Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => apply_interval_mdn(op, l, l_s, l_v, r, r_s, r_v),
(Null, Null) => None,
_ => unreachable!(),
};
Expand Down Expand Up @@ -346,6 +346,82 @@ fn apply<T: ArrayOrd>(
}
}

fn apply_interval_dt(
op: Op,
l: &dyn Array,
l_s: bool,
l_v: Option<&dyn AnyDictionaryArray>,
r: &dyn Array,
r_s: bool,
r_v: Option<&dyn AnyDictionaryArray>,
) -> Option<BooleanBuffer> {
let evaluate_min = apply(
op,
interval_dt_min(l).values().as_ref(),
l_s,
l_v,
interval_dt_min(r).values().as_ref(),
r_s,
r_v,
);
let evaluate_max = apply(
op,
interval_dt_max(l).values().as_ref(),
l_s,
l_v,
interval_dt_max(r).values().as_ref(),
r_s,
r_v,
);
definite_comparison(evaluate_min, evaluate_max)
}

fn apply_interval_mdn(
op: Op,
l: &dyn Array,
l_s: bool,
l_v: Option<&dyn AnyDictionaryArray>,
r: &dyn Array,
r_s: bool,
r_v: Option<&dyn AnyDictionaryArray>,
) -> Option<BooleanBuffer> {
let evaluate_min = apply(
op,
interval_mdn_min(l).values().as_ref(),
l_s,
l_v,
interval_mdn_min(r).values().as_ref(),
r_s,
r_v,
);
let evaluate_max = apply(
op,
interval_mdn_max(l).values().as_ref(),
l_s,
l_v,
interval_mdn_max(r).values().as_ref(),
r_s,
r_v,
);
definite_comparison(evaluate_min, evaluate_max)
}

fn definite_comparison(
min: Option<BooleanBuffer>,
max: Option<BooleanBuffer>,
) -> Option<BooleanBuffer> {
min.and_then(|min_values| {
max.map(|max_values| {
BooleanBuffer::from_iter(
min_values
.into_iter()
.zip(max_values.into_iter())
.map(|(min, max)| min & max),
)
})
})
}

/// Perform a take operation on `buffer` with the given dictionary
fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer {
let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap();
Expand Down Expand Up @@ -545,119 +621,66 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
}
}

/// Computes max or min milliseconds from a `PrimitiveArray<IntervalDayTimeType>` based on
/// the comparison operator (`op`) and operand side (`lhs`). This function is essential for
/// accurate interval comparison operations by considering the leap seconds.
///
/// # Arguments
/// * `dt` - Reference to an array, expected to be `PrimitiveArray<IntervalDayTimeType>`.
/// * `op` - Comparison operator.
/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator.
///
/// # Returns
/// A `PrimitiveArray<Int64Type>` with computed milliseconds values.
///
/// # Panics
/// If `dt` is not a `PrimitiveArray<IntervalDayTimeType>` or if an invalid operator is used.
#[inline]
fn safer_interval_dt(dt: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray<Int64Type> {
match dt.as_primitive_opt::<IntervalDayTimeType>() {
Some(dt) => match (op, lhs) {
(Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_max)))
}
(Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_min)))
}
(Op::Equal | Op::NotEqual, _) => PrimitiveArray::<Int64Type>::from_iter(dt.iter()),
_ => {
panic!(
"Invalid operator {:?} for Interval(IntervalDayTime) comparison",
op
)
}
},
_ => {
panic!("Invalid datatype for Interval(IntervalDayTime) comparison")
}
fn interval_dt_min(dt: &dyn Array) -> PrimitiveArray<Int64Type> {
if let Some(dt) = dt.as_primitive_opt::<IntervalDayTimeType>() {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| {
dt.map(|dt| {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000) + m
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalDayTime) comparison")
}
}

/// Computes max or min nanoseconds from a `PrimitiveArray<IntervalMonthDayNanoType>` based on
/// the comparison operator (`op`) and operand side (`lhs`). This function is crucial for
/// precise interval comparison operations involving months and days, which can result in different
/// number of nanoseconds depending on the timestamp.
///
/// # Arguments
/// * `mdn` - Reference to an array, expected to be `PrimitiveArray<IntervalMonthDayNanoType>`.
/// * `op` - Comparison operator.
/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator.
///
/// # Returns
/// A `PrimitiveArray<Int128Type>` with computed nanoseconds values.
///
/// # Panics
/// If `mdn` is not a `PrimitiveArray<IntervalMonthDayNanoType>` or if an invalid operator is used.
#[inline]
fn safer_interval_mdn(mdn: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray<Int128Type> {
match mdn.as_primitive_opt::<IntervalMonthDayNanoType>() {
Some(mdn) => match (op, lhs) {
(Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => {
PrimitiveArray::<Int128Type>::from_iter(
mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)),
)
}
(Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => {
PrimitiveArray::<Int128Type>::from_iter(
mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)),
)
}
(Op::Equal | Op::NotEqual, _) => PrimitiveArray::<Int128Type>::from_iter(mdn.iter()),
_ => {
panic!("Invalid operator for Interval(IntervalMonthDayNano) comparison")
}
},
_ => {
panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison")
}
fn interval_dt_max(dt: &dyn Array) -> PrimitiveArray<Int64Type> {
if let Some(dt) = dt.as_primitive_opt::<IntervalDayTimeType>() {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| {
dt.map(|dt| {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000 + 1_000) + m
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalDayTime) comparison")
}
}

/// Calculates the maximum milliseconds for an `IntervalDayTimeType` interval, accounting
/// for leap seconds by adding an extra 1000 milliseconds for each day.
#[inline]
fn dt_in_millis_max(dt: i64) -> i64 {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000 + 1_000) + m
}

/// Calculates the minimum milliseconds for an `IntervalDayTimeType` interval, excluding leap seconds.
#[inline]
fn dt_in_millis_min(dt: i64) -> i64 {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000) + m
}

/// Calculates the maximum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming
/// 31 days per month and adding extra nanoseconds for longer days.
#[inline]
fn mdn_in_nanos_max(mdn: i128) -> i128 {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 31) + d as i128) * (86_400_000_000_000 + 1_000_000_000) + n as i128
fn interval_mdn_min(mdn: &dyn Array) -> PrimitiveArray<Int128Type> {
if let Some(mdn) = mdn.as_primitive_opt::<IntervalMonthDayNanoType>() {
PrimitiveArray::<Int128Type>::from_iter(mdn.iter().map(|mdn| {
mdn.map(|mdn| {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 28) + d as i128) * (86_400_000_000_000) + n as i128
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison")
}
}

/// Calculates the minimum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming
/// 28 days per month and excluding additional nanoseconds for longer days.
#[inline]
fn mdn_in_nanos_min(mdn: i128) -> i128 {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 28) + d as i128) * (86_400_000_000_000) + n as i128
fn interval_mdn_max(mdn: &dyn Array) -> PrimitiveArray<Int128Type> {
if let Some(mdn) = mdn.as_primitive_opt::<IntervalMonthDayNanoType>() {
PrimitiveArray::<Int128Type>::from_iter(mdn.iter().map(|mdn| {
mdn.map(|mdn| {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 31) + d as i128) * (86_400_000_000_000 + 1_000_000_000) + n as i128
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison")
}
}

#[cfg(test)]
Expand Down
8 changes: 4 additions & 4 deletions arrow-ord/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2047,21 +2047,21 @@ mod tests {
assert_eq!(res, res_eq);
assert_eq!(
&res,
&BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)])
&BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(true)])
);
let res = lt(&b, &a).unwrap();
let res_eq = lt_eq(&b, &a).unwrap();
assert_eq!(res, res_eq);
assert_eq!(
&res,
&BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)])
&BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(true)])
);

let a = IntervalMonthDayNanoArray::from(
vec![Some(IntervalMonthDayNanoType::make_value(0, 0, 1)),Some(IntervalMonthDayNanoType::make_value(0, 1, -1_000_000_000)),Some(IntervalMonthDayNanoType::make_value(3, 2, -100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 1)),Some(IntervalMonthDayNanoType::make_value(1, 28, 0)), Some(IntervalMonthDayNanoType::make_value(10, 0, -1_000_000_000_000))],
vec![Some(IntervalMonthDayNanoType::make_value(0, 0, 1)),Some(IntervalMonthDayNanoType::make_value(0, 1, -1_000_000_000)),Some(IntervalMonthDayNanoType::make_value(3, 2, -100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 86_400_000_000_999)),Some(IntervalMonthDayNanoType::make_value(1, 28, 0)), Some(IntervalMonthDayNanoType::make_value(10, 0, -1_000_000_000_000))],
);
let b = IntervalMonthDayNanoArray::from(
vec![Some(IntervalMonthDayNanoType::make_value(0, 0,0)),Some(IntervalMonthDayNanoType::make_value(0, 1, -8_000_000_000)),Some(IntervalMonthDayNanoType::make_value(1, 25, 100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 0)),Some(IntervalMonthDayNanoType::make_value(2, 0, 0)), Some(IntervalMonthDayNanoType::make_value(5, 150, 1_000_000_000_000))],
vec![Some(IntervalMonthDayNanoType::make_value(0, 0,0)),Some(IntervalMonthDayNanoType::make_value(0, 1, -8_000_000_000)),Some(IntervalMonthDayNanoType::make_value(1, 25, 100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 2, 0)),Some(IntervalMonthDayNanoType::make_value(2, 0, 0)), Some(IntervalMonthDayNanoType::make_value(5, 150, 1_000_000_000_000))],
);
let res = gt(&a, &b).unwrap();
let res_eq = gt_eq(&a, &b).unwrap();
Expand Down

0 comments on commit bc4967e

Please sign in to comment.