Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/tdigest.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one or more
2
// contributor license agreements.  See the NOTICE file distributed with this
3
// work for additional information regarding copyright ownership.  The ASF
4
// licenses this file to you under the Apache License, Version 2.0 (the
5
// "License"); you may not use this file except in compliance with the License.
6
// You may obtain a copy of the License at
7
//
8
//   http://www.apache.org/licenses/LICENSE-2.0
9
//
10
// Unless required by applicable law or agreed to in writing, software
11
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
13
// License for the specific language governing permissions and limitations under
14
// the License.
15
16
//! An implementation of the [TDigest sketch algorithm] providing approximate
17
//! quantile calculations.
18
//!
19
//! The TDigest code in this module is modified from
20
//! <https://github.com/MnO2/t-digest>, itself a rust reimplementation of
21
//! [Facebook's Folly TDigest] implementation.
22
//!
23
//! Alterations include reduction of runtime heap allocations, broader type
24
//! support, (de-)serialisation support, reduced type conversions and null value
25
//! tolerance.
26
//!
27
//! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023
28
//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h
29
30
use arrow::datatypes::DataType;
31
use arrow::datatypes::Float64Type;
32
use datafusion_common::cast::as_primitive_array;
33
use datafusion_common::Result;
34
use datafusion_common::ScalarValue;
35
use std::cmp::Ordering;
36
37
pub const DEFAULT_MAX_SIZE: usize = 100;
38
39
// Cast a non-null [`ScalarValue::Float64`] to an [`f64`], or
40
// panic.
41
macro_rules! cast_scalar_f64 {
42
    ($value:expr ) => {
43
        match &$value {
44
            ScalarValue::Float64(Some(v)) => *v,
45
            v => panic!("invalid type {:?}", v),
46
        }
47
    };
48
}
49
50
// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or
51
// panic.
52
macro_rules! cast_scalar_u64 {
53
    ($value:expr ) => {
54
        match &$value {
55
            ScalarValue::UInt64(Some(v)) => *v,
56
            v => panic!("invalid type {:?}", v),
57
        }
58
    };
59
}
60
61
/// This trait is implemented for each type a [`TDigest`] can operate on,
62
/// allowing it to support both numerical rust types (obtained from
63
/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
64
pub trait TryIntoF64 {
65
    /// A fallible conversion of a possibly null `self` into a [`f64`].
66
    ///
67
    /// If `self` is null, this method must return `Ok(None)`.
68
    ///
69
    /// If `self` cannot be coerced to the desired type, this method must return
70
    /// an `Err` variant.
71
    fn try_as_f64(&self) -> Result<Option<f64>>;
72
}
73
74
/// Generate an infallible conversion from `type` to an [`f64`].
75
macro_rules! impl_try_ordered_f64 {
76
    ($type:ty) => {
77
        impl TryIntoF64 for $type {
78
0
            fn try_as_f64(&self) -> Result<Option<f64>> {
79
0
                Ok(Some(*self as f64))
80
0
            }
81
        }
82
    };
83
}
84
85
impl_try_ordered_f64!(f64);
86
impl_try_ordered_f64!(f32);
87
impl_try_ordered_f64!(i64);
88
impl_try_ordered_f64!(i32);
89
impl_try_ordered_f64!(i16);
90
impl_try_ordered_f64!(i8);
91
impl_try_ordered_f64!(u64);
92
impl_try_ordered_f64!(u32);
93
impl_try_ordered_f64!(u16);
94
impl_try_ordered_f64!(u8);
95
96
/// Centroid implementation to the cluster mentioned in the paper.
97
#[derive(Debug, PartialEq, Clone)]
98
pub struct Centroid {
99
    mean: f64,
100
    weight: f64,
101
}
102
103
impl PartialOrd for Centroid {
104
0
    fn partial_cmp(&self, other: &Centroid) -> Option<Ordering> {
105
0
        Some(self.cmp(other))
106
0
    }
107
}
108
109
impl Eq for Centroid {}
110
111
impl Ord for Centroid {
112
0
    fn cmp(&self, other: &Centroid) -> Ordering {
113
0
        self.mean.total_cmp(&other.mean)
114
0
    }
115
}
116
117
impl Centroid {
118
0
    pub fn new(mean: f64, weight: f64) -> Self {
119
0
        Centroid { mean, weight }
120
0
    }
121
122
    #[inline]
123
0
    pub fn mean(&self) -> f64 {
124
0
        self.mean
125
0
    }
126
127
    #[inline]
128
0
    pub fn weight(&self) -> f64 {
129
0
        self.weight
130
0
    }
131
132
0
    pub fn add(&mut self, sum: f64, weight: f64) -> f64 {
133
0
        let new_sum = sum + self.weight * self.mean;
134
0
        let new_weight = self.weight + weight;
135
0
        self.weight = new_weight;
136
0
        self.mean = new_sum / new_weight;
137
0
        new_sum
138
0
    }
139
}
140
141
impl Default for Centroid {
142
0
    fn default() -> Self {
143
0
        Centroid {
144
0
            mean: 0_f64,
145
0
            weight: 1_f64,
146
0
        }
147
0
    }
148
}
149
150
/// T-Digest to be operated on.
151
#[derive(Debug, PartialEq, Clone)]
152
pub struct TDigest {
153
    centroids: Vec<Centroid>,
154
    max_size: usize,
155
    sum: f64,
156
    count: u64,
157
    max: f64,
158
    min: f64,
159
}
160
161
impl TDigest {
162
0
    pub fn new(max_size: usize) -> Self {
163
0
        TDigest {
164
0
            centroids: Vec::new(),
165
0
            max_size,
166
0
            sum: 0_f64,
167
0
            count: 0,
168
0
            max: f64::NAN,
169
0
            min: f64::NAN,
170
0
        }
171
0
    }
172
173
0
    pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self {
174
0
        TDigest {
175
0
            centroids: vec![centroid.clone()],
176
0
            max_size,
177
0
            sum: centroid.mean * centroid.weight,
178
0
            count: 1,
179
0
            max: centroid.mean,
180
0
            min: centroid.mean,
181
0
        }
182
0
    }
183
184
    #[inline]
185
0
    pub fn count(&self) -> u64 {
186
0
        self.count
187
0
    }
188
189
    #[inline]
190
0
    pub fn max(&self) -> f64 {
191
0
        self.max
192
0
    }
193
194
    #[inline]
195
0
    pub fn min(&self) -> f64 {
196
0
        self.min
197
0
    }
198
199
    #[inline]
200
0
    pub fn max_size(&self) -> usize {
201
0
        self.max_size
202
0
    }
203
204
    /// Size in bytes including `Self`.
205
0
    pub fn size(&self) -> usize {
206
0
        std::mem::size_of_val(self)
207
0
            + (std::mem::size_of::<Centroid>() * self.centroids.capacity())
208
0
    }
209
}
210
211
impl Default for TDigest {
212
0
    fn default() -> Self {
213
0
        TDigest {
214
0
            centroids: Vec::new(),
215
0
            max_size: 100,
216
0
            sum: 0_f64,
217
0
            count: 0,
218
0
            max: f64::NAN,
219
0
            min: f64::NAN,
220
0
        }
221
0
    }
222
}
223
224
impl TDigest {
225
0
    fn k_to_q(k: u64, d: usize) -> f64 {
226
0
        let k_div_d = k as f64 / d as f64;
227
0
        if k_div_d >= 0.5 {
228
0
            let base = 1.0 - k_div_d;
229
0
            1.0 - 2.0 * base * base
230
        } else {
231
0
            2.0 * k_div_d * k_div_d
232
        }
233
0
    }
234
235
0
    fn clamp(v: f64, lo: f64, hi: f64) -> f64 {
236
0
        if lo.is_nan() || hi.is_nan() {
237
0
            return v;
238
0
        }
239
0
        v.clamp(lo, hi)
240
0
    }
241
242
    // public for testing in other modules
243
0
    pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest {
244
0
        let mut values = unsorted_values;
245
0
        values.sort_by(|a, b| a.total_cmp(b));
246
0
        self.merge_sorted_f64(&values)
247
0
    }
248
249
0
    pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
250
0
        #[cfg(debug_assertions)]
251
0
        debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
252
253
0
        if sorted_values.is_empty() {
254
0
            return self.clone();
255
0
        }
256
0
257
0
        let mut result = TDigest::new(self.max_size());
258
0
        result.count = self.count() + sorted_values.len() as u64;
259
0
260
0
        let maybe_min = *sorted_values.first().unwrap();
261
0
        let maybe_max = *sorted_values.last().unwrap();
262
0
263
0
        if self.count() > 0 {
264
0
            result.min = self.min.min(maybe_min);
265
0
            result.max = self.max.max(maybe_max);
266
0
        } else {
267
0
            result.min = maybe_min;
268
0
            result.max = maybe_max;
269
0
        }
270
271
0
        let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
272
0
273
0
        let mut k_limit: u64 = 1;
274
0
        let mut q_limit_times_count =
275
0
            Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
276
0
        k_limit += 1;
277
0
278
0
        let mut iter_centroids = self.centroids.iter().peekable();
279
0
        let mut iter_sorted_values = sorted_values.iter().peekable();
280
281
0
        let mut curr: Centroid = if let Some(c) = iter_centroids.peek() {
282
0
            let curr = **iter_sorted_values.peek().unwrap();
283
0
            if c.mean() < curr {
284
0
                iter_centroids.next().unwrap().clone()
285
            } else {
286
0
                Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
287
            }
288
        } else {
289
0
            Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
290
        };
291
292
0
        let mut weight_so_far = curr.weight();
293
0
294
0
        let mut sums_to_merge = 0_f64;
295
0
        let mut weights_to_merge = 0_f64;
296
297
0
        while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() {
298
0
            let next: Centroid = if let Some(c) = iter_centroids.peek() {
299
0
                if iter_sorted_values.peek().is_none()
300
0
                    || c.mean() < **iter_sorted_values.peek().unwrap()
301
                {
302
0
                    iter_centroids.next().unwrap().clone()
303
                } else {
304
0
                    Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
305
                }
306
            } else {
307
0
                Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
308
            };
309
310
0
            let next_sum = next.mean() * next.weight();
311
0
            weight_so_far += next.weight();
312
0
313
0
            if weight_so_far <= q_limit_times_count {
314
0
                sums_to_merge += next_sum;
315
0
                weights_to_merge += next.weight();
316
0
            } else {
317
0
                result.sum += curr.add(sums_to_merge, weights_to_merge);
318
0
                sums_to_merge = 0_f64;
319
0
                weights_to_merge = 0_f64;
320
0
321
0
                compressed.push(curr.clone());
322
0
                q_limit_times_count =
323
0
                    Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
324
0
                k_limit += 1;
325
0
                curr = next;
326
0
            }
327
        }
328
329
0
        result.sum += curr.add(sums_to_merge, weights_to_merge);
330
0
        compressed.push(curr);
331
0
        compressed.shrink_to_fit();
332
0
        compressed.sort();
333
0
334
0
        result.centroids = compressed;
335
0
        result
336
0
    }
337
338
0
    fn external_merge(
339
0
        centroids: &mut [Centroid],
340
0
        first: usize,
341
0
        middle: usize,
342
0
        last: usize,
343
0
    ) {
344
0
        let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len());
345
0
346
0
        let mut i = first;
347
0
        let mut j = middle;
348
349
0
        while i < middle && j < last {
350
0
            match centroids[i].cmp(&centroids[j]) {
351
0
                Ordering::Less => {
352
0
                    result.push(centroids[i].clone());
353
0
                    i += 1;
354
0
                }
355
0
                Ordering::Greater => {
356
0
                    result.push(centroids[j].clone());
357
0
                    j += 1;
358
0
                }
359
0
                Ordering::Equal => {
360
0
                    result.push(centroids[i].clone());
361
0
                    i += 1;
362
0
                }
363
            }
364
        }
365
366
0
        while i < middle {
367
0
            result.push(centroids[i].clone());
368
0
            i += 1;
369
0
        }
370
371
0
        while j < last {
372
0
            result.push(centroids[j].clone());
373
0
            j += 1;
374
0
        }
375
376
0
        i = first;
377
0
        for centroid in result.into_iter() {
378
0
            centroids[i] = centroid;
379
0
            i += 1;
380
0
        }
381
0
    }
382
383
    // Merge multiple T-Digests
384
0
    pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>) -> TDigest {
385
0
        let digests = digests.into_iter().collect::<Vec<_>>();
386
0
        let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum();
387
0
        if n_centroids == 0 {
388
0
            return TDigest::default();
389
0
        }
390
0
391
0
        let max_size = digests.first().unwrap().max_size;
392
0
        let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
393
0
        let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
394
0
395
0
        let mut count = 0;
396
0
        let mut min = f64::INFINITY;
397
0
        let mut max = f64::NEG_INFINITY;
398
0
399
0
        let mut start: usize = 0;
400
0
        for digest in digests.iter() {
401
0
            starts.push(start);
402
0
403
0
            let curr_count = digest.count();
404
0
            if curr_count > 0 {
405
0
                min = min.min(digest.min);
406
0
                max = max.max(digest.max);
407
0
                count += curr_count;
408
0
                for centroid in &digest.centroids {
409
0
                    centroids.push(centroid.clone());
410
0
                    start += 1;
411
0
                }
412
0
            }
413
        }
414
415
0
        let mut digests_per_block: usize = 1;
416
0
        while digests_per_block < starts.len() {
417
0
            for i in (0..starts.len()).step_by(digests_per_block * 2) {
418
0
                if i + digests_per_block < starts.len() {
419
0
                    let first = starts[i];
420
0
                    let middle = starts[i + digests_per_block];
421
0
                    let last = if i + 2 * digests_per_block < starts.len() {
422
0
                        starts[i + 2 * digests_per_block]
423
                    } else {
424
0
                        centroids.len()
425
                    };
426
427
0
                    debug_assert!(first <= middle && middle <= last);
428
0
                    Self::external_merge(&mut centroids, first, middle, last);
429
0
                }
430
            }
431
432
0
            digests_per_block *= 2;
433
        }
434
435
0
        let mut result = TDigest::new(max_size);
436
0
        let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
437
0
438
0
        let mut k_limit = 1;
439
0
        let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
440
0
441
0
        let mut iter_centroids = centroids.iter_mut();
442
0
        let mut curr = iter_centroids.next().unwrap();
443
0
        let mut weight_so_far = curr.weight();
444
0
        let mut sums_to_merge = 0_f64;
445
0
        let mut weights_to_merge = 0_f64;
446
447
0
        for centroid in iter_centroids {
448
0
            weight_so_far += centroid.weight();
449
0
450
0
            if weight_so_far <= q_limit_times_count {
451
0
                sums_to_merge += centroid.mean() * centroid.weight();
452
0
                weights_to_merge += centroid.weight();
453
0
            } else {
454
0
                result.sum += curr.add(sums_to_merge, weights_to_merge);
455
0
                sums_to_merge = 0_f64;
456
0
                weights_to_merge = 0_f64;
457
0
                compressed.push(curr.clone());
458
0
                q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
459
0
                k_limit += 1;
460
0
                curr = centroid;
461
0
            }
462
        }
463
464
0
        result.sum += curr.add(sums_to_merge, weights_to_merge);
465
0
        compressed.push(curr.clone());
466
0
        compressed.shrink_to_fit();
467
0
        compressed.sort();
468
0
469
0
        result.count = count;
470
0
        result.min = min;
471
0
        result.max = max;
472
0
        result.centroids = compressed;
473
0
        result
474
0
    }
475
476
    /// To estimate the value located at `q` quantile
477
0
    pub fn estimate_quantile(&self, q: f64) -> f64 {
478
0
        if self.centroids.is_empty() {
479
0
            return 0.0;
480
0
        }
481
0
482
0
        let rank = q * self.count as f64;
483
0
484
0
        let mut pos: usize;
485
0
        let mut t;
486
0
        if q > 0.5 {
487
0
            if q >= 1.0 {
488
0
                return self.max();
489
0
            }
490
0
491
0
            pos = 0;
492
0
            t = self.count as f64;
493
494
0
            for (k, centroid) in self.centroids.iter().enumerate().rev() {
495
0
                t -= centroid.weight();
496
0
497
0
                if rank >= t {
498
0
                    pos = k;
499
0
                    break;
500
0
                }
501
            }
502
        } else {
503
0
            if q <= 0.0 {
504
0
                return self.min();
505
0
            }
506
0
507
0
            pos = self.centroids.len() - 1;
508
0
            t = 0_f64;
509
510
0
            for (k, centroid) in self.centroids.iter().enumerate() {
511
0
                if rank < t + centroid.weight() {
512
0
                    pos = k;
513
0
                    break;
514
0
                }
515
0
516
0
                t += centroid.weight();
517
            }
518
        }
519
520
0
        let mut delta = 0_f64;
521
0
        let mut min = self.min;
522
0
        let mut max = self.max;
523
0
524
0
        if self.centroids.len() > 1 {
525
0
            if pos == 0 {
526
0
                delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean();
527
0
                max = self.centroids[pos + 1].mean();
528
0
            } else if pos == (self.centroids.len() - 1) {
529
0
                delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean();
530
0
                min = self.centroids[pos - 1].mean();
531
0
            } else {
532
0
                delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean())
533
0
                    / 2.0;
534
0
                min = self.centroids[pos - 1].mean();
535
0
                max = self.centroids[pos + 1].mean();
536
0
            }
537
0
        }
538
539
0
        let value = self.centroids[pos].mean()
540
0
            + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta;
541
0
542
0
        // In `merge_digests()`: `min` is initialized to Inf, `max` is initialized to -Inf
543
0
        // and gets updated according to different `TDigest`s
544
0
        // However, `min`/`max` won't get updated if there is only one `NaN` within `TDigest`
545
0
        // The following two checks is for such edge case
546
0
        if !min.is_finite() && min.is_sign_positive() {
547
0
            min = f64::NEG_INFINITY;
548
0
        }
549
550
0
        if !max.is_finite() && max.is_sign_negative() {
551
0
            max = f64::INFINITY;
552
0
        }
553
554
0
        Self::clamp(value, min, max)
555
0
    }
556
557
    /// This method decomposes the [`TDigest`] and its [`Centroid`] instances
558
    /// into a series of primitive scalar values.
559
    ///
560
    /// First the values of the TDigest are packed, followed by the variable
561
    /// number of centroids packed into a [`ScalarValue::List`] of
562
    /// [`ScalarValue::Float64`]:
563
    ///
564
    /// ```text
565
    ///
566
    ///    ┌────────┬────────┬────────┬───────┬────────┬────────┐
567
    ///    │max_size│  sum   │ count  │  max  │  min   │centroid│
568
    ///    └────────┴────────┴────────┴───────┴────────┴────────┘
569
    ///                                                     │
570
    ///                               ┌─────────────────────┘
571
    ///                               ▼
572
    ///                          ┌ List ───┐
573
    ///                          │┌ ─ ─ ─ ┐│
574
    ///                          │  mean   │
575
    ///                          │├ ─ ─ ─ ┼│─ ─ Centroid 1
576
    ///                          │ weight  │
577
    ///                          │└ ─ ─ ─ ┘│
578
    ///                          │         │
579
    ///                          │┌ ─ ─ ─ ┐│
580
    ///                          │  mean   │
581
    ///                          │├ ─ ─ ─ ┼│─ ─ Centroid 2
582
    ///                          │ weight  │
583
    ///                          │└ ─ ─ ─ ┘│
584
    ///                          │         │
585
    ///                              ...
586
    ///
587
    /// ```
588
    ///
589
    /// The [`TDigest::from_scalar_state()`] method reverses this processes,
590
    /// consuming the output of this method and returning an unpacked
591
    /// [`TDigest`].
592
0
    pub fn to_scalar_state(&self) -> Vec<ScalarValue> {
593
0
        // Gather up all the centroids
594
0
        let centroids: Vec<ScalarValue> = self
595
0
            .centroids
596
0
            .iter()
597
0
            .flat_map(|c| [c.mean(), c.weight()])
598
0
            .map(|v| ScalarValue::Float64(Some(v)))
599
0
            .collect();
600
0
601
0
        let arr = ScalarValue::new_list_nullable(&centroids, &DataType::Float64);
602
0
603
0
        vec![
604
0
            ScalarValue::UInt64(Some(self.max_size as u64)),
605
0
            ScalarValue::Float64(Some(self.sum)),
606
0
            ScalarValue::UInt64(Some(self.count)),
607
0
            ScalarValue::Float64(Some(self.max)),
608
0
            ScalarValue::Float64(Some(self.min)),
609
0
            ScalarValue::List(arr),
610
0
        ]
611
0
    }
612
613
    /// Unpack the serialised state of a [`TDigest`] produced by
614
    /// [`Self::to_scalar_state()`].
615
    ///
616
    /// # Correctness
617
    ///
618
    /// Providing input to this method that was not obtained from
619
    /// [`Self::to_scalar_state()`] results in undefined behaviour and may
620
    /// panic.
621
0
    pub fn from_scalar_state(state: &[ScalarValue]) -> Self {
622
0
        assert_eq!(state.len(), 6, "invalid TDigest state");
623
624
0
        let max_size = match &state[0] {
625
0
            ScalarValue::UInt64(Some(v)) => *v as usize,
626
0
            v => panic!("invalid max_size type {v:?}"),
627
        };
628
629
0
        let centroids: Vec<_> = match &state[5] {
630
0
            ScalarValue::List(arr) => {
631
0
                let array = arr.values();
632
0
633
0
                let f64arr =
634
0
                    as_primitive_array::<Float64Type>(array).expect("expected f64 array");
635
0
                f64arr
636
0
                    .values()
637
0
                    .chunks(2)
638
0
                    .map(|v| Centroid::new(v[0], v[1]))
639
0
                    .collect()
640
            }
641
0
            v => panic!("invalid centroids type {v:?}"),
642
        };
643
644
0
        let max = cast_scalar_f64!(&state[3]);
645
0
        let min = cast_scalar_f64!(&state[4]);
646
647
0
        assert!(max.total_cmp(&min).is_ge());
648
649
        Self {
650
0
            max_size,
651
0
            sum: cast_scalar_f64!(state[1]),
652
0
            count: cast_scalar_u64!(&state[2]),
653
0
            max,
654
0
            min,
655
0
            centroids,
656
0
        }
657
0
    }
658
}
659
660
#[cfg(debug_assertions)]
661
0
fn is_sorted(values: &[f64]) -> bool {
662
0
    values.windows(2).all(|w| w[0].total_cmp(&w[1]).is_le())
663
0
}
664
665
#[cfg(test)]
666
mod tests {
667
    use super::*;
668
669
    // A macro to assert the specified `quantile` estimated by `t` is within the
670
    // allowable relative error bound.
671
    macro_rules! assert_error_bounds {
672
        ($t:ident, quantile = $quantile:literal, want = $want:literal) => {
673
            assert_error_bounds!(
674
                $t,
675
                quantile = $quantile,
676
                want = $want,
677
                allowable_error = 0.01
678
            )
679
        };
680
        ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => {
681
            let ans = $t.estimate_quantile($quantile);
682
            let expected: f64 = $want;
683
            let percentage: f64 = (expected - ans).abs() / expected;
684
            assert!(
685
                percentage < $re,
686
                "relative error {} is more than {}% (got quantile {}, want {})",
687
                percentage,
688
                $re,
689
                ans,
690
                expected
691
            );
692
        };
693
    }
694
695
    macro_rules! assert_state_roundtrip {
696
        ($t:ident) => {
697
            let state = $t.to_scalar_state();
698
            let other = TDigest::from_scalar_state(&state);
699
            assert_eq!($t, other);
700
        };
701
    }
702
703
    #[test]
704
    fn test_int64_uniform() {
705
        let values = (1i64..=1000).map(|v| v as f64).collect();
706
707
        let t = TDigest::new(100);
708
        let t = t.merge_unsorted_f64(values);
709
710
        assert_error_bounds!(t, quantile = 0.1, want = 100.0);
711
        assert_error_bounds!(t, quantile = 0.5, want = 500.0);
712
        assert_error_bounds!(t, quantile = 0.9, want = 900.0);
713
        assert_state_roundtrip!(t);
714
    }
715
716
    #[test]
717
    fn test_centroid_addition_regression() {
718
        // https://github.com/MnO2/t-digest/pull/1
719
720
        let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0];
721
        let mut t = TDigest::new(10);
722
723
        for v in vals {
724
            t = t.merge_unsorted_f64(vec![v]);
725
        }
726
727
        assert_error_bounds!(t, quantile = 0.5, want = 1.0);
728
        assert_error_bounds!(t, quantile = 0.95, want = 2.0);
729
        assert_state_roundtrip!(t);
730
    }
731
732
    #[test]
733
    fn test_merge_unsorted_against_uniform_distro() {
734
        let t = TDigest::new(100);
735
        let values: Vec<_> = (1..=1_000_000).map(f64::from).collect();
736
737
        let t = t.merge_unsorted_f64(values);
738
739
        assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0);
740
        assert_error_bounds!(t, quantile = 0.99, want = 990_000.0);
741
        assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
742
        assert_error_bounds!(t, quantile = 0.0, want = 1.0);
743
        assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
744
        assert_state_roundtrip!(t);
745
    }
746
747
    #[test]
748
    fn test_merge_unsorted_against_skewed_distro() {
749
        let t = TDigest::new(100);
750
        let mut values: Vec<_> = (1..=600_000).map(f64::from).collect();
751
        values.resize(1_000_000, 1_000_000_f64);
752
753
        let t = t.merge_unsorted_f64(values);
754
755
        assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0);
756
        assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
757
        assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
758
        assert_state_roundtrip!(t);
759
    }
760
761
    #[test]
762
    fn test_merge_digests() {
763
        let mut digests: Vec<TDigest> = Vec::new();
764
765
        for _ in 1..=100 {
766
            let t = TDigest::new(100);
767
            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
768
            let t = t.merge_unsorted_f64(values);
769
            digests.push(t)
770
        }
771
772
        let t = TDigest::merge_digests(&digests);
773
774
        assert_error_bounds!(t, quantile = 1.0, want = 1000.0);
775
        assert_error_bounds!(t, quantile = 0.99, want = 990.0);
776
        assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2);
777
        assert_error_bounds!(t, quantile = 0.0, want = 1.0);
778
        assert_error_bounds!(t, quantile = 0.5, want = 500.0);
779
        assert_state_roundtrip!(t);
780
    }
781
782
    #[test]
783
    fn test_size() {
784
        let t = TDigest::new(10);
785
        let t = t.merge_unsorted_f64(vec![0.0, 1.0]);
786
787
        assert_eq!(t.size(), 96);
788
    }
789
}