/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/tdigest.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one or more |
2 | | // contributor license agreements. See the NOTICE file distributed with this |
3 | | // work for additional information regarding copyright ownership. The ASF |
4 | | // licenses this file to you under the Apache License, Version 2.0 (the |
5 | | // "License"); you may not use this file except in compliance with the License. |
6 | | // You may obtain a copy of the License at |
7 | | // |
8 | | // http://www.apache.org/licenses/LICENSE-2.0 |
9 | | // |
10 | | // Unless required by applicable law or agreed to in writing, software |
11 | | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
12 | | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
13 | | // License for the specific language governing permissions and limitations under |
14 | | // the License. |
15 | | |
16 | | //! An implementation of the [TDigest sketch algorithm] providing approximate |
17 | | //! quantile calculations. |
18 | | //! |
19 | | //! The TDigest code in this module is modified from |
20 | | //! <https://github.com/MnO2/t-digest>, itself a rust reimplementation of |
21 | | //! [Facebook's Folly TDigest] implementation. |
22 | | //! |
23 | | //! Alterations include reduction of runtime heap allocations, broader type |
24 | | //! support, (de-)serialisation support, reduced type conversions and null value |
25 | | //! tolerance. |
26 | | //! |
27 | | //! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 |
28 | | //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h |
29 | | |
30 | | use arrow::datatypes::DataType; |
31 | | use arrow::datatypes::Float64Type; |
32 | | use datafusion_common::cast::as_primitive_array; |
33 | | use datafusion_common::Result; |
34 | | use datafusion_common::ScalarValue; |
35 | | use std::cmp::Ordering; |
36 | | |
37 | | pub const DEFAULT_MAX_SIZE: usize = 100; |
38 | | |
39 | | // Cast a non-null [`ScalarValue::Float64`] to an [`f64`], or |
40 | | // panic. |
41 | | macro_rules! cast_scalar_f64 { |
42 | | ($value:expr ) => { |
43 | | match &$value { |
44 | | ScalarValue::Float64(Some(v)) => *v, |
45 | | v => panic!("invalid type {:?}", v), |
46 | | } |
47 | | }; |
48 | | } |
49 | | |
50 | | // Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or |
51 | | // panic. |
52 | | macro_rules! cast_scalar_u64 { |
53 | | ($value:expr ) => { |
54 | | match &$value { |
55 | | ScalarValue::UInt64(Some(v)) => *v, |
56 | | v => panic!("invalid type {:?}", v), |
57 | | } |
58 | | }; |
59 | | } |
60 | | |
61 | | /// This trait is implemented for each type a [`TDigest`] can operate on, |
62 | | /// allowing it to support both numerical rust types (obtained from |
63 | | /// `PrimitiveArray` instances), and [`ScalarValue`] instances. |
64 | | pub trait TryIntoF64 { |
65 | | /// A fallible conversion of a possibly null `self` into a [`f64`]. |
66 | | /// |
67 | | /// If `self` is null, this method must return `Ok(None)`. |
68 | | /// |
69 | | /// If `self` cannot be coerced to the desired type, this method must return |
70 | | /// an `Err` variant. |
71 | | fn try_as_f64(&self) -> Result<Option<f64>>; |
72 | | } |
73 | | |
74 | | /// Generate an infallible conversion from `type` to an [`f64`]. |
75 | | macro_rules! impl_try_ordered_f64 { |
76 | | ($type:ty) => { |
77 | | impl TryIntoF64 for $type { |
78 | 0 | fn try_as_f64(&self) -> Result<Option<f64>> { |
79 | 0 | Ok(Some(*self as f64)) |
80 | 0 | } |
81 | | } |
82 | | }; |
83 | | } |
84 | | |
85 | | impl_try_ordered_f64!(f64); |
86 | | impl_try_ordered_f64!(f32); |
87 | | impl_try_ordered_f64!(i64); |
88 | | impl_try_ordered_f64!(i32); |
89 | | impl_try_ordered_f64!(i16); |
90 | | impl_try_ordered_f64!(i8); |
91 | | impl_try_ordered_f64!(u64); |
92 | | impl_try_ordered_f64!(u32); |
93 | | impl_try_ordered_f64!(u16); |
94 | | impl_try_ordered_f64!(u8); |
95 | | |
96 | | /// Centroid implementation to the cluster mentioned in the paper. |
97 | | #[derive(Debug, PartialEq, Clone)] |
98 | | pub struct Centroid { |
99 | | mean: f64, |
100 | | weight: f64, |
101 | | } |
102 | | |
103 | | impl PartialOrd for Centroid { |
104 | 0 | fn partial_cmp(&self, other: &Centroid) -> Option<Ordering> { |
105 | 0 | Some(self.cmp(other)) |
106 | 0 | } |
107 | | } |
108 | | |
109 | | impl Eq for Centroid {} |
110 | | |
111 | | impl Ord for Centroid { |
112 | 0 | fn cmp(&self, other: &Centroid) -> Ordering { |
113 | 0 | self.mean.total_cmp(&other.mean) |
114 | 0 | } |
115 | | } |
116 | | |
117 | | impl Centroid { |
118 | 0 | pub fn new(mean: f64, weight: f64) -> Self { |
119 | 0 | Centroid { mean, weight } |
120 | 0 | } |
121 | | |
122 | | #[inline] |
123 | 0 | pub fn mean(&self) -> f64 { |
124 | 0 | self.mean |
125 | 0 | } |
126 | | |
127 | | #[inline] |
128 | 0 | pub fn weight(&self) -> f64 { |
129 | 0 | self.weight |
130 | 0 | } |
131 | | |
132 | 0 | pub fn add(&mut self, sum: f64, weight: f64) -> f64 { |
133 | 0 | let new_sum = sum + self.weight * self.mean; |
134 | 0 | let new_weight = self.weight + weight; |
135 | 0 | self.weight = new_weight; |
136 | 0 | self.mean = new_sum / new_weight; |
137 | 0 | new_sum |
138 | 0 | } |
139 | | } |
140 | | |
141 | | impl Default for Centroid { |
142 | 0 | fn default() -> Self { |
143 | 0 | Centroid { |
144 | 0 | mean: 0_f64, |
145 | 0 | weight: 1_f64, |
146 | 0 | } |
147 | 0 | } |
148 | | } |
149 | | |
150 | | /// T-Digest to be operated on. |
151 | | #[derive(Debug, PartialEq, Clone)] |
152 | | pub struct TDigest { |
153 | | centroids: Vec<Centroid>, |
154 | | max_size: usize, |
155 | | sum: f64, |
156 | | count: u64, |
157 | | max: f64, |
158 | | min: f64, |
159 | | } |
160 | | |
161 | | impl TDigest { |
162 | 0 | pub fn new(max_size: usize) -> Self { |
163 | 0 | TDigest { |
164 | 0 | centroids: Vec::new(), |
165 | 0 | max_size, |
166 | 0 | sum: 0_f64, |
167 | 0 | count: 0, |
168 | 0 | max: f64::NAN, |
169 | 0 | min: f64::NAN, |
170 | 0 | } |
171 | 0 | } |
172 | | |
173 | 0 | pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self { |
174 | 0 | TDigest { |
175 | 0 | centroids: vec![centroid.clone()], |
176 | 0 | max_size, |
177 | 0 | sum: centroid.mean * centroid.weight, |
178 | 0 | count: 1, |
179 | 0 | max: centroid.mean, |
180 | 0 | min: centroid.mean, |
181 | 0 | } |
182 | 0 | } |
183 | | |
184 | | #[inline] |
185 | 0 | pub fn count(&self) -> u64 { |
186 | 0 | self.count |
187 | 0 | } |
188 | | |
189 | | #[inline] |
190 | 0 | pub fn max(&self) -> f64 { |
191 | 0 | self.max |
192 | 0 | } |
193 | | |
194 | | #[inline] |
195 | 0 | pub fn min(&self) -> f64 { |
196 | 0 | self.min |
197 | 0 | } |
198 | | |
199 | | #[inline] |
200 | 0 | pub fn max_size(&self) -> usize { |
201 | 0 | self.max_size |
202 | 0 | } |
203 | | |
204 | | /// Size in bytes including `Self`. |
205 | 0 | pub fn size(&self) -> usize { |
206 | 0 | std::mem::size_of_val(self) |
207 | 0 | + (std::mem::size_of::<Centroid>() * self.centroids.capacity()) |
208 | 0 | } |
209 | | } |
210 | | |
211 | | impl Default for TDigest { |
212 | 0 | fn default() -> Self { |
213 | 0 | TDigest { |
214 | 0 | centroids: Vec::new(), |
215 | 0 | max_size: 100, |
216 | 0 | sum: 0_f64, |
217 | 0 | count: 0, |
218 | 0 | max: f64::NAN, |
219 | 0 | min: f64::NAN, |
220 | 0 | } |
221 | 0 | } |
222 | | } |
223 | | |
224 | | impl TDigest { |
225 | 0 | fn k_to_q(k: u64, d: usize) -> f64 { |
226 | 0 | let k_div_d = k as f64 / d as f64; |
227 | 0 | if k_div_d >= 0.5 { |
228 | 0 | let base = 1.0 - k_div_d; |
229 | 0 | 1.0 - 2.0 * base * base |
230 | | } else { |
231 | 0 | 2.0 * k_div_d * k_div_d |
232 | | } |
233 | 0 | } |
234 | | |
235 | 0 | fn clamp(v: f64, lo: f64, hi: f64) -> f64 { |
236 | 0 | if lo.is_nan() || hi.is_nan() { |
237 | 0 | return v; |
238 | 0 | } |
239 | 0 | v.clamp(lo, hi) |
240 | 0 | } |
241 | | |
242 | | // public for testing in other modules |
243 | 0 | pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest { |
244 | 0 | let mut values = unsorted_values; |
245 | 0 | values.sort_by(|a, b| a.total_cmp(b)); |
246 | 0 | self.merge_sorted_f64(&values) |
247 | 0 | } |
248 | | |
249 | 0 | pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest { |
250 | 0 | #[cfg(debug_assertions)] |
251 | 0 | debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); |
252 | | |
253 | 0 | if sorted_values.is_empty() { |
254 | 0 | return self.clone(); |
255 | 0 | } |
256 | 0 |
|
257 | 0 | let mut result = TDigest::new(self.max_size()); |
258 | 0 | result.count = self.count() + sorted_values.len() as u64; |
259 | 0 |
|
260 | 0 | let maybe_min = *sorted_values.first().unwrap(); |
261 | 0 | let maybe_max = *sorted_values.last().unwrap(); |
262 | 0 |
|
263 | 0 | if self.count() > 0 { |
264 | 0 | result.min = self.min.min(maybe_min); |
265 | 0 | result.max = self.max.max(maybe_max); |
266 | 0 | } else { |
267 | 0 | result.min = maybe_min; |
268 | 0 | result.max = maybe_max; |
269 | 0 | } |
270 | | |
271 | 0 | let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size); |
272 | 0 |
|
273 | 0 | let mut k_limit: u64 = 1; |
274 | 0 | let mut q_limit_times_count = |
275 | 0 | Self::k_to_q(k_limit, self.max_size) * result.count() as f64; |
276 | 0 | k_limit += 1; |
277 | 0 |
|
278 | 0 | let mut iter_centroids = self.centroids.iter().peekable(); |
279 | 0 | let mut iter_sorted_values = sorted_values.iter().peekable(); |
280 | | |
281 | 0 | let mut curr: Centroid = if let Some(c) = iter_centroids.peek() { |
282 | 0 | let curr = **iter_sorted_values.peek().unwrap(); |
283 | 0 | if c.mean() < curr { |
284 | 0 | iter_centroids.next().unwrap().clone() |
285 | | } else { |
286 | 0 | Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) |
287 | | } |
288 | | } else { |
289 | 0 | Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) |
290 | | }; |
291 | | |
292 | 0 | let mut weight_so_far = curr.weight(); |
293 | 0 |
|
294 | 0 | let mut sums_to_merge = 0_f64; |
295 | 0 | let mut weights_to_merge = 0_f64; |
296 | | |
297 | 0 | while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() { |
298 | 0 | let next: Centroid = if let Some(c) = iter_centroids.peek() { |
299 | 0 | if iter_sorted_values.peek().is_none() |
300 | 0 | || c.mean() < **iter_sorted_values.peek().unwrap() |
301 | | { |
302 | 0 | iter_centroids.next().unwrap().clone() |
303 | | } else { |
304 | 0 | Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) |
305 | | } |
306 | | } else { |
307 | 0 | Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) |
308 | | }; |
309 | | |
310 | 0 | let next_sum = next.mean() * next.weight(); |
311 | 0 | weight_so_far += next.weight(); |
312 | 0 |
|
313 | 0 | if weight_so_far <= q_limit_times_count { |
314 | 0 | sums_to_merge += next_sum; |
315 | 0 | weights_to_merge += next.weight(); |
316 | 0 | } else { |
317 | 0 | result.sum += curr.add(sums_to_merge, weights_to_merge); |
318 | 0 | sums_to_merge = 0_f64; |
319 | 0 | weights_to_merge = 0_f64; |
320 | 0 |
|
321 | 0 | compressed.push(curr.clone()); |
322 | 0 | q_limit_times_count = |
323 | 0 | Self::k_to_q(k_limit, self.max_size) * result.count() as f64; |
324 | 0 | k_limit += 1; |
325 | 0 | curr = next; |
326 | 0 | } |
327 | | } |
328 | | |
329 | 0 | result.sum += curr.add(sums_to_merge, weights_to_merge); |
330 | 0 | compressed.push(curr); |
331 | 0 | compressed.shrink_to_fit(); |
332 | 0 | compressed.sort(); |
333 | 0 |
|
334 | 0 | result.centroids = compressed; |
335 | 0 | result |
336 | 0 | } |
337 | | |
338 | 0 | fn external_merge( |
339 | 0 | centroids: &mut [Centroid], |
340 | 0 | first: usize, |
341 | 0 | middle: usize, |
342 | 0 | last: usize, |
343 | 0 | ) { |
344 | 0 | let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len()); |
345 | 0 |
|
346 | 0 | let mut i = first; |
347 | 0 | let mut j = middle; |
348 | | |
349 | 0 | while i < middle && j < last { |
350 | 0 | match centroids[i].cmp(¢roids[j]) { |
351 | 0 | Ordering::Less => { |
352 | 0 | result.push(centroids[i].clone()); |
353 | 0 | i += 1; |
354 | 0 | } |
355 | 0 | Ordering::Greater => { |
356 | 0 | result.push(centroids[j].clone()); |
357 | 0 | j += 1; |
358 | 0 | } |
359 | 0 | Ordering::Equal => { |
360 | 0 | result.push(centroids[i].clone()); |
361 | 0 | i += 1; |
362 | 0 | } |
363 | | } |
364 | | } |
365 | | |
366 | 0 | while i < middle { |
367 | 0 | result.push(centroids[i].clone()); |
368 | 0 | i += 1; |
369 | 0 | } |
370 | | |
371 | 0 | while j < last { |
372 | 0 | result.push(centroids[j].clone()); |
373 | 0 | j += 1; |
374 | 0 | } |
375 | | |
376 | 0 | i = first; |
377 | 0 | for centroid in result.into_iter() { |
378 | 0 | centroids[i] = centroid; |
379 | 0 | i += 1; |
380 | 0 | } |
381 | 0 | } |
382 | | |
383 | | // Merge multiple T-Digests |
384 | 0 | pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>) -> TDigest { |
385 | 0 | let digests = digests.into_iter().collect::<Vec<_>>(); |
386 | 0 | let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); |
387 | 0 | if n_centroids == 0 { |
388 | 0 | return TDigest::default(); |
389 | 0 | } |
390 | 0 |
|
391 | 0 | let max_size = digests.first().unwrap().max_size; |
392 | 0 | let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids); |
393 | 0 | let mut starts: Vec<usize> = Vec::with_capacity(digests.len()); |
394 | 0 |
|
395 | 0 | let mut count = 0; |
396 | 0 | let mut min = f64::INFINITY; |
397 | 0 | let mut max = f64::NEG_INFINITY; |
398 | 0 |
|
399 | 0 | let mut start: usize = 0; |
400 | 0 | for digest in digests.iter() { |
401 | 0 | starts.push(start); |
402 | 0 |
|
403 | 0 | let curr_count = digest.count(); |
404 | 0 | if curr_count > 0 { |
405 | 0 | min = min.min(digest.min); |
406 | 0 | max = max.max(digest.max); |
407 | 0 | count += curr_count; |
408 | 0 | for centroid in &digest.centroids { |
409 | 0 | centroids.push(centroid.clone()); |
410 | 0 | start += 1; |
411 | 0 | } |
412 | 0 | } |
413 | | } |
414 | | |
415 | 0 | let mut digests_per_block: usize = 1; |
416 | 0 | while digests_per_block < starts.len() { |
417 | 0 | for i in (0..starts.len()).step_by(digests_per_block * 2) { |
418 | 0 | if i + digests_per_block < starts.len() { |
419 | 0 | let first = starts[i]; |
420 | 0 | let middle = starts[i + digests_per_block]; |
421 | 0 | let last = if i + 2 * digests_per_block < starts.len() { |
422 | 0 | starts[i + 2 * digests_per_block] |
423 | | } else { |
424 | 0 | centroids.len() |
425 | | }; |
426 | | |
427 | 0 | debug_assert!(first <= middle && middle <= last); |
428 | 0 | Self::external_merge(&mut centroids, first, middle, last); |
429 | 0 | } |
430 | | } |
431 | | |
432 | 0 | digests_per_block *= 2; |
433 | | } |
434 | | |
435 | 0 | let mut result = TDigest::new(max_size); |
436 | 0 | let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size); |
437 | 0 |
|
438 | 0 | let mut k_limit = 1; |
439 | 0 | let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; |
440 | 0 |
|
441 | 0 | let mut iter_centroids = centroids.iter_mut(); |
442 | 0 | let mut curr = iter_centroids.next().unwrap(); |
443 | 0 | let mut weight_so_far = curr.weight(); |
444 | 0 | let mut sums_to_merge = 0_f64; |
445 | 0 | let mut weights_to_merge = 0_f64; |
446 | | |
447 | 0 | for centroid in iter_centroids { |
448 | 0 | weight_so_far += centroid.weight(); |
449 | 0 |
|
450 | 0 | if weight_so_far <= q_limit_times_count { |
451 | 0 | sums_to_merge += centroid.mean() * centroid.weight(); |
452 | 0 | weights_to_merge += centroid.weight(); |
453 | 0 | } else { |
454 | 0 | result.sum += curr.add(sums_to_merge, weights_to_merge); |
455 | 0 | sums_to_merge = 0_f64; |
456 | 0 | weights_to_merge = 0_f64; |
457 | 0 | compressed.push(curr.clone()); |
458 | 0 | q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; |
459 | 0 | k_limit += 1; |
460 | 0 | curr = centroid; |
461 | 0 | } |
462 | | } |
463 | | |
464 | 0 | result.sum += curr.add(sums_to_merge, weights_to_merge); |
465 | 0 | compressed.push(curr.clone()); |
466 | 0 | compressed.shrink_to_fit(); |
467 | 0 | compressed.sort(); |
468 | 0 |
|
469 | 0 | result.count = count; |
470 | 0 | result.min = min; |
471 | 0 | result.max = max; |
472 | 0 | result.centroids = compressed; |
473 | 0 | result |
474 | 0 | } |
475 | | |
476 | | /// To estimate the value located at `q` quantile |
477 | 0 | pub fn estimate_quantile(&self, q: f64) -> f64 { |
478 | 0 | if self.centroids.is_empty() { |
479 | 0 | return 0.0; |
480 | 0 | } |
481 | 0 |
|
482 | 0 | let rank = q * self.count as f64; |
483 | 0 |
|
484 | 0 | let mut pos: usize; |
485 | 0 | let mut t; |
486 | 0 | if q > 0.5 { |
487 | 0 | if q >= 1.0 { |
488 | 0 | return self.max(); |
489 | 0 | } |
490 | 0 |
|
491 | 0 | pos = 0; |
492 | 0 | t = self.count as f64; |
493 | | |
494 | 0 | for (k, centroid) in self.centroids.iter().enumerate().rev() { |
495 | 0 | t -= centroid.weight(); |
496 | 0 |
|
497 | 0 | if rank >= t { |
498 | 0 | pos = k; |
499 | 0 | break; |
500 | 0 | } |
501 | | } |
502 | | } else { |
503 | 0 | if q <= 0.0 { |
504 | 0 | return self.min(); |
505 | 0 | } |
506 | 0 |
|
507 | 0 | pos = self.centroids.len() - 1; |
508 | 0 | t = 0_f64; |
509 | | |
510 | 0 | for (k, centroid) in self.centroids.iter().enumerate() { |
511 | 0 | if rank < t + centroid.weight() { |
512 | 0 | pos = k; |
513 | 0 | break; |
514 | 0 | } |
515 | 0 |
|
516 | 0 | t += centroid.weight(); |
517 | | } |
518 | | } |
519 | | |
520 | 0 | let mut delta = 0_f64; |
521 | 0 | let mut min = self.min; |
522 | 0 | let mut max = self.max; |
523 | 0 |
|
524 | 0 | if self.centroids.len() > 1 { |
525 | 0 | if pos == 0 { |
526 | 0 | delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean(); |
527 | 0 | max = self.centroids[pos + 1].mean(); |
528 | 0 | } else if pos == (self.centroids.len() - 1) { |
529 | 0 | delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean(); |
530 | 0 | min = self.centroids[pos - 1].mean(); |
531 | 0 | } else { |
532 | 0 | delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean()) |
533 | 0 | / 2.0; |
534 | 0 | min = self.centroids[pos - 1].mean(); |
535 | 0 | max = self.centroids[pos + 1].mean(); |
536 | 0 | } |
537 | 0 | } |
538 | | |
539 | 0 | let value = self.centroids[pos].mean() |
540 | 0 | + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta; |
541 | 0 |
|
542 | 0 | // In `merge_digests()`: `min` is initialized to Inf, `max` is initialized to -Inf |
543 | 0 | // and gets updated according to different `TDigest`s |
544 | 0 | // However, `min`/`max` won't get updated if there is only one `NaN` within `TDigest` |
545 | 0 | // The following two checks is for such edge case |
546 | 0 | if !min.is_finite() && min.is_sign_positive() { |
547 | 0 | min = f64::NEG_INFINITY; |
548 | 0 | } |
549 | | |
550 | 0 | if !max.is_finite() && max.is_sign_negative() { |
551 | 0 | max = f64::INFINITY; |
552 | 0 | } |
553 | | |
554 | 0 | Self::clamp(value, min, max) |
555 | 0 | } |
556 | | |
557 | | /// This method decomposes the [`TDigest`] and its [`Centroid`] instances |
558 | | /// into a series of primitive scalar values. |
559 | | /// |
560 | | /// First the values of the TDigest are packed, followed by the variable |
561 | | /// number of centroids packed into a [`ScalarValue::List`] of |
562 | | /// [`ScalarValue::Float64`]: |
563 | | /// |
564 | | /// ```text |
565 | | /// |
566 | | /// ┌────────┬────────┬────────┬───────┬────────┬────────┐ |
567 | | /// │max_size│ sum │ count │ max │ min │centroid│ |
568 | | /// └────────┴────────┴────────┴───────┴────────┴────────┘ |
569 | | /// │ |
570 | | /// ┌─────────────────────┘ |
571 | | /// ▼ |
572 | | /// ┌ List ───┐ |
573 | | /// │┌ ─ ─ ─ ┐│ |
574 | | /// │ mean │ |
575 | | /// │├ ─ ─ ─ ┼│─ ─ Centroid 1 |
576 | | /// │ weight │ |
577 | | /// │└ ─ ─ ─ ┘│ |
578 | | /// │ │ |
579 | | /// │┌ ─ ─ ─ ┐│ |
580 | | /// │ mean │ |
581 | | /// │├ ─ ─ ─ ┼│─ ─ Centroid 2 |
582 | | /// │ weight │ |
583 | | /// │└ ─ ─ ─ ┘│ |
584 | | /// │ │ |
585 | | /// ... |
586 | | /// |
587 | | /// ``` |
588 | | /// |
589 | | /// The [`TDigest::from_scalar_state()`] method reverses this processes, |
590 | | /// consuming the output of this method and returning an unpacked |
591 | | /// [`TDigest`]. |
592 | 0 | pub fn to_scalar_state(&self) -> Vec<ScalarValue> { |
593 | 0 | // Gather up all the centroids |
594 | 0 | let centroids: Vec<ScalarValue> = self |
595 | 0 | .centroids |
596 | 0 | .iter() |
597 | 0 | .flat_map(|c| [c.mean(), c.weight()]) |
598 | 0 | .map(|v| ScalarValue::Float64(Some(v))) |
599 | 0 | .collect(); |
600 | 0 |
|
601 | 0 | let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64); |
602 | 0 |
|
603 | 0 | vec![ |
604 | 0 | ScalarValue::UInt64(Some(self.max_size as u64)), |
605 | 0 | ScalarValue::Float64(Some(self.sum)), |
606 | 0 | ScalarValue::UInt64(Some(self.count)), |
607 | 0 | ScalarValue::Float64(Some(self.max)), |
608 | 0 | ScalarValue::Float64(Some(self.min)), |
609 | 0 | ScalarValue::List(arr), |
610 | 0 | ] |
611 | 0 | } |
612 | | |
613 | | /// Unpack the serialised state of a [`TDigest`] produced by |
614 | | /// [`Self::to_scalar_state()`]. |
615 | | /// |
616 | | /// # Correctness |
617 | | /// |
618 | | /// Providing input to this method that was not obtained from |
619 | | /// [`Self::to_scalar_state()`] results in undefined behaviour and may |
620 | | /// panic. |
621 | 0 | pub fn from_scalar_state(state: &[ScalarValue]) -> Self { |
622 | 0 | assert_eq!(state.len(), 6, "invalid TDigest state"); |
623 | | |
624 | 0 | let max_size = match &state[0] { |
625 | 0 | ScalarValue::UInt64(Some(v)) => *v as usize, |
626 | 0 | v => panic!("invalid max_size type {v:?}"), |
627 | | }; |
628 | | |
629 | 0 | let centroids: Vec<_> = match &state[5] { |
630 | 0 | ScalarValue::List(arr) => { |
631 | 0 | let array = arr.values(); |
632 | 0 |
|
633 | 0 | let f64arr = |
634 | 0 | as_primitive_array::<Float64Type>(array).expect("expected f64 array"); |
635 | 0 | f64arr |
636 | 0 | .values() |
637 | 0 | .chunks(2) |
638 | 0 | .map(|v| Centroid::new(v[0], v[1])) |
639 | 0 | .collect() |
640 | | } |
641 | 0 | v => panic!("invalid centroids type {v:?}"), |
642 | | }; |
643 | | |
644 | 0 | let max = cast_scalar_f64!(&state[3]); |
645 | 0 | let min = cast_scalar_f64!(&state[4]); |
646 | | |
647 | 0 | assert!(max.total_cmp(&min).is_ge()); |
648 | | |
649 | | Self { |
650 | 0 | max_size, |
651 | 0 | sum: cast_scalar_f64!(state[1]), |
652 | 0 | count: cast_scalar_u64!(&state[2]), |
653 | 0 | max, |
654 | 0 | min, |
655 | 0 | centroids, |
656 | 0 | } |
657 | 0 | } |
658 | | } |
659 | | |
660 | | #[cfg(debug_assertions)] |
661 | 0 | fn is_sorted(values: &[f64]) -> bool { |
662 | 0 | values.windows(2).all(|w| w[0].total_cmp(&w[1]).is_le()) |
663 | 0 | } |
664 | | |
665 | | #[cfg(test)] |
666 | | mod tests { |
667 | | use super::*; |
668 | | |
669 | | // A macro to assert the specified `quantile` estimated by `t` is within the |
670 | | // allowable relative error bound. |
671 | | macro_rules! assert_error_bounds { |
672 | | ($t:ident, quantile = $quantile:literal, want = $want:literal) => { |
673 | | assert_error_bounds!( |
674 | | $t, |
675 | | quantile = $quantile, |
676 | | want = $want, |
677 | | allowable_error = 0.01 |
678 | | ) |
679 | | }; |
680 | | ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => { |
681 | | let ans = $t.estimate_quantile($quantile); |
682 | | let expected: f64 = $want; |
683 | | let percentage: f64 = (expected - ans).abs() / expected; |
684 | | assert!( |
685 | | percentage < $re, |
686 | | "relative error {} is more than {}% (got quantile {}, want {})", |
687 | | percentage, |
688 | | $re, |
689 | | ans, |
690 | | expected |
691 | | ); |
692 | | }; |
693 | | } |
694 | | |
695 | | macro_rules! assert_state_roundtrip { |
696 | | ($t:ident) => { |
697 | | let state = $t.to_scalar_state(); |
698 | | let other = TDigest::from_scalar_state(&state); |
699 | | assert_eq!($t, other); |
700 | | }; |
701 | | } |
702 | | |
703 | | #[test] |
704 | | fn test_int64_uniform() { |
705 | | let values = (1i64..=1000).map(|v| v as f64).collect(); |
706 | | |
707 | | let t = TDigest::new(100); |
708 | | let t = t.merge_unsorted_f64(values); |
709 | | |
710 | | assert_error_bounds!(t, quantile = 0.1, want = 100.0); |
711 | | assert_error_bounds!(t, quantile = 0.5, want = 500.0); |
712 | | assert_error_bounds!(t, quantile = 0.9, want = 900.0); |
713 | | assert_state_roundtrip!(t); |
714 | | } |
715 | | |
716 | | #[test] |
717 | | fn test_centroid_addition_regression() { |
718 | | // https://github.com/MnO2/t-digest/pull/1 |
719 | | |
720 | | let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0]; |
721 | | let mut t = TDigest::new(10); |
722 | | |
723 | | for v in vals { |
724 | | t = t.merge_unsorted_f64(vec![v]); |
725 | | } |
726 | | |
727 | | assert_error_bounds!(t, quantile = 0.5, want = 1.0); |
728 | | assert_error_bounds!(t, quantile = 0.95, want = 2.0); |
729 | | assert_state_roundtrip!(t); |
730 | | } |
731 | | |
732 | | #[test] |
733 | | fn test_merge_unsorted_against_uniform_distro() { |
734 | | let t = TDigest::new(100); |
735 | | let values: Vec<_> = (1..=1_000_000).map(f64::from).collect(); |
736 | | |
737 | | let t = t.merge_unsorted_f64(values); |
738 | | |
739 | | assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0); |
740 | | assert_error_bounds!(t, quantile = 0.99, want = 990_000.0); |
741 | | assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); |
742 | | assert_error_bounds!(t, quantile = 0.0, want = 1.0); |
743 | | assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); |
744 | | assert_state_roundtrip!(t); |
745 | | } |
746 | | |
747 | | #[test] |
748 | | fn test_merge_unsorted_against_skewed_distro() { |
749 | | let t = TDigest::new(100); |
750 | | let mut values: Vec<_> = (1..=600_000).map(f64::from).collect(); |
751 | | values.resize(1_000_000, 1_000_000_f64); |
752 | | |
753 | | let t = t.merge_unsorted_f64(values); |
754 | | |
755 | | assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0); |
756 | | assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); |
757 | | assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); |
758 | | assert_state_roundtrip!(t); |
759 | | } |
760 | | |
761 | | #[test] |
762 | | fn test_merge_digests() { |
763 | | let mut digests: Vec<TDigest> = Vec::new(); |
764 | | |
765 | | for _ in 1..=100 { |
766 | | let t = TDigest::new(100); |
767 | | let values: Vec<_> = (1..=1_000).map(f64::from).collect(); |
768 | | let t = t.merge_unsorted_f64(values); |
769 | | digests.push(t) |
770 | | } |
771 | | |
772 | | let t = TDigest::merge_digests(&digests); |
773 | | |
774 | | assert_error_bounds!(t, quantile = 1.0, want = 1000.0); |
775 | | assert_error_bounds!(t, quantile = 0.99, want = 990.0); |
776 | | assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2); |
777 | | assert_error_bounds!(t, quantile = 0.0, want = 1.0); |
778 | | assert_error_bounds!(t, quantile = 0.5, want = 500.0); |
779 | | assert_state_roundtrip!(t); |
780 | | } |
781 | | |
782 | | #[test] |
783 | | fn test_size() { |
784 | | let t = TDigest::new(10); |
785 | | let t = t.merge_unsorted_f64(vec![0.0, 1.0]); |
786 | | |
787 | | assert_eq!(t.size(), 96); |
788 | | } |
789 | | } |