Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/topk/priority_map.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! A `Map<K, V>` / `PriorityQueue` combo that evicts the worst values after reaching `capacity`
19
20
use crate::aggregates::topk::hash_table::{new_hash_table, ArrowHashTable};
21
use crate::aggregates::topk::heap::{new_heap, ArrowHeap};
22
use arrow_array::ArrayRef;
23
use arrow_schema::DataType;
24
use datafusion_common::Result;
25
26
/// A `Map<K, V>` / `PriorityQueue` combo that evicts the worst values after reaching `capacity`
27
pub struct PriorityMap {
28
    map: Box<dyn ArrowHashTable + Send>,
29
    heap: Box<dyn ArrowHeap + Send>,
30
    capacity: usize,
31
    mapper: Vec<(usize, usize)>,
32
}
33
34
impl PriorityMap {
35
10
    pub fn new(
36
10
        key_type: DataType,
37
10
        val_type: DataType,
38
10
        capacity: usize,
39
10
        descending: bool,
40
10
    ) -> Result<Self> {
41
10
        Ok(Self {
42
10
            map: new_hash_table(capacity, key_type)
?0
,
43
10
            heap: new_heap(capacity, descending, val_type)
?0
,
44
10
            capacity,
45
10
            mapper: Vec::with_capacity(capacity),
46
        })
47
10
    }
48
49
10
    pub fn set_batch(&mut self, ids: ArrayRef, vals: ArrayRef) {
50
10
        self.map.set_batch(ids);
51
10
        self.heap.set_batch(vals);
52
10
    }
53
54
20
    pub fn insert(&mut self, row_idx: usize) -> Result<()> {
55
20
        assert!(self.map.len() <= self.capacity, 
"Overflow"0
);
56
57
        // if we're full, and the new val is worse than all our values, just bail
58
20
        if self.heap.is_worse(row_idx) {
59
2
            return Ok(());
60
18
        }
61
18
        let map = &mut self.mapper;
62
18
63
18
        // handle new groups we haven't seen yet
64
18
        map.clear();
65
18
        let replace_idx = self.heap.worst_map_idx();
66
18
        // JUSTIFICATION
67
18
        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
68
18
        //  Soundness: replace_idx kept valid during resizes
69
18
        let (map_idx, did_insert) =
70
18
            unsafe { self.map.find_or_insert(row_idx, replace_idx, map) };
71
18
        if did_insert {
72
13
            self.heap.renumber(map);
73
13
            map.clear();
74
13
            self.heap.insert(row_idx, map_idx, map);
75
13
            // JUSTIFICATION
76
13
            //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
77
13
            //  Soundness: the map was created on the line above, so all the indexes should be valid
78
13
            unsafe { self.map.update_heap_idx(map) };
79
13
            return Ok(());
80
5
        };
81
5
82
5
        // this is a value for an existing group
83
5
        map.clear();
84
5
        // JUSTIFICATION
85
5
        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
86
5
        //  Soundness: map_idx was just found, so it is valid
87
5
        let heap_idx = unsafe { self.map.heap_idx_at(map_idx) };
88
5
        self.heap.replace_if_better(heap_idx, row_idx, map);
89
5
        // JUSTIFICATION
90
5
        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
91
5
        //  Soundness: the index map was just built, so it will be valid
92
5
        unsafe { self.map.update_heap_idx(map) };
93
5
94
5
        Ok(())
95
20
    }
96
97
10
    pub fn emit(&mut self) -> Result<Vec<ArrayRef>> {
98
10
        let (vals, map_idxs) = self.heap.drain();
99
10
        let ids = unsafe { self.map.take_all(map_idxs) };
100
10
        Ok(vec![ids, vals])
101
10
    }
102
103
0
    pub fn is_empty(&self) -> bool {
104
0
        self.map.len() == 0
105
0
    }
106
}
107
108
#[cfg(test)]
109
mod tests {
110
    use super::*;
111
    use arrow::util::pretty::pretty_format_batches;
112
    use arrow_array::{Int64Array, RecordBatch, StringArray};
113
    use arrow_schema::Field;
114
    use arrow_schema::Schema;
115
    use arrow_schema::SchemaRef;
116
    use std::sync::Arc;
117
118
    #[test]
119
1
    fn should_append() -> Result<()> {
120
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"]));
121
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
122
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)
?0
;
123
1
        agg.set_batch(ids, vals);
124
1
        agg.insert(0)
?0
;
125
126
1
        let cols = agg.emit()
?0
;
127
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
128
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
129
1
        let expected = r#"
130
1
+----------+--------------+
131
1
| trace_id | timestamp_ms |
132
1
+----------+--------------+
133
1
| 1        | 1            |
134
1
+----------+--------------+
135
1
        "#
136
1
        .trim();
137
1
        assert_eq!(actual, expected);
138
139
1
        Ok(())
140
1
    }
141
142
    #[test]
143
1
    fn should_ignore_higher_group() -> Result<()> {
144
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"]));
145
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
146
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)
?0
;
147
1
        agg.set_batch(ids, vals);
148
1
        agg.insert(0)
?0
;
149
1
        agg.insert(1)
?0
;
150
151
1
        let cols = agg.emit()
?0
;
152
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
153
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
154
1
        let expected = r#"
155
1
+----------+--------------+
156
1
| trace_id | timestamp_ms |
157
1
+----------+--------------+
158
1
| 1        | 1            |
159
1
+----------+--------------+
160
1
        "#
161
1
        .trim();
162
1
        assert_eq!(actual, expected);
163
164
1
        Ok(())
165
1
    }
166
167
    #[test]
168
1
    fn should_ignore_lower_group() -> Result<()> {
169
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"]));
170
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
171
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)
?0
;
172
1
        agg.set_batch(ids, vals);
173
1
        agg.insert(0)
?0
;
174
1
        agg.insert(1)
?0
;
175
176
1
        let cols = agg.emit()
?0
;
177
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
178
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
179
1
        let expected = r#"
180
1
+----------+--------------+
181
1
| trace_id | timestamp_ms |
182
1
+----------+--------------+
183
1
| 2        | 2            |
184
1
+----------+--------------+
185
1
        "#
186
1
        .trim();
187
1
        assert_eq!(actual, expected);
188
189
1
        Ok(())
190
1
    }
191
192
    #[test]
193
1
    fn should_ignore_higher_same_group() -> Result<()> {
194
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
195
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
196
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)
?0
;
197
1
        agg.set_batch(ids, vals);
198
1
        agg.insert(0)
?0
;
199
1
        agg.insert(1)
?0
;
200
201
1
        let cols = agg.emit()
?0
;
202
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
203
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
204
1
        let expected = r#"
205
1
+----------+--------------+
206
1
| trace_id | timestamp_ms |
207
1
+----------+--------------+
208
1
| 1        | 1            |
209
1
+----------+--------------+
210
1
        "#
211
1
        .trim();
212
1
        assert_eq!(actual, expected);
213
214
1
        Ok(())
215
1
    }
216
217
    #[test]
218
1
    fn should_ignore_lower_same_group() -> Result<()> {
219
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
220
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
221
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)
?0
;
222
1
        agg.set_batch(ids, vals);
223
1
        agg.insert(0)
?0
;
224
1
        agg.insert(1)
?0
;
225
226
1
        let cols = agg.emit()
?0
;
227
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
228
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
229
1
        let expected = r#"
230
1
+----------+--------------+
231
1
| trace_id | timestamp_ms |
232
1
+----------+--------------+
233
1
| 1        | 2            |
234
1
+----------+--------------+
235
1
        "#
236
1
        .trim();
237
1
        assert_eq!(actual, expected);
238
239
1
        Ok(())
240
1
    }
241
242
    #[test]
243
1
    fn should_accept_lower_group() -> Result<()> {
244
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"]));
245
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
246
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)
?0
;
247
1
        agg.set_batch(ids, vals);
248
1
        agg.insert(0)
?0
;
249
1
        agg.insert(1)
?0
;
250
251
1
        let cols = agg.emit()
?0
;
252
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
253
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
254
1
        let expected = r#"
255
1
+----------+--------------+
256
1
| trace_id | timestamp_ms |
257
1
+----------+--------------+
258
1
| 1        | 1            |
259
1
+----------+--------------+
260
1
        "#
261
1
        .trim();
262
1
        assert_eq!(actual, expected);
263
264
1
        Ok(())
265
1
    }
266
267
    #[test]
268
1
    fn should_accept_higher_group() -> Result<()> {
269
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"]));
270
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
271
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)
?0
;
272
1
        agg.set_batch(ids, vals);
273
1
        agg.insert(0)
?0
;
274
1
        agg.insert(1)
?0
;
275
276
1
        let cols = agg.emit()
?0
;
277
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
278
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
279
1
        let expected = r#"
280
1
+----------+--------------+
281
1
| trace_id | timestamp_ms |
282
1
+----------+--------------+
283
1
| 2        | 2            |
284
1
+----------+--------------+
285
1
        "#
286
1
        .trim();
287
1
        assert_eq!(actual, expected);
288
289
1
        Ok(())
290
1
    }
291
292
    #[test]
293
1
    fn should_accept_lower_for_group() -> Result<()> {
294
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
295
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
296
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)
?0
;
297
1
        agg.set_batch(ids, vals);
298
1
        agg.insert(0)
?0
;
299
1
        agg.insert(1)
?0
;
300
301
1
        let cols = agg.emit()
?0
;
302
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
303
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
304
1
        let expected = r#"
305
1
+----------+--------------+
306
1
| trace_id | timestamp_ms |
307
1
+----------+--------------+
308
1
| 1        | 1            |
309
1
+----------+--------------+
310
1
        "#
311
1
        .trim();
312
1
        assert_eq!(actual, expected);
313
314
1
        Ok(())
315
1
    }
316
317
    #[test]
318
1
    fn should_accept_higher_for_group() -> Result<()> {
319
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
320
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
321
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)
?0
;
322
1
        agg.set_batch(ids, vals);
323
1
        agg.insert(0)
?0
;
324
1
        agg.insert(1)
?0
;
325
326
1
        let cols = agg.emit()
?0
;
327
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
328
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
329
1
        let expected = r#"
330
1
+----------+--------------+
331
1
| trace_id | timestamp_ms |
332
1
+----------+--------------+
333
1
| 1        | 2            |
334
1
+----------+--------------+
335
1
        "#
336
1
        .trim();
337
1
        assert_eq!(actual, expected);
338
339
1
        Ok(())
340
1
    }
341
342
    #[test]
343
1
    fn should_handle_null_ids() -> Result<()> {
344
1
        let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None]));
345
1
        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
346
1
        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)
?0
;
347
1
        agg.set_batch(ids, vals);
348
1
        agg.insert(0)
?0
;
349
1
        agg.insert(1)
?0
;
350
1
        agg.insert(2)
?0
;
351
352
1
        let cols = agg.emit()
?0
;
353
1
        let batch = RecordBatch::try_new(test_schema(), cols)
?0
;
354
1
        let actual = format!("{}", pretty_format_batches(&[batch])
?0
);
355
1
        let expected = r#"
356
1
+----------+--------------+
357
1
| trace_id | timestamp_ms |
358
1
+----------+--------------+
359
1
|          | 3            |
360
1
| 1        | 1            |
361
1
+----------+--------------+
362
1
        "#
363
1
        .trim();
364
1
        assert_eq!(actual, expected);
365
366
1
        Ok(())
367
1
    }
368
369
10
    fn test_schema() -> SchemaRef {
370
10
        Arc::new(Schema::new(vec![
371
10
            Field::new("trace_id", DataType::Utf8, true),
372
10
            Field::new("timestamp_ms", DataType::Int64, true),
373
10
        ]))
374
10
    }
375
}