diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index 3dbf1679e230d..63a4c85f9e80a 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -34,7 +34,7 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{Accumulator, ColumnarValue}; -use std::{any::Any, iter, sync::Arc}; +use std::{any::Any, sync::Arc}; /// APPROX_PERCENTILE_CONT aggregate expression #[derive(Debug)] @@ -284,7 +284,8 @@ impl ApproxPercentileAccumulator { } pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) { - self.digest = TDigest::merge_digests(digests); + let digests = digests.iter().chain(std::iter::once(&self.digest)); + self.digest = TDigest::merge_digests(digests) } pub(crate) fn convert_to_float(values: &ArrayRef) -> Result> { @@ -425,7 +426,6 @@ impl Accumulator for ApproxPercentileAccumulator { .collect::>>() .map(|state| TDigest::from_scalar_state(&state)) }) - .chain(iter::once(Ok(self.digest.clone()))) .collect::>>()?; self.merge_digests(&states); @@ -440,3 +440,34 @@ impl Accumulator for ApproxPercentileAccumulator { - std::mem::size_of_val(&self.return_type) } } + +#[cfg(test)] +mod tests { + use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator; + use crate::aggregate::tdigest::TDigest; + use arrow_schema::DataType; + + #[test] + fn test_combine_approx_percentile_accumulator() { + let mut digests: Vec = Vec::new(); + + // one TDigest with 50_000 values from 1 to 1_000 + for _ in 1..=50 { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000).map(f64::from).collect(); + let t = t.merge_unsorted_f64(values); + digests.push(t) + } + + let t1 = TDigest::merge_digests(&digests); + let t2 = TDigest::merge_digests(&digests); + + let mut accumulator = + ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); + + accumulator.merge_digests(&[t1]); + assert_eq!(accumulator.digest.count(), 50_000.0); + accumulator.merge_digests(&[t2]); + assert_eq!(accumulator.digest.count(), 100_000.0); + } +} diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 78708df94c25a..e3b23b91d0ffe 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -370,7 +370,10 @@ impl TDigest { } // Merge multiple T-Digests - pub(crate) fn merge_digests(digests: &[TDigest]) -> TDigest { + pub(crate) fn merge_digests<'a>( + digests: impl IntoIterator, + ) -> TDigest { + let digests = digests.into_iter().collect::>(); let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); if n_centroids == 0 { return TDigest::default();