Skip to content

Commit

Permalink
fix: use total ordering in the min & max accumulator for floats (#10627)
Browse files Browse the repository at this point in the history
* fix: use total ordering in the min & max accumulator for floats to match the ordering used by arrow kernels

* change unit test to expect min to be nan

* changed behavior again since the partial_cmp approach doesn't handle nulls correctly

* Revert change to describe test.  It was not originating from a nan/finite discrepency but from a null/defined discrepency and we don't want that behavior to change

* Update the test to check the min function and also verify the result
  • Loading branch information
westonpace authored Jun 7, 2024
1 parent cb9068c commit 5bb6b35
Showing 1 changed file with 56 additions and 4 deletions.
60 changes: 56 additions & 4 deletions datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,20 @@ macro_rules! typed_min_max {
}};
}

macro_rules! typed_min_max_float {
($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
(None, None) => None,
(Some(a), None) => Some(*a),
(None, Some(b)) => Some(*b),
(Some(a), Some(b)) => match a.total_cmp(b) {
choose_min_max!($OP) => Some(*b),
_ => Some(*a),
},
})
}};
}

// min/max of two scalar string values.
macro_rules! typed_min_max_string {
($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
Expand All @@ -500,7 +514,7 @@ macro_rules! typed_min_max_string {
}};
}

macro_rules! interval_choose_min_max {
macro_rules! choose_min_max {
(min) => {
std::cmp::Ordering::Greater
};
Expand All @@ -512,7 +526,7 @@ macro_rules! interval_choose_min_max {
macro_rules! interval_min_max {
($OP:tt, $LHS:expr, $RHS:expr) => {{
match $LHS.partial_cmp(&$RHS) {
Some(interval_choose_min_max!($OP)) => $RHS.clone(),
Some(choose_min_max!($OP)) => $RHS.clone(),
Some(_) => $LHS.clone(),
None => {
return internal_err!("Comparison error while computing interval min/max")
Expand Down Expand Up @@ -555,10 +569,10 @@ macro_rules! min_max {
typed_min_max!(lhs, rhs, Boolean, $OP)
}
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
typed_min_max!(lhs, rhs, Float64, $OP)
typed_min_max_float!(lhs, rhs, Float64, $OP)
}
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
typed_min_max!(lhs, rhs, Float32, $OP)
typed_min_max_float!(lhs, rhs, Float32, $OP)
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
typed_min_max!(lhs, rhs, UInt64, $OP)
Expand Down Expand Up @@ -1103,3 +1117,41 @@ impl Accumulator for SlidingMinAccumulator {
std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn float_min_max_with_nans() {
let pos_nan = f32::NAN;
let zero = 0_f32;
let neg_inf = f32::NEG_INFINITY;

let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| {
for batch in values.iter() {
let batch =
Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
acc.update_batch(&[batch]).unwrap();
}
let result = acc.evaluate().unwrap();
assert_eq!(result, ScalarValue::Float32(Some(expected)));
};

// This test checks both comparison between batches (which uses the min_max macro
// defined above) and within a batch (which uses the arrow min/max compute function
// and verifies both respect the total order comparison for floats)

let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();

check(&mut min(), &[&[zero], &[pos_nan]], zero);
check(&mut min(), &[&[zero, pos_nan]], zero);
check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
check(&mut min(), &[&[zero, neg_inf]], neg_inf);
check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
check(&mut max(), &[&[zero, pos_nan]], pos_nan);
check(&mut max(), &[&[zero], &[neg_inf]], zero);
check(&mut max(), &[&[zero, neg_inf]], zero);
}
}

0 comments on commit 5bb6b35

Please sign in to comment.