/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/topk/heap.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! A custom binary heap implementation for performant top K aggregation |
19 | | |
20 | | use arrow::datatypes::i256; |
21 | | use arrow_array::cast::AsArray; |
22 | | use arrow_array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; |
23 | | use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; |
24 | | use arrow_schema::DataType; |
25 | | use datafusion_common::DataFusionError; |
26 | | use datafusion_common::Result; |
27 | | use datafusion_physical_expr::aggregate::utils::adjust_output_array; |
28 | | use half::f16; |
29 | | use std::cmp::Ordering; |
30 | | use std::fmt::{Debug, Display, Formatter}; |
31 | | use std::sync::Arc; |
32 | | |
33 | | /// A custom version of `Ord` that only exists to we can implement it for the Values in our heap |
34 | | pub trait Comparable { |
35 | | fn comp(&self, other: &Self) -> Ordering; |
36 | | } |
37 | | |
38 | | impl Comparable for Option<String> { |
39 | 0 | fn comp(&self, other: &Self) -> Ordering { |
40 | 0 | self.cmp(other) |
41 | 0 | } |
42 | | } |
43 | | |
44 | | /// A "type alias" for Values which are stored in our heap |
45 | | pub trait ValueType: Comparable + Clone + Debug {} |
46 | | |
47 | | impl<T> ValueType for T where T: Comparable + Clone + Debug {} |
48 | | |
49 | | /// An entry in our heap, which contains both the value and a index into an external HashTable |
50 | | struct HeapItem<VAL: ValueType> { |
51 | | val: VAL, |
52 | | map_idx: usize, |
53 | | } |
54 | | |
55 | | /// A custom heap implementation that allows several things that couldn't be achieved with |
56 | | /// `collections::BinaryHeap`: |
57 | | /// 1. It allows values to be updated at arbitrary positions (when group values change) |
58 | | /// 2. It can be either a min or max heap |
59 | | /// 3. It can use our `HeapItem` type & `Comparable` trait |
60 | | /// 4. It is specialized to grow to a certain limit, then always replace without grow & shrink |
61 | | struct TopKHeap<VAL: ValueType> { |
62 | | desc: bool, |
63 | | len: usize, |
64 | | capacity: usize, |
65 | | heap: Vec<Option<HeapItem<VAL>>>, |
66 | | } |
67 | | |
68 | | /// An interface to hide the generic type signature of TopKHeap behind arrow arrays |
69 | | pub trait ArrowHeap { |
70 | | fn set_batch(&mut self, vals: ArrayRef); |
71 | | fn is_worse(&self, idx: usize) -> bool; |
72 | | fn worst_map_idx(&self) -> usize; |
73 | | fn renumber(&mut self, heap_to_map: &[(usize, usize)]); |
74 | | fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>); |
75 | | fn replace_if_better( |
76 | | &mut self, |
77 | | heap_idx: usize, |
78 | | row_idx: usize, |
79 | | map: &mut Vec<(usize, usize)>, |
80 | | ); |
81 | | fn drain(&mut self) -> (ArrayRef, Vec<usize>); |
82 | | } |
83 | | |
84 | | /// An implementation of `ArrowHeap` that deals with primitive values |
85 | | pub struct PrimitiveHeap<VAL: ArrowPrimitiveType> |
86 | | where |
87 | | <VAL as ArrowPrimitiveType>::Native: Comparable, |
88 | | { |
89 | | batch: ArrayRef, |
90 | | heap: TopKHeap<VAL::Native>, |
91 | | desc: bool, |
92 | | data_type: DataType, |
93 | | } |
94 | | |
95 | | impl<VAL: ArrowPrimitiveType> PrimitiveHeap<VAL> |
96 | | where |
97 | | <VAL as ArrowPrimitiveType>::Native: Comparable, |
98 | | { |
99 | 10 | pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { |
100 | 10 | let owned: ArrayRef = Arc::new(PrimitiveArray::<VAL>::builder(0).finish()); |
101 | 10 | Self { |
102 | 10 | batch: owned, |
103 | 10 | heap: TopKHeap::new(limit, desc), |
104 | 10 | desc, |
105 | 10 | data_type, |
106 | 10 | } |
107 | 10 | } |
108 | | } |
109 | | |
110 | | impl<VAL: ArrowPrimitiveType> ArrowHeap for PrimitiveHeap<VAL> |
111 | | where |
112 | | <VAL as ArrowPrimitiveType>::Native: Comparable, |
113 | | { |
114 | 10 | fn set_batch(&mut self, vals: ArrayRef) { |
115 | 10 | self.batch = vals; |
116 | 10 | } |
117 | | |
118 | 20 | fn is_worse(&self, row_idx: usize) -> bool { |
119 | 20 | if !self.heap.is_full() { |
120 | 15 | return false; |
121 | 5 | } |
122 | 5 | let vals = self.batch.as_primitive::<VAL>(); |
123 | 5 | let new_val = vals.value(row_idx); |
124 | 5 | let worst_val = self.heap.worst_val().expect("Missing root"); |
125 | 5 | (!self.desc && new_val > *worst_val2 ) || (self.desc4 && new_val < *worst_val3 ) |
126 | 20 | } |
127 | | |
128 | 18 | fn worst_map_idx(&self) -> usize { |
129 | 18 | self.heap.worst_map_idx() |
130 | 18 | } |
131 | | |
132 | 13 | fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { |
133 | 13 | self.heap.renumber(heap_to_map); |
134 | 13 | } |
135 | | |
136 | 13 | fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { |
137 | 13 | let vals = self.batch.as_primitive::<VAL>(); |
138 | 13 | let new_val = vals.value(row_idx); |
139 | 13 | self.heap.append_or_replace(new_val, map_idx, map); |
140 | 13 | } |
141 | | |
142 | 5 | fn replace_if_better( |
143 | 5 | &mut self, |
144 | 5 | heap_idx: usize, |
145 | 5 | row_idx: usize, |
146 | 5 | map: &mut Vec<(usize, usize)>, |
147 | 5 | ) { |
148 | 5 | let vals = self.batch.as_primitive::<VAL>(); |
149 | 5 | let new_val = vals.value(row_idx); |
150 | 5 | self.heap.replace_if_better(heap_idx, new_val, map); |
151 | 5 | } |
152 | | |
153 | 10 | fn drain(&mut self) -> (ArrayRef, Vec<usize>) { |
154 | 10 | let (vals, map_idxs) = self.heap.drain(); |
155 | 10 | let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals)); |
156 | 10 | let vals = adjust_output_array(&self.data_type, vals).expect("Type is incorrect"); |
157 | 10 | (vals, map_idxs) |
158 | 10 | } |
159 | | } |
160 | | |
161 | | impl<VAL: ValueType> TopKHeap<VAL> { |
162 | 17 | pub fn new(limit: usize, desc: bool) -> Self { |
163 | 17 | Self { |
164 | 17 | desc, |
165 | 17 | capacity: limit, |
166 | 17 | len: 0, |
167 | 89 | heap: (0..=limit).map(|_| None).collect::<Vec<_>>(), |
168 | 17 | } |
169 | 17 | } |
170 | | |
171 | 6 | pub fn worst_val(&self) -> Option<&VAL> { |
172 | 6 | let root = self.heap.first()?0 ; |
173 | 6 | let hi = match root { |
174 | 0 | None => return None, |
175 | 6 | Some(hi) => hi, |
176 | 6 | }; |
177 | 6 | Some(&hi.val) |
178 | 6 | } |
179 | | |
180 | 19 | pub fn worst_map_idx(&self) -> usize { |
181 | 19 | self.heap[0].as_ref().map(|hi| hi.map_idx9 ).unwrap_or(0) |
182 | 19 | } |
183 | | |
184 | 50 | pub fn is_full(&self) -> bool { |
185 | 50 | self.len >= self.capacity |
186 | 50 | } |
187 | | |
188 | 38 | pub fn len(&self) -> usize { |
189 | 38 | self.len |
190 | 38 | } |
191 | | |
192 | 30 | pub fn append_or_replace( |
193 | 30 | &mut self, |
194 | 30 | new_val: VAL, |
195 | 30 | map_idx: usize, |
196 | 30 | map: &mut Vec<(usize, usize)>, |
197 | 30 | ) { |
198 | 30 | if self.is_full() { |
199 | 3 | self.replace_root(new_val, map_idx, map); |
200 | 27 | } else { |
201 | 27 | self.append(new_val, map_idx, map); |
202 | 27 | } |
203 | 30 | } |
204 | | |
205 | 27 | fn append(&mut self, new_val: VAL, map_idx: usize, mapper: &mut Vec<(usize, usize)>) { |
206 | 27 | let hi = HeapItem::new(new_val, map_idx); |
207 | 27 | self.heap[self.len] = Some(hi); |
208 | 27 | self.heapify_up(self.len, mapper); |
209 | 27 | self.len += 1; |
210 | 27 | } |
211 | | |
212 | 24 | fn pop(&mut self, map: &mut Vec<(usize, usize)>) -> Option<HeapItem<VAL>> { |
213 | 24 | if self.len() == 0 { |
214 | 11 | return None; |
215 | 13 | } |
216 | 13 | if self.len() == 1 { |
217 | 11 | self.len = 0; |
218 | 11 | return self.heap[0].take(); |
219 | 2 | } |
220 | 2 | self.swap(0, self.len - 1, map); |
221 | 2 | let former_root = self.heap[self.len - 1].take(); |
222 | 2 | self.len -= 1; |
223 | 2 | self.heapify_down(0, map); |
224 | 2 | former_root |
225 | 24 | } |
226 | | |
227 | 11 | pub fn drain(&mut self) -> (Vec<VAL>, Vec<usize>) { |
228 | 11 | let mut map = Vec::with_capacity(self.len); |
229 | 11 | let mut vals = Vec::with_capacity(self.len); |
230 | 11 | let mut map_idxs = Vec::with_capacity(self.len); |
231 | 24 | while let Some(worst_hi13 ) = self.pop(&mut map) { |
232 | 13 | vals.push(worst_hi.val); |
233 | 13 | map_idxs.push(worst_hi.map_idx); |
234 | 13 | } |
235 | 11 | vals.reverse(); |
236 | 11 | map_idxs.reverse(); |
237 | 11 | (vals, map_idxs) |
238 | 11 | } |
239 | | |
240 | 3 | fn replace_root( |
241 | 3 | &mut self, |
242 | 3 | new_val: VAL, |
243 | 3 | map_idx: usize, |
244 | 3 | mapper: &mut Vec<(usize, usize)>, |
245 | 3 | ) { |
246 | 3 | let hi = self.heap[0].as_mut().expect("No root"); |
247 | 3 | hi.val = new_val; |
248 | 3 | hi.map_idx = map_idx; |
249 | 3 | self.heapify_down(0, mapper); |
250 | 3 | } |
251 | | |
252 | 6 | pub fn replace_if_better( |
253 | 6 | &mut self, |
254 | 6 | heap_idx: usize, |
255 | 6 | new_val: VAL, |
256 | 6 | mapper: &mut Vec<(usize, usize)>, |
257 | 6 | ) { |
258 | 6 | let existing = self.heap[heap_idx].as_mut().expect("Missing heap item"); |
259 | 6 | if (!self.desc && new_val.comp(&existing.val) != Ordering::Less3 ) |
260 | 5 | || (self.desc && new_val.comp(&existing.val) != Ordering::Greater3 ) |
261 | | { |
262 | 2 | return; |
263 | 4 | } |
264 | 4 | existing.val = new_val; |
265 | 4 | self.heapify_down(heap_idx, mapper); |
266 | 6 | } |
267 | | |
268 | 14 | pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { |
269 | 14 | for (heap_idx, map_idx2 ) in heap_to_map.iter() { |
270 | 2 | if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) { |
271 | 2 | hi.map_idx = *map_idx; |
272 | 2 | }0 |
273 | | } |
274 | 14 | } |
275 | | |
276 | 27 | fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) { |
277 | 27 | let desc = self.desc; |
278 | 37 | while idx != 0 { |
279 | 11 | let parent_idx = (idx - 1) / 2; |
280 | 11 | let node = self.heap[idx].as_ref().expect("No heap item"); |
281 | 11 | let parent = self.heap[parent_idx].as_ref().expect("No heap item"); |
282 | 11 | if (!desc && node.val.comp(&parent.val) != Ordering::Greater10 ) |
283 | 11 | || (desc && node.val.comp(&parent.val) != Ordering::Less1 ) |
284 | | { |
285 | 1 | return; |
286 | 10 | } |
287 | 10 | self.swap(idx, parent_idx, mapper); |
288 | 10 | idx = parent_idx; |
289 | | } |
290 | 27 | } |
291 | | |
292 | 14 | fn swap(&mut self, a_idx: usize, b_idx: usize, mapper: &mut Vec<(usize, usize)>) { |
293 | 14 | let a_hi = self.heap[a_idx].take().expect("Missing heap entry"); |
294 | 14 | let b_hi = self.heap[b_idx].take().expect("Missing heap entry"); |
295 | 14 | |
296 | 14 | mapper.push((a_hi.map_idx, b_idx)); |
297 | 14 | mapper.push((b_hi.map_idx, a_idx)); |
298 | 14 | |
299 | 14 | self.heap[a_idx] = Some(b_hi); |
300 | 14 | self.heap[b_idx] = Some(a_hi); |
301 | 14 | } |
302 | | |
303 | 11 | fn heapify_down(&mut self, node_idx: usize, mapper: &mut Vec<(usize, usize)>) { |
304 | 11 | let left_child = node_idx * 2 + 1; |
305 | 11 | let desc = self.desc; |
306 | 11 | let entry = self.heap.get(node_idx).expect("Missing node!"); |
307 | 11 | let entry = entry.as_ref().expect("Missing node!"); |
308 | 11 | let mut best_idx = node_idx; |
309 | 11 | let mut best_val = &entry.val; |
310 | 22 | for child_idx in left_child..=left_child + 111 { |
311 | 22 | if let Some(Some(child3 )) = self.heap.get(child_idx) { |
312 | 3 | if (!desc && child.val.comp(best_val) == Ordering::Greater) |
313 | 0 | || (desc && child.val.comp(best_val) == Ordering::Less) |
314 | 3 | { |
315 | 3 | best_val = &child.val; |
316 | 3 | best_idx = child_idx; |
317 | 3 | }0 |
318 | 19 | } |
319 | | } |
320 | 11 | if best_val.comp(&entry.val) != Ordering::Equal { |
321 | 2 | self.swap(best_idx, node_idx, mapper); |
322 | 2 | self.heapify_down(best_idx, mapper); |
323 | 9 | } |
324 | 11 | } |
325 | | |
326 | 25 | fn _tree_print( |
327 | 25 | &self, |
328 | 25 | idx: usize, |
329 | 25 | prefix: String, |
330 | 25 | is_tail: bool, |
331 | 25 | output: &mut String, |
332 | 25 | ) { |
333 | 25 | if let Some(Some(hi)) = self.heap.get(idx) { |
334 | 25 | let connector = if idx != 0 { |
335 | 15 | if is_tail { |
336 | 11 | "└── " |
337 | | } else { |
338 | 4 | "├── " |
339 | | } |
340 | | } else { |
341 | 10 | "" |
342 | | }; |
343 | 25 | output.push_str(&format!( |
344 | 25 | "{}{}val={:?} idx={}, bucket={}\n", |
345 | 25 | prefix, connector, hi.val, idx, hi.map_idx |
346 | 25 | )); |
347 | 25 | let new_prefix = if is_tail { ""21 } else { "│ "4 }; |
348 | 25 | let child_prefix = format!("{}{}", prefix, new_prefix); |
349 | 25 | |
350 | 25 | let left_idx = idx * 2 + 1; |
351 | 25 | let right_idx = idx * 2 + 2; |
352 | 25 | |
353 | 25 | let left_exists = left_idx < self.len; |
354 | 25 | let right_exists = right_idx < self.len; |
355 | 25 | |
356 | 25 | if left_exists { |
357 | 11 | self._tree_print(left_idx, child_prefix.clone(), !right_exists, output); |
358 | 14 | } |
359 | 25 | if right_exists { |
360 | 4 | self._tree_print(right_idx, child_prefix, true, output); |
361 | 21 | } |
362 | 0 | } |
363 | 25 | } |
364 | | } |
365 | | |
366 | | impl<VAL: ValueType> Display for TopKHeap<VAL> { |
367 | 10 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
368 | 10 | let mut output = String::new(); |
369 | 10 | if self.heap.first().is_some() { |
370 | 10 | self._tree_print(0, String::new(), true, &mut output); |
371 | 10 | }0 |
372 | 10 | write!(f, "{}", output) |
373 | 10 | } |
374 | | } |
375 | | |
376 | | impl<VAL: ValueType> HeapItem<VAL> { |
377 | 27 | pub fn new(val: VAL, buk_idx: usize) -> Self { |
378 | 27 | Self { |
379 | 27 | val, |
380 | 27 | map_idx: buk_idx, |
381 | 27 | } |
382 | 27 | } |
383 | | } |
384 | | |
385 | | impl<VAL: ValueType> Debug for HeapItem<VAL> { |
386 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
387 | 0 | f.write_str("bucket=")?; |
388 | 0 | Debug::fmt(&self.map_idx, f)?; |
389 | 0 | f.write_str(" val=")?; |
390 | 0 | Debug::fmt(&self.val, f)?; |
391 | 0 | f.write_str("\n")?; |
392 | 0 | Ok(()) |
393 | 0 | } |
394 | | } |
395 | | |
396 | | impl<VAL: ValueType> Eq for HeapItem<VAL> {} |
397 | | |
398 | | impl<VAL: ValueType> PartialEq<Self> for HeapItem<VAL> { |
399 | 0 | fn eq(&self, other: &Self) -> bool { |
400 | 0 | self.cmp(other) == Ordering::Equal |
401 | 0 | } |
402 | | } |
403 | | |
404 | | impl<VAL: ValueType> PartialOrd<Self> for HeapItem<VAL> { |
405 | 0 | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
406 | 0 | Some(self.cmp(other)) |
407 | 0 | } |
408 | | } |
409 | | |
410 | | impl<VAL: ValueType> Ord for HeapItem<VAL> { |
411 | 0 | fn cmp(&self, other: &Self) -> Ordering { |
412 | 0 | let res = self.val.comp(&other.val); |
413 | 0 | if res != Ordering::Equal { |
414 | 0 | return res; |
415 | 0 | } |
416 | 0 | self.map_idx.cmp(&other.map_idx) |
417 | 0 | } |
418 | | } |
419 | | |
420 | | macro_rules! compare_float { |
421 | | ($($t:ty),+) => { |
422 | | $(impl Comparable for Option<$t> { |
423 | 0 | fn comp(&self, other: &Self) -> Ordering { |
424 | 0 | match (self, other) { |
425 | 0 | (Some(me), Some(other)) => me.total_cmp(other), |
426 | 0 | (Some(_), None) => Ordering::Greater, |
427 | 0 | (None, Some(_)) => Ordering::Less, |
428 | 0 | (None, None) => Ordering::Equal, |
429 | | } |
430 | 0 | } |
431 | | })+ |
432 | | |
433 | | $(impl Comparable for $t { |
434 | 0 | fn comp(&self, other: &Self) -> Ordering { |
435 | 0 | self.total_cmp(other) |
436 | 0 | } |
437 | | })+ |
438 | | }; |
439 | | } |
440 | | |
441 | | macro_rules! compare_integer { |
442 | | ($($t:ty),+) => { |
443 | | $(impl Comparable for Option<$t> { |
444 | 0 | fn comp(&self, other: &Self) -> Ordering { |
445 | 0 | self.cmp(other) |
446 | 0 | } |
447 | | })+ |
448 | | |
449 | | $(impl Comparable for $t { |
450 | 31 | fn comp(&self, other: &Self) -> Ordering { |
451 | 31 | self.cmp(other) |
452 | 31 | } |
453 | | })+ |
454 | | }; |
455 | | } |
456 | | |
457 | | compare_integer!(i8, i16, i32, i64, i128, i256); |
458 | | compare_integer!(u8, u16, u32, u64); |
459 | | compare_integer!(IntervalDayTime, IntervalMonthDayNano); |
460 | | compare_float!(f16, f32, f64); |
461 | | |
462 | 10 | pub fn new_heap( |
463 | 10 | limit: usize, |
464 | 10 | desc: bool, |
465 | 10 | vt: DataType, |
466 | 10 | ) -> Result<Box<dyn ArrowHeap + Send>> { |
467 | | macro_rules! downcast_helper { |
468 | | ($vt:ty, $d:ident) => { |
469 | | return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) |
470 | | }; |
471 | | } |
472 | | |
473 | 0 | downcast_primitive! { |
474 | 0 | vt => (downcast_helper, vt), |
475 | 0 | _ => {} |
476 | 0 | } |
477 | 0 |
|
478 | 0 | Err(DataFusionError::Execution(format!( |
479 | 0 | "Can't group type: {vt:?}" |
480 | 0 | ))) |
481 | 10 | } |
482 | | |
483 | | #[cfg(test)] |
484 | | mod tests { |
485 | | use super::*; |
486 | | |
487 | | #[test] |
488 | 1 | fn should_append() -> Result<()> { |
489 | 1 | let mut map = vec![]; |
490 | 1 | let mut heap = TopKHeap::new(10, false); |
491 | 1 | heap.append_or_replace(1, 1, &mut map); |
492 | 1 | |
493 | 1 | let actual = heap.to_string(); |
494 | 1 | let expected = r#" |
495 | 1 | val=1 idx=0, bucket=1 |
496 | 1 | "#; |
497 | 1 | assert_eq!(actual.trim(), expected.trim()); |
498 | | |
499 | 1 | Ok(()) |
500 | 1 | } |
501 | | |
502 | | #[test] |
503 | 1 | fn should_heapify_up() -> Result<()> { |
504 | 1 | let mut map = vec![]; |
505 | 1 | let mut heap = TopKHeap::new(10, false); |
506 | 1 | |
507 | 1 | heap.append_or_replace(1, 1, &mut map); |
508 | 1 | assert_eq!(map, vec![]); |
509 | | |
510 | 1 | heap.append_or_replace(2, 2, &mut map); |
511 | 1 | assert_eq!(map, vec![(2, 0), (1, 1)]); |
512 | | |
513 | 1 | let actual = heap.to_string(); |
514 | 1 | let expected = r#" |
515 | 1 | val=2 idx=0, bucket=2 |
516 | 1 | └── val=1 idx=1, bucket=1 |
517 | 1 | "#; |
518 | 1 | assert_eq!(actual.trim(), expected.trim()); |
519 | | |
520 | 1 | Ok(()) |
521 | 1 | } |
522 | | |
523 | | #[test] |
524 | 1 | fn should_heapify_down() -> Result<()> { |
525 | 1 | let mut map = vec![]; |
526 | 1 | let mut heap = TopKHeap::new(3, false); |
527 | 1 | |
528 | 1 | heap.append_or_replace(1, 1, &mut map); |
529 | 1 | heap.append_or_replace(2, 2, &mut map); |
530 | 1 | heap.append_or_replace(3, 3, &mut map); |
531 | 1 | let actual = heap.to_string(); |
532 | 1 | let expected = r#" |
533 | 1 | val=3 idx=0, bucket=3 |
534 | 1 | ├── val=1 idx=1, bucket=1 |
535 | 1 | └── val=2 idx=2, bucket=2 |
536 | 1 | "#; |
537 | 1 | assert_eq!(actual.trim(), expected.trim()); |
538 | | |
539 | 1 | let mut map = vec![]; |
540 | 1 | heap.append_or_replace(0, 0, &mut map); |
541 | 1 | let actual = heap.to_string(); |
542 | 1 | let expected = r#" |
543 | 1 | val=2 idx=0, bucket=2 |
544 | 1 | ├── val=1 idx=1, bucket=1 |
545 | 1 | └── val=0 idx=2, bucket=0 |
546 | 1 | "#; |
547 | 1 | assert_eq!(actual.trim(), expected.trim()); |
548 | 1 | assert_eq!(map, vec![(2, 0), (0, 2)]); |
549 | | |
550 | 1 | Ok(()) |
551 | 1 | } |
552 | | |
553 | | #[test] |
554 | 1 | fn should_replace() -> Result<()> { |
555 | 1 | let mut map = vec![]; |
556 | 1 | let mut heap = TopKHeap::new(4, false); |
557 | 1 | |
558 | 1 | heap.append_or_replace(1, 1, &mut map); |
559 | 1 | heap.append_or_replace(2, 2, &mut map); |
560 | 1 | heap.append_or_replace(3, 3, &mut map); |
561 | 1 | heap.append_or_replace(4, 4, &mut map); |
562 | 1 | let actual = heap.to_string(); |
563 | 1 | let expected = r#" |
564 | 1 | val=4 idx=0, bucket=4 |
565 | 1 | ├── val=3 idx=1, bucket=3 |
566 | 1 | │ └── val=1 idx=3, bucket=1 |
567 | 1 | └── val=2 idx=2, bucket=2 |
568 | 1 | "#; |
569 | 1 | assert_eq!(actual.trim(), expected.trim()); |
570 | | |
571 | 1 | let mut map = vec![]; |
572 | 1 | heap.replace_if_better(1, 0, &mut map); |
573 | 1 | let actual = heap.to_string(); |
574 | 1 | let expected = r#" |
575 | 1 | val=4 idx=0, bucket=4 |
576 | 1 | ├── val=1 idx=1, bucket=1 |
577 | 1 | │ └── val=0 idx=3, bucket=3 |
578 | 1 | └── val=2 idx=2, bucket=2 |
579 | 1 | "#; |
580 | 1 | assert_eq!(actual.trim(), expected.trim()); |
581 | 1 | assert_eq!(map, vec![(1, 1), (3, 3)]); |
582 | | |
583 | 1 | Ok(()) |
584 | 1 | } |
585 | | |
586 | | #[test] |
587 | 1 | fn should_find_worst() -> Result<()> { |
588 | 1 | let mut map = vec![]; |
589 | 1 | let mut heap = TopKHeap::new(10, false); |
590 | 1 | |
591 | 1 | heap.append_or_replace(1, 1, &mut map); |
592 | 1 | heap.append_or_replace(2, 2, &mut map); |
593 | 1 | |
594 | 1 | let actual = heap.to_string(); |
595 | 1 | let expected = r#" |
596 | 1 | val=2 idx=0, bucket=2 |
597 | 1 | └── val=1 idx=1, bucket=1 |
598 | 1 | "#; |
599 | 1 | assert_eq!(actual.trim(), expected.trim()); |
600 | | |
601 | 1 | assert_eq!(heap.worst_val(), Some(&2)); |
602 | 1 | assert_eq!(heap.worst_map_idx(), 2); |
603 | | |
604 | 1 | Ok(()) |
605 | 1 | } |
606 | | |
607 | | #[test] |
608 | 1 | fn should_drain() -> Result<()> { |
609 | 1 | let mut map = vec![]; |
610 | 1 | let mut heap = TopKHeap::new(10, false); |
611 | 1 | |
612 | 1 | heap.append_or_replace(1, 1, &mut map); |
613 | 1 | heap.append_or_replace(2, 2, &mut map); |
614 | 1 | |
615 | 1 | let actual = heap.to_string(); |
616 | 1 | let expected = r#" |
617 | 1 | val=2 idx=0, bucket=2 |
618 | 1 | └── val=1 idx=1, bucket=1 |
619 | 1 | "#; |
620 | 1 | assert_eq!(actual.trim(), expected.trim()); |
621 | | |
622 | 1 | let (vals, map_idxs) = heap.drain(); |
623 | 1 | assert_eq!(vals, vec![1, 2]); |
624 | 1 | assert_eq!(map_idxs, vec![1, 2]); |
625 | 1 | assert_eq!(heap.len(), 0); |
626 | | |
627 | 1 | Ok(()) |
628 | 1 | } |
629 | | |
630 | | #[test] |
631 | 1 | fn should_renumber() -> Result<()> { |
632 | 1 | let mut map = vec![]; |
633 | 1 | let mut heap = TopKHeap::new(10, false); |
634 | 1 | |
635 | 1 | heap.append_or_replace(1, 1, &mut map); |
636 | 1 | heap.append_or_replace(2, 2, &mut map); |
637 | 1 | |
638 | 1 | let actual = heap.to_string(); |
639 | 1 | let expected = r#" |
640 | 1 | val=2 idx=0, bucket=2 |
641 | 1 | └── val=1 idx=1, bucket=1 |
642 | 1 | "#; |
643 | 1 | assert_eq!(actual.trim(), expected.trim()); |
644 | | |
645 | 1 | let numbers = vec![(0, 1), (1, 2)]; |
646 | 1 | heap.renumber(numbers.as_slice()); |
647 | 1 | let actual = heap.to_string(); |
648 | 1 | let expected = r#" |
649 | 1 | val=2 idx=0, bucket=1 |
650 | 1 | └── val=1 idx=1, bucket=2 |
651 | 1 | "#; |
652 | 1 | assert_eq!(actual.trim(), expected.trim()); |
653 | | |
654 | 1 | Ok(()) |
655 | 1 | } |
656 | | } |