From 82219b5371410d78ce0c3ab7825690e5fad9b167 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Mon, 4 Nov 2024 14:41:18 +0100 Subject: [PATCH] fix: Fix stddev indeterministically producing NAN In the VarianceGroupAccumulator we were missing a `count == 0` check that is present in the normal Accumulator. This mostly does not matter except for the case where the first state to be merge has `count == 0` then the `merge` funciton will calculate a new m2 of NAN which will propagate to the final result. This fixes the bug bu adding the missing `count == 0` check. --- .../functions-aggregate/src/variance.rs | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 810247a2884a9..41e0048ba7a9d 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -316,6 +316,7 @@ fn merge( mean2: f64, m22: f64, ) -> (u64, f64, f64) { + debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states"); let new_count = count + count2; let new_mean = mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64; @@ -573,6 +574,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { partial_m2s, opt_filter, |group_index, partial_count, partial_mean, partial_m2| { + if partial_count == 0 { + return; + } let (new_count, new_mean, new_m2) = merge( self.counts[group_index], self.means[group_index], @@ -612,3 +616,32 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { + self.counts.capacity() * size_of::() } } + +#[cfg(test)] +mod tests { + use datafusion_expr::EmitTo; + + use super::*; + + #[test] + fn test_groups_accumulator_merge_empty_states() -> Result<()> { + let state_1 = vec![ + Arc::new(UInt64Array::from(vec![0])) as ArrayRef, + Arc::new(Float64Array::from(vec![0.0])), + Arc::new(Float64Array::from(vec![0.0])), + ]; + let state_2 = vec![ + Arc::new(UInt64Array::from(vec![2])) as ArrayRef, + Arc::new(Float64Array::from(vec![1.0])), + Arc::new(Float64Array::from(vec![0.0])), + ]; + let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample); + acc.merge_batch(&state_1, &[0], None, 1)?; + acc.merge_batch(&state_2, &[0], None, 1)?; + let result = acc.evaluate(EmitTo::All)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result.value(0), 0.0); + Ok(()) + } +}