/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/topk/priority_map.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! A `Map<K, V>` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` |
19 | | |
20 | | use crate::aggregates::topk::hash_table::{new_hash_table, ArrowHashTable}; |
21 | | use crate::aggregates::topk::heap::{new_heap, ArrowHeap}; |
22 | | use arrow_array::ArrayRef; |
23 | | use arrow_schema::DataType; |
24 | | use datafusion_common::Result; |
25 | | |
26 | | /// A `Map<K, V>` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` |
27 | | pub struct PriorityMap { |
28 | | map: Box<dyn ArrowHashTable + Send>, |
29 | | heap: Box<dyn ArrowHeap + Send>, |
30 | | capacity: usize, |
31 | | mapper: Vec<(usize, usize)>, |
32 | | } |
33 | | |
34 | | impl PriorityMap { |
35 | 10 | pub fn new( |
36 | 10 | key_type: DataType, |
37 | 10 | val_type: DataType, |
38 | 10 | capacity: usize, |
39 | 10 | descending: bool, |
40 | 10 | ) -> Result<Self> { |
41 | 10 | Ok(Self { |
42 | 10 | map: new_hash_table(capacity, key_type)?0 , |
43 | 10 | heap: new_heap(capacity, descending, val_type)?0 , |
44 | 10 | capacity, |
45 | 10 | mapper: Vec::with_capacity(capacity), |
46 | | }) |
47 | 10 | } |
48 | | |
49 | 10 | pub fn set_batch(&mut self, ids: ArrayRef, vals: ArrayRef) { |
50 | 10 | self.map.set_batch(ids); |
51 | 10 | self.heap.set_batch(vals); |
52 | 10 | } |
53 | | |
54 | 20 | pub fn insert(&mut self, row_idx: usize) -> Result<()> { |
55 | 20 | assert!(self.map.len() <= self.capacity, "Overflow"0 ); |
56 | | |
57 | | // if we're full, and the new val is worse than all our values, just bail |
58 | 20 | if self.heap.is_worse(row_idx) { |
59 | 2 | return Ok(()); |
60 | 18 | } |
61 | 18 | let map = &mut self.mapper; |
62 | 18 | |
63 | 18 | // handle new groups we haven't seen yet |
64 | 18 | map.clear(); |
65 | 18 | let replace_idx = self.heap.worst_map_idx(); |
66 | 18 | // JUSTIFICATION |
67 | 18 | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
68 | 18 | // Soundness: replace_idx kept valid during resizes |
69 | 18 | let (map_idx, did_insert) = |
70 | 18 | unsafe { self.map.find_or_insert(row_idx, replace_idx, map) }; |
71 | 18 | if did_insert { |
72 | 13 | self.heap.renumber(map); |
73 | 13 | map.clear(); |
74 | 13 | self.heap.insert(row_idx, map_idx, map); |
75 | 13 | // JUSTIFICATION |
76 | 13 | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
77 | 13 | // Soundness: the map was created on the line above, so all the indexes should be valid |
78 | 13 | unsafe { self.map.update_heap_idx(map) }; |
79 | 13 | return Ok(()); |
80 | 5 | }; |
81 | 5 | |
82 | 5 | // this is a value for an existing group |
83 | 5 | map.clear(); |
84 | 5 | // JUSTIFICATION |
85 | 5 | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
86 | 5 | // Soundness: map_idx was just found, so it is valid |
87 | 5 | let heap_idx = unsafe { self.map.heap_idx_at(map_idx) }; |
88 | 5 | self.heap.replace_if_better(heap_idx, row_idx, map); |
89 | 5 | // JUSTIFICATION |
90 | 5 | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
91 | 5 | // Soundness: the index map was just built, so it will be valid |
92 | 5 | unsafe { self.map.update_heap_idx(map) }; |
93 | 5 | |
94 | 5 | Ok(()) |
95 | 20 | } |
96 | | |
97 | 10 | pub fn emit(&mut self) -> Result<Vec<ArrayRef>> { |
98 | 10 | let (vals, map_idxs) = self.heap.drain(); |
99 | 10 | let ids = unsafe { self.map.take_all(map_idxs) }; |
100 | 10 | Ok(vec![ids, vals]) |
101 | 10 | } |
102 | | |
103 | 0 | pub fn is_empty(&self) -> bool { |
104 | 0 | self.map.len() == 0 |
105 | 0 | } |
106 | | } |
107 | | |
108 | | #[cfg(test)] |
109 | | mod tests { |
110 | | use super::*; |
111 | | use arrow::util::pretty::pretty_format_batches; |
112 | | use arrow_array::{Int64Array, RecordBatch, StringArray}; |
113 | | use arrow_schema::Field; |
114 | | use arrow_schema::Schema; |
115 | | use arrow_schema::SchemaRef; |
116 | | use std::sync::Arc; |
117 | | |
118 | | #[test] |
119 | 1 | fn should_append() -> Result<()> { |
120 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"])); |
121 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); |
122 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?0 ; |
123 | 1 | agg.set_batch(ids, vals); |
124 | 1 | agg.insert(0)?0 ; |
125 | | |
126 | 1 | let cols = agg.emit()?0 ; |
127 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
128 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
129 | 1 | let expected = r#" |
130 | 1 | +----------+--------------+ |
131 | 1 | | trace_id | timestamp_ms | |
132 | 1 | +----------+--------------+ |
133 | 1 | | 1 | 1 | |
134 | 1 | +----------+--------------+ |
135 | 1 | "# |
136 | 1 | .trim(); |
137 | 1 | assert_eq!(actual, expected); |
138 | | |
139 | 1 | Ok(()) |
140 | 1 | } |
141 | | |
142 | | #[test] |
143 | 1 | fn should_ignore_higher_group() -> Result<()> { |
144 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); |
145 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); |
146 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?0 ; |
147 | 1 | agg.set_batch(ids, vals); |
148 | 1 | agg.insert(0)?0 ; |
149 | 1 | agg.insert(1)?0 ; |
150 | | |
151 | 1 | let cols = agg.emit()?0 ; |
152 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
153 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
154 | 1 | let expected = r#" |
155 | 1 | +----------+--------------+ |
156 | 1 | | trace_id | timestamp_ms | |
157 | 1 | +----------+--------------+ |
158 | 1 | | 1 | 1 | |
159 | 1 | +----------+--------------+ |
160 | 1 | "# |
161 | 1 | .trim(); |
162 | 1 | assert_eq!(actual, expected); |
163 | | |
164 | 1 | Ok(()) |
165 | 1 | } |
166 | | |
167 | | #[test] |
168 | 1 | fn should_ignore_lower_group() -> Result<()> { |
169 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); |
170 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); |
171 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?0 ; |
172 | 1 | agg.set_batch(ids, vals); |
173 | 1 | agg.insert(0)?0 ; |
174 | 1 | agg.insert(1)?0 ; |
175 | | |
176 | 1 | let cols = agg.emit()?0 ; |
177 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
178 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
179 | 1 | let expected = r#" |
180 | 1 | +----------+--------------+ |
181 | 1 | | trace_id | timestamp_ms | |
182 | 1 | +----------+--------------+ |
183 | 1 | | 2 | 2 | |
184 | 1 | +----------+--------------+ |
185 | 1 | "# |
186 | 1 | .trim(); |
187 | 1 | assert_eq!(actual, expected); |
188 | | |
189 | 1 | Ok(()) |
190 | 1 | } |
191 | | |
192 | | #[test] |
193 | 1 | fn should_ignore_higher_same_group() -> Result<()> { |
194 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); |
195 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); |
196 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?0 ; |
197 | 1 | agg.set_batch(ids, vals); |
198 | 1 | agg.insert(0)?0 ; |
199 | 1 | agg.insert(1)?0 ; |
200 | | |
201 | 1 | let cols = agg.emit()?0 ; |
202 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
203 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
204 | 1 | let expected = r#" |
205 | 1 | +----------+--------------+ |
206 | 1 | | trace_id | timestamp_ms | |
207 | 1 | +----------+--------------+ |
208 | 1 | | 1 | 1 | |
209 | 1 | +----------+--------------+ |
210 | 1 | "# |
211 | 1 | .trim(); |
212 | 1 | assert_eq!(actual, expected); |
213 | | |
214 | 1 | Ok(()) |
215 | 1 | } |
216 | | |
217 | | #[test] |
218 | 1 | fn should_ignore_lower_same_group() -> Result<()> { |
219 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); |
220 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); |
221 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?0 ; |
222 | 1 | agg.set_batch(ids, vals); |
223 | 1 | agg.insert(0)?0 ; |
224 | 1 | agg.insert(1)?0 ; |
225 | | |
226 | 1 | let cols = agg.emit()?0 ; |
227 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
228 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
229 | 1 | let expected = r#" |
230 | 1 | +----------+--------------+ |
231 | 1 | | trace_id | timestamp_ms | |
232 | 1 | +----------+--------------+ |
233 | 1 | | 1 | 2 | |
234 | 1 | +----------+--------------+ |
235 | 1 | "# |
236 | 1 | .trim(); |
237 | 1 | assert_eq!(actual, expected); |
238 | | |
239 | 1 | Ok(()) |
240 | 1 | } |
241 | | |
242 | | #[test] |
243 | 1 | fn should_accept_lower_group() -> Result<()> { |
244 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); |
245 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); |
246 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?0 ; |
247 | 1 | agg.set_batch(ids, vals); |
248 | 1 | agg.insert(0)?0 ; |
249 | 1 | agg.insert(1)?0 ; |
250 | | |
251 | 1 | let cols = agg.emit()?0 ; |
252 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
253 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
254 | 1 | let expected = r#" |
255 | 1 | +----------+--------------+ |
256 | 1 | | trace_id | timestamp_ms | |
257 | 1 | +----------+--------------+ |
258 | 1 | | 1 | 1 | |
259 | 1 | +----------+--------------+ |
260 | 1 | "# |
261 | 1 | .trim(); |
262 | 1 | assert_eq!(actual, expected); |
263 | | |
264 | 1 | Ok(()) |
265 | 1 | } |
266 | | |
267 | | #[test] |
268 | 1 | fn should_accept_higher_group() -> Result<()> { |
269 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); |
270 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); |
271 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?0 ; |
272 | 1 | agg.set_batch(ids, vals); |
273 | 1 | agg.insert(0)?0 ; |
274 | 1 | agg.insert(1)?0 ; |
275 | | |
276 | 1 | let cols = agg.emit()?0 ; |
277 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
278 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
279 | 1 | let expected = r#" |
280 | 1 | +----------+--------------+ |
281 | 1 | | trace_id | timestamp_ms | |
282 | 1 | +----------+--------------+ |
283 | 1 | | 2 | 2 | |
284 | 1 | +----------+--------------+ |
285 | 1 | "# |
286 | 1 | .trim(); |
287 | 1 | assert_eq!(actual, expected); |
288 | | |
289 | 1 | Ok(()) |
290 | 1 | } |
291 | | |
292 | | #[test] |
293 | 1 | fn should_accept_lower_for_group() -> Result<()> { |
294 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); |
295 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); |
296 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?0 ; |
297 | 1 | agg.set_batch(ids, vals); |
298 | 1 | agg.insert(0)?0 ; |
299 | 1 | agg.insert(1)?0 ; |
300 | | |
301 | 1 | let cols = agg.emit()?0 ; |
302 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
303 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
304 | 1 | let expected = r#" |
305 | 1 | +----------+--------------+ |
306 | 1 | | trace_id | timestamp_ms | |
307 | 1 | +----------+--------------+ |
308 | 1 | | 1 | 1 | |
309 | 1 | +----------+--------------+ |
310 | 1 | "# |
311 | 1 | .trim(); |
312 | 1 | assert_eq!(actual, expected); |
313 | | |
314 | 1 | Ok(()) |
315 | 1 | } |
316 | | |
317 | | #[test] |
318 | 1 | fn should_accept_higher_for_group() -> Result<()> { |
319 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); |
320 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); |
321 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?0 ; |
322 | 1 | agg.set_batch(ids, vals); |
323 | 1 | agg.insert(0)?0 ; |
324 | 1 | agg.insert(1)?0 ; |
325 | | |
326 | 1 | let cols = agg.emit()?0 ; |
327 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
328 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
329 | 1 | let expected = r#" |
330 | 1 | +----------+--------------+ |
331 | 1 | | trace_id | timestamp_ms | |
332 | 1 | +----------+--------------+ |
333 | 1 | | 1 | 2 | |
334 | 1 | +----------+--------------+ |
335 | 1 | "# |
336 | 1 | .trim(); |
337 | 1 | assert_eq!(actual, expected); |
338 | | |
339 | 1 | Ok(()) |
340 | 1 | } |
341 | | |
342 | | #[test] |
343 | 1 | fn should_handle_null_ids() -> Result<()> { |
344 | 1 | let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None])); |
345 | 1 | let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); |
346 | 1 | let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?0 ; |
347 | 1 | agg.set_batch(ids, vals); |
348 | 1 | agg.insert(0)?0 ; |
349 | 1 | agg.insert(1)?0 ; |
350 | 1 | agg.insert(2)?0 ; |
351 | | |
352 | 1 | let cols = agg.emit()?0 ; |
353 | 1 | let batch = RecordBatch::try_new(test_schema(), cols)?0 ; |
354 | 1 | let actual = format!("{}", pretty_format_batches(&[batch])?0 ); |
355 | 1 | let expected = r#" |
356 | 1 | +----------+--------------+ |
357 | 1 | | trace_id | timestamp_ms | |
358 | 1 | +----------+--------------+ |
359 | 1 | | | 3 | |
360 | 1 | | 1 | 1 | |
361 | 1 | +----------+--------------+ |
362 | 1 | "# |
363 | 1 | .trim(); |
364 | 1 | assert_eq!(actual, expected); |
365 | | |
366 | 1 | Ok(()) |
367 | 1 | } |
368 | | |
369 | 10 | fn test_schema() -> SchemaRef { |
370 | 10 | Arc::new(Schema::new(vec![ |
371 | 10 | Field::new("trace_id", DataType::Utf8, true), |
372 | 10 | Field::new("timestamp_ms", DataType::Int64, true), |
373 | 10 | ])) |
374 | 10 | } |
375 | | } |