Skip to content

Commit

Permalink
fix: Fix stddev indeterministically producing NAN
Browse files Browse the repository at this point in the history
In the VarianceGroupAccumulator we were missing a `count == 0` check
that is present in the normal Accumulator. This mostly does not matter
except for the case where the first state to be merge has `count == 0`
then the `merge` funciton will calculate a new m2 of NAN which will
propagate to the final result.

This fixes the bug bu adding the missing `count == 0` check.
  • Loading branch information
eejbyfeldt committed Nov 4, 2024
1 parent 2482ff4 commit 82219b5
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ fn merge(
mean2: f64,
m22: f64,
) -> (u64, f64, f64) {
debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
let new_count = count + count2;
let new_mean =
mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
Expand Down Expand Up @@ -573,6 +574,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
partial_m2s,
opt_filter,
|group_index, partial_count, partial_mean, partial_m2| {
if partial_count == 0 {
return;
}
let (new_count, new_mean, new_m2) = merge(
self.counts[group_index],
self.means[group_index],
Expand Down Expand Up @@ -612,3 +616,32 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
+ self.counts.capacity() * size_of::<u64>()
}
}

#[cfg(test)]
mod tests {
use datafusion_expr::EmitTo;

use super::*;

#[test]
fn test_groups_accumulator_merge_empty_states() -> Result<()> {
let state_1 = vec![
Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.0])),
Arc::new(Float64Array::from(vec![0.0])),
];
let state_2 = vec![
Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
Arc::new(Float64Array::from(vec![1.0])),
Arc::new(Float64Array::from(vec![0.0])),
];
let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
acc.merge_batch(&state_1, &[0], None, 1)?;
acc.merge_batch(&state_2, &[0], None, 1)?;
let result = acc.evaluate(EmitTo::All)?;
let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result.value(0), 0.0);
Ok(())
}
}

0 comments on commit 82219b5

Please sign in to comment.