/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/topk/mod.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! TopK: Combination of Sort / LIMIT |
19 | | |
20 | | use arrow::{ |
21 | | compute::interleave, |
22 | | row::{RowConverter, Rows, SortField}, |
23 | | }; |
24 | | use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; |
25 | | |
26 | | use arrow_array::{Array, ArrayRef, RecordBatch}; |
27 | | use arrow_schema::SchemaRef; |
28 | | use datafusion_common::Result; |
29 | | use datafusion_execution::{ |
30 | | memory_pool::{MemoryConsumer, MemoryReservation}, |
31 | | runtime_env::RuntimeEnv, |
32 | | }; |
33 | | use datafusion_physical_expr::PhysicalSortExpr; |
34 | | use hashbrown::HashMap; |
35 | | |
36 | | use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; |
37 | | |
38 | | use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; |
39 | | |
40 | | /// Global TopK |
41 | | /// |
42 | | /// # Background |
43 | | /// |
44 | | /// "Top K" is a common query optimization used for queries such as |
45 | | /// "find the top 3 customers by revenue". The (simplified) SQL for |
46 | | /// such a query might be: |
47 | | /// |
48 | | /// ```sql |
49 | | /// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3; |
50 | | /// ``` |
51 | | /// |
52 | | /// The simple plan would be: |
53 | | /// |
54 | | /// ```sql |
55 | | /// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3; |
56 | | /// +--------------+----------------------------------------+ |
57 | | /// | plan_type | plan | |
58 | | /// +--------------+----------------------------------------+ |
59 | | /// | logical_plan | Limit: 3 | |
60 | | /// | | Sort: revenue DESC NULLS FIRST | |
61 | | /// | | Projection: customer_id, revenue | |
62 | | /// | | TableScan: sales | |
63 | | /// +--------------+----------------------------------------+ |
64 | | /// ``` |
65 | | /// |
66 | | /// While this plan produces the correct answer, it will fully sorts the |
67 | | /// input before discarding everything other than the top 3 elements. |
68 | | /// |
69 | | /// The same answer can be produced by simply keeping track of the top |
70 | | /// K=3 elements, reducing the total amount of required buffer memory. |
71 | | /// |
72 | | /// # Structure |
73 | | /// |
74 | | /// This operator tracks the top K items using a `TopKHeap`. |
75 | | pub struct TopK { |
76 | | /// schema of the output (and the input) |
77 | | schema: SchemaRef, |
78 | | /// Runtime metrics |
79 | | metrics: TopKMetrics, |
80 | | /// Reservation |
81 | | reservation: MemoryReservation, |
82 | | /// The target number of rows for output batches |
83 | | batch_size: usize, |
84 | | /// sort expressions |
85 | | expr: Arc<[PhysicalSortExpr]>, |
86 | | /// row converter, for sort keys |
87 | | row_converter: RowConverter, |
88 | | /// scratch space for converting rows |
89 | | scratch_rows: Rows, |
90 | | /// stores the top k values and their sort key values, in order |
91 | | heap: TopKHeap, |
92 | | } |
93 | | |
94 | | impl TopK { |
95 | | /// Create a new [`TopK`] that stores the top `k` values, as |
96 | | /// defined by the sort expressions in `expr`. |
97 | | // TODO: make a builder or some other nicer API to avoid the |
98 | | // clippy warning |
99 | | #[allow(clippy::too_many_arguments)] |
100 | 5 | pub fn try_new( |
101 | 5 | partition_id: usize, |
102 | 5 | schema: SchemaRef, |
103 | 5 | expr: Vec<PhysicalSortExpr>, |
104 | 5 | k: usize, |
105 | 5 | batch_size: usize, |
106 | 5 | runtime: Arc<RuntimeEnv>, |
107 | 5 | metrics: &ExecutionPlanMetricsSet, |
108 | 5 | partition: usize, |
109 | 5 | ) -> Result<Self> { |
110 | 5 | let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) |
111 | 5 | .register(&runtime.memory_pool); |
112 | 5 | |
113 | 5 | let expr: Arc<[PhysicalSortExpr]> = expr.into(); |
114 | | |
115 | 5 | let sort_fields: Vec<_> = expr |
116 | 5 | .iter() |
117 | 13 | .map(|e| { |
118 | 13 | Ok(SortField::new_with_options( |
119 | 13 | e.expr.data_type(&schema)?0 , |
120 | 13 | e.options, |
121 | | )) |
122 | 13 | }) |
123 | 5 | .collect::<Result<_>>()?0 ; |
124 | | |
125 | | // TODO there is potential to add special cases for single column sort fields |
126 | | // to improve performance |
127 | 5 | let row_converter = RowConverter::new(sort_fields)?0 ; |
128 | 5 | let scratch_rows = row_converter.empty_rows( |
129 | 5 | batch_size, |
130 | 5 | 20 * batch_size, // guestimate 20 bytes per row |
131 | 5 | ); |
132 | 5 | |
133 | 5 | Ok(Self { |
134 | 5 | schema: Arc::clone(&schema), |
135 | 5 | metrics: TopKMetrics::new(metrics, partition), |
136 | 5 | reservation, |
137 | 5 | batch_size, |
138 | 5 | expr, |
139 | 5 | row_converter, |
140 | 5 | scratch_rows, |
141 | 5 | heap: TopKHeap::new(k, batch_size, schema), |
142 | 5 | }) |
143 | 5 | } |
144 | | |
145 | | /// Insert `batch`, remembering if any of its values are among |
146 | | /// the top k seen so far. |
147 | 20 | pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> { |
148 | 20 | // Updates on drop |
149 | 20 | let _timer = self.metrics.baseline.elapsed_compute().timer(); |
150 | | |
151 | 20 | let sort_keys: Vec<ArrayRef> = self |
152 | 20 | .expr |
153 | 20 | .iter() |
154 | 52 | .map(|expr| { |
155 | 52 | let value = expr.expr.evaluate(&batch)?0 ; |
156 | 52 | value.into_array(batch.num_rows()) |
157 | 52 | }) |
158 | 20 | .collect::<Result<Vec<_>>>()?0 ; |
159 | | |
160 | | // reuse existing `Rows` to avoid reallocations |
161 | 20 | let rows = &mut self.scratch_rows; |
162 | 20 | rows.clear(); |
163 | 20 | self.row_converter.append(rows, &sort_keys)?0 ; |
164 | | |
165 | | // TODO make this algorithmically better?: |
166 | | // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`) |
167 | | // this avoids some work and also might be better vectorizable. |
168 | 20 | let mut batch_entry = self.heap.register_batch(batch); |
169 | 2.00k | for (index, row) in rows.iter().enumerate()20 { |
170 | 2.00k | match self.heap.max() { |
171 | | // heap has k items, and the new row is greater than the |
172 | | // current max in the heap ==> it is not a new topk |
173 | 1.42k | Some(max_row1.32k ) if row.as_ref() >= max_row.row() => {}1.32k |
174 | | // don't yet have k items or new item is lower than the currently k low values |
175 | 676 | None | Some(_) => { |
176 | 676 | self.heap.add(&mut batch_entry, row, index); |
177 | 676 | self.metrics.row_replacements.add(1); |
178 | 676 | } |
179 | | } |
180 | | } |
181 | 20 | self.heap.insert_batch_entry(batch_entry); |
182 | 20 | |
183 | 20 | // conserve memory |
184 | 20 | self.heap.maybe_compact()?0 ; |
185 | | |
186 | | // update memory reservation |
187 | 20 | self.reservation.try_resize(self.size())?0 ; |
188 | 20 | Ok(()) |
189 | 20 | } |
190 | | |
191 | | /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap |
192 | 5 | pub fn emit(self) -> Result<SendableRecordBatchStream> { |
193 | 5 | let Self { |
194 | 5 | schema, |
195 | 5 | metrics, |
196 | 5 | reservation: _, |
197 | 5 | batch_size, |
198 | 5 | expr: _, |
199 | 5 | row_converter: _, |
200 | 5 | scratch_rows: _, |
201 | 5 | mut heap, |
202 | 5 | } = self; |
203 | 5 | let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop |
204 | | |
205 | 5 | let mut batch = heap.emit()?0 ; |
206 | 5 | metrics.baseline.output_rows().add(batch.num_rows()); |
207 | 5 | |
208 | 5 | // break into record batches as needed |
209 | 5 | let mut batches = vec![]; |
210 | | loop { |
211 | 5 | if batch.num_rows() <= batch_size { |
212 | 5 | batches.push(Ok(batch)); |
213 | 5 | break; |
214 | 0 | } else { |
215 | 0 | batches.push(Ok(batch.slice(0, batch_size))); |
216 | 0 | let remaining_length = batch.num_rows() - batch_size; |
217 | 0 | batch = batch.slice(batch_size, remaining_length); |
218 | 0 | } |
219 | | } |
220 | 5 | Ok(Box::pin(RecordBatchStreamAdapter::new( |
221 | 5 | schema, |
222 | 5 | futures::stream::iter(batches), |
223 | 5 | ))) |
224 | 5 | } |
225 | | |
226 | | /// return the size of memory used by this operator, in bytes |
227 | 20 | fn size(&self) -> usize { |
228 | 20 | std::mem::size_of::<Self>() |
229 | 20 | + self.row_converter.size() |
230 | 20 | + self.scratch_rows.size() |
231 | 20 | + self.heap.size() |
232 | 20 | } |
233 | | } |
234 | | |
235 | | struct TopKMetrics { |
236 | | /// metrics |
237 | | pub baseline: BaselineMetrics, |
238 | | |
239 | | /// count of how many rows were replaced in the heap |
240 | | pub row_replacements: Count, |
241 | | } |
242 | | |
243 | | impl TopKMetrics { |
244 | 5 | fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { |
245 | 5 | Self { |
246 | 5 | baseline: BaselineMetrics::new(metrics, partition), |
247 | 5 | row_replacements: MetricBuilder::new(metrics) |
248 | 5 | .counter("row_replacements", partition), |
249 | 5 | } |
250 | 5 | } |
251 | | } |
252 | | |
253 | | /// This structure keeps at most the *smallest* k items, using the |
254 | | /// [arrow::row] format for sort keys. While it is called "topK" for |
255 | | /// values like `1, 2, 3, 4, 5` the "top 3" really means the |
256 | | /// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`. |
257 | | /// |
258 | | /// Using the `Row` format handles things such as ascending vs |
259 | | /// descending and nulls first vs nulls last. |
260 | | struct TopKHeap { |
261 | | /// The maximum number of elements to store in this heap. |
262 | | k: usize, |
263 | | /// The target number of rows for output batches |
264 | | batch_size: usize, |
265 | | /// Storage for up at most `k` items using a BinaryHeap. Reverserd |
266 | | /// so that the smallest k so far is on the top |
267 | | inner: BinaryHeap<TopKRow>, |
268 | | /// Storage the original row values (TopKRow only has the sort key) |
269 | | store: RecordBatchStore, |
270 | | /// The size of all owned data held by this heap |
271 | | owned_bytes: usize, |
272 | | } |
273 | | |
274 | | impl TopKHeap { |
275 | 5 | fn new(k: usize, batch_size: usize, schema: SchemaRef) -> Self { |
276 | 5 | assert!(k > 0); |
277 | 5 | Self { |
278 | 5 | k, |
279 | 5 | batch_size, |
280 | 5 | inner: BinaryHeap::new(), |
281 | 5 | store: RecordBatchStore::new(schema), |
282 | 5 | owned_bytes: 0, |
283 | 5 | } |
284 | 5 | } |
285 | | |
286 | | /// Register a [`RecordBatch`] with the heap, returning the |
287 | | /// appropriate entry |
288 | 20 | pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry { |
289 | 20 | self.store.register(batch) |
290 | 20 | } |
291 | | |
292 | | /// Insert a [`RecordBatchEntry`] created by a previous call to |
293 | | /// [`Self::register_batch`] into storage. |
294 | 20 | pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) { |
295 | 20 | self.store.insert(entry) |
296 | 20 | } |
297 | | |
298 | | /// Returns the largest value stored by the heap if there are k |
299 | | /// items, otherwise returns None. Remember this structure is |
300 | | /// keeping the "smallest" k values |
301 | 2.00k | fn max(&self) -> Option<&TopKRow> { |
302 | 2.00k | if self.inner.len() < self.k { |
303 | 571 | None |
304 | | } else { |
305 | 1.42k | self.inner.peek() |
306 | | } |
307 | 2.00k | } |
308 | | |
309 | | /// Adds `row` to this heap. If inserting this new item would |
310 | | /// increase the size past `k`, removes the previously smallest |
311 | | /// item. |
312 | 676 | fn add( |
313 | 676 | &mut self, |
314 | 676 | batch_entry: &mut RecordBatchEntry, |
315 | 676 | row: impl AsRef<[u8]>, |
316 | 676 | index: usize, |
317 | 676 | ) { |
318 | 676 | let batch_id = batch_entry.id; |
319 | 676 | batch_entry.uses += 1; |
320 | 676 | |
321 | 676 | assert!(self.inner.len() <= self.k); |
322 | 676 | let row = row.as_ref(); |
323 | | |
324 | | // Reuse storage for evicted item if possible |
325 | 676 | let new_top_k = if self.inner.len() == self.k { |
326 | 105 | let prev_min = self.inner.pop().unwrap(); |
327 | 105 | |
328 | 105 | // Update batch use |
329 | 105 | if prev_min.batch_id == batch_entry.id { |
330 | 105 | batch_entry.uses -= 1; |
331 | 105 | } else { |
332 | 0 | self.store.unuse(prev_min.batch_id); |
333 | 0 | } |
334 | | |
335 | | // update memory accounting |
336 | 105 | self.owned_bytes -= prev_min.owned_size(); |
337 | 105 | prev_min.with_new_row(row, batch_id, index) |
338 | | } else { |
339 | 571 | TopKRow::new(row, batch_id, index) |
340 | | }; |
341 | | |
342 | 676 | self.owned_bytes += new_top_k.owned_size(); |
343 | 676 | |
344 | 676 | // put the new row into the heap |
345 | 676 | self.inner.push(new_top_k) |
346 | 676 | } |
347 | | |
348 | | /// Returns the values stored in this heap, from values low to |
349 | | /// high, as a single [`RecordBatch`], resetting the inner heap |
350 | 5 | pub fn emit(&mut self) -> Result<RecordBatch> { |
351 | 5 | Ok(self.emit_with_state()?0 .0) |
352 | 5 | } |
353 | | |
354 | | /// Returns the values stored in this heap, from values low to |
355 | | /// high, as a single [`RecordBatch`], and a sorted vec of the |
356 | | /// current heap's contents |
357 | 5 | pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec<TopKRow>)> { |
358 | 5 | let schema = Arc::clone(self.store.schema()); |
359 | 5 | |
360 | 5 | // generate sorted rows |
361 | 5 | let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec(); |
362 | 5 | |
363 | 5 | if self.store.is_empty() { |
364 | 0 | return Ok((RecordBatch::new_empty(schema), topk_rows)); |
365 | 5 | } |
366 | 5 | |
367 | 5 | // Indices for each row within its respective RecordBatch |
368 | 5 | let indices: Vec<_> = topk_rows |
369 | 5 | .iter() |
370 | 5 | .enumerate() |
371 | 571 | .map(|(i, k)| (i, k.index)) |
372 | 5 | .collect(); |
373 | 5 | |
374 | 5 | let num_columns = schema.fields().len(); |
375 | | |
376 | | // build the output columns one at time, using the |
377 | | // `interleave` kernel to pick rows from different arrays |
378 | 5 | let output_columns: Vec<_> = (0..num_columns) |
379 | 13 | .map(|col| { |
380 | 13 | let input_arrays: Vec<_> = topk_rows |
381 | 13 | .iter() |
382 | 1.71k | .map(|k| { |
383 | 1.71k | let entry = |
384 | 1.71k | self.store.get(k.batch_id).expect("invalid stored batch id"); |
385 | 1.71k | entry.batch.column(col) as &dyn Array |
386 | 1.71k | }) |
387 | 13 | .collect(); |
388 | 13 | |
389 | 13 | // at this point `indices` contains indexes within the |
390 | 13 | // rows and `input_arrays` contains a reference to the |
391 | 13 | // relevant Array for that index. `interleave` pulls |
392 | 13 | // them together into a single new array |
393 | 13 | Ok(interleave(&input_arrays, &indices)?0 ) |
394 | 13 | }) |
395 | 5 | .collect::<Result<_>>()?0 ; |
396 | | |
397 | 5 | let new_batch = RecordBatch::try_new(schema, output_columns)?0 ; |
398 | 5 | Ok((new_batch, topk_rows)) |
399 | 5 | } |
400 | | |
401 | | /// Compact this heap, rewriting all stored batches into a single |
402 | | /// input batch |
403 | 20 | pub fn maybe_compact(&mut self) -> Result<()> { |
404 | 20 | // we compact if the number of "unused" rows in the store is |
405 | 20 | // past some pre-defined threshold. Target holding up to |
406 | 20 | // around 20 batches, but handle cases of large k where some |
407 | 20 | // batches might be partially full |
408 | 20 | let max_unused_rows = (20 * self.batch_size) + self.k; |
409 | 20 | let unused_rows = self.store.unused_rows(); |
410 | 20 | |
411 | 20 | // don't compact if the store has one extra batch or |
412 | 20 | // unused rows is under the threshold |
413 | 20 | if self.store.len() <= 2 || unused_rows < max_unused_rows2 { |
414 | 20 | return Ok(()); |
415 | 0 | } |
416 | 0 | // at first, compact the entire thing always into a new batch |
417 | 0 | // (maybe we can get fancier in the future about ignoring |
418 | 0 | // batches that have a high usage ratio already |
419 | 0 |
|
420 | 0 | // Note: new batch is in the same order as inner |
421 | 0 | let num_rows = self.inner.len(); |
422 | 0 | let (new_batch, mut topk_rows) = self.emit_with_state()?; |
423 | | |
424 | | // clear all old entries in store (this invalidates all |
425 | | // store_ids in `inner`) |
426 | 0 | self.store.clear(); |
427 | 0 |
|
428 | 0 | let mut batch_entry = self.register_batch(new_batch); |
429 | 0 | batch_entry.uses = num_rows; |
430 | | |
431 | | // rewrite all existing entries to use the new batch, and |
432 | | // remove old entries. The sortedness and their relative |
433 | | // position do not change |
434 | 0 | for (i, topk_row) in topk_rows.iter_mut().enumerate() { |
435 | 0 | topk_row.batch_id = batch_entry.id; |
436 | 0 | topk_row.index = i; |
437 | 0 | } |
438 | 0 | self.insert_batch_entry(batch_entry); |
439 | 0 | // restore the heap |
440 | 0 | self.inner = BinaryHeap::from(topk_rows); |
441 | 0 |
|
442 | 0 | Ok(()) |
443 | 20 | } |
444 | | |
445 | | /// return the size of memory used by this heap, in bytes |
446 | 20 | fn size(&self) -> usize { |
447 | 20 | std::mem::size_of::<Self>() |
448 | 20 | + (self.inner.capacity() * std::mem::size_of::<TopKRow>()) |
449 | 20 | + self.store.size() |
450 | 20 | + self.owned_bytes |
451 | 20 | } |
452 | | } |
453 | | |
454 | | /// Represents one of the top K rows held in this heap. Orders |
455 | | /// according to memcmp of row (e.g. the arrow Row format, but could |
456 | | /// also be primitive values) |
457 | | /// |
458 | | /// Reuses allocations to minimize runtime overhead of creating new Vecs |
459 | | #[derive(Debug, PartialEq)] |
460 | | struct TopKRow { |
461 | | /// the value of the sort key for this row. This contains the |
462 | | /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to |
463 | | /// reuse allocations. |
464 | | row: Vec<u8>, |
465 | | /// the RecordBatch this row came from: an id into a [`RecordBatchStore`] |
466 | | batch_id: u32, |
467 | | /// the index in this record batch the row came from |
468 | | index: usize, |
469 | | } |
470 | | |
471 | | impl TopKRow { |
472 | | /// Create a new TopKRow with new allocation |
473 | 571 | fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self { |
474 | 571 | Self { |
475 | 571 | row: row.as_ref().to_vec(), |
476 | 571 | batch_id, |
477 | 571 | index, |
478 | 571 | } |
479 | 571 | } |
480 | | |
481 | | /// Create a new TopKRow reusing the existing allocation |
482 | 105 | fn with_new_row( |
483 | 105 | self, |
484 | 105 | new_row: impl AsRef<[u8]>, |
485 | 105 | batch_id: u32, |
486 | 105 | index: usize, |
487 | 105 | ) -> Self { |
488 | 105 | let Self { |
489 | 105 | mut row, |
490 | 105 | batch_id: _, |
491 | 105 | index: _, |
492 | 105 | } = self; |
493 | 105 | row.clear(); |
494 | 105 | row.extend_from_slice(new_row.as_ref()); |
495 | 105 | |
496 | 105 | Self { |
497 | 105 | row, |
498 | 105 | batch_id, |
499 | 105 | index, |
500 | 105 | } |
501 | 105 | } |
502 | | |
503 | | /// Returns the number of bytes owned by this row in the heap (not |
504 | | /// including itself) |
505 | 781 | fn owned_size(&self) -> usize { |
506 | 781 | self.row.capacity() |
507 | 781 | } |
508 | | |
509 | | /// Returns a slice to the owned row value |
510 | 1.42k | fn row(&self) -> &[u8] { |
511 | 1.42k | self.row.as_slice() |
512 | 1.42k | } |
513 | | } |
514 | | |
515 | | impl Eq for TopKRow {} |
516 | | |
517 | | impl PartialOrd for TopKRow { |
518 | 7.39k | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
519 | 7.39k | Some(self.cmp(other)) |
520 | 7.39k | } |
521 | | } |
522 | | |
523 | | impl Ord for TopKRow { |
524 | 7.39k | fn cmp(&self, other: &Self) -> Ordering { |
525 | 7.39k | self.row.cmp(&other.row) |
526 | 7.39k | } |
527 | | } |
528 | | |
529 | | #[derive(Debug)] |
530 | | struct RecordBatchEntry { |
531 | | id: u32, |
532 | | batch: RecordBatch, |
533 | | // for this batch, how many times has it been used |
534 | | uses: usize, |
535 | | } |
536 | | |
537 | | /// This structure tracks [`RecordBatch`] by an id so that: |
538 | | /// |
539 | | /// 1. The baches can be tracked via an id that can be copied cheaply |
540 | | /// 2. The total memory held by all batches is tracked |
541 | | #[derive(Debug)] |
542 | | struct RecordBatchStore { |
543 | | /// id generator |
544 | | next_id: u32, |
545 | | /// storage |
546 | | batches: HashMap<u32, RecordBatchEntry>, |
547 | | /// total size of all record batches tracked by this store |
548 | | batches_size: usize, |
549 | | /// schema of the batches |
550 | | schema: SchemaRef, |
551 | | } |
552 | | |
553 | | impl RecordBatchStore { |
554 | 5 | fn new(schema: SchemaRef) -> Self { |
555 | 5 | Self { |
556 | 5 | next_id: 0, |
557 | 5 | batches: HashMap::new(), |
558 | 5 | batches_size: 0, |
559 | 5 | schema, |
560 | 5 | } |
561 | 5 | } |
562 | | |
563 | | /// Register this batch with the store and assign an ID. No |
564 | | /// attempt is made to compare this batch to other batches |
565 | 20 | pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry { |
566 | 20 | let id = self.next_id; |
567 | 20 | self.next_id += 1; |
568 | 20 | RecordBatchEntry { id, batch, uses: 0 } |
569 | 20 | } |
570 | | |
571 | | /// Insert a record batch entry into this store, tracking its |
572 | | /// memory use, if it has any uses |
573 | 20 | pub fn insert(&mut self, entry: RecordBatchEntry) { |
574 | 20 | // uses of 0 means that none of the rows in the batch were stored in the topk |
575 | 20 | if entry.uses > 0 { |
576 | 9 | self.batches_size += entry.batch.get_array_memory_size(); |
577 | 9 | self.batches.insert(entry.id, entry); |
578 | 11 | } |
579 | 20 | } |
580 | | |
581 | | /// Clear all values in this store, invalidating all previous batch ids |
582 | 0 | fn clear(&mut self) { |
583 | 0 | self.batches.clear(); |
584 | 0 | self.batches_size = 0; |
585 | 0 | } |
586 | | |
587 | 1.71k | fn get(&self, id: u32) -> Option<&RecordBatchEntry> { |
588 | 1.71k | self.batches.get(&id) |
589 | 1.71k | } |
590 | | |
591 | | /// returns the total number of batches stored in this store |
592 | 20 | fn len(&self) -> usize { |
593 | 20 | self.batches.len() |
594 | 20 | } |
595 | | |
596 | | /// Returns the total number of rows in batches minus the number |
597 | | /// which are in use |
598 | 20 | fn unused_rows(&self) -> usize { |
599 | 20 | self.batches |
600 | 20 | .values() |
601 | 31 | .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses) |
602 | 20 | .sum() |
603 | 20 | } |
604 | | |
605 | | /// returns true if the store has nothing stored |
606 | 5 | fn is_empty(&self) -> bool { |
607 | 5 | self.batches.is_empty() |
608 | 5 | } |
609 | | |
610 | | /// return the schema of batches stored |
611 | 5 | fn schema(&self) -> &SchemaRef { |
612 | 5 | &self.schema |
613 | 5 | } |
614 | | |
615 | | /// remove a use from the specified batch id. If the use count |
616 | | /// reaches zero the batch entry is removed from the store |
617 | | /// |
618 | | /// panics if there were no remaining uses of id |
619 | 0 | pub fn unuse(&mut self, id: u32) { |
620 | 0 | let remove = if let Some(batch_entry) = self.batches.get_mut(&id) { |
621 | 0 | batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow"); |
622 | 0 | batch_entry.uses == 0 |
623 | | } else { |
624 | 0 | panic!("No entry for id {id}"); |
625 | | }; |
626 | | |
627 | 0 | if remove { |
628 | 0 | let old_entry = self.batches.remove(&id).unwrap(); |
629 | 0 | self.batches_size = self |
630 | 0 | .batches_size |
631 | 0 | .checked_sub(old_entry.batch.get_array_memory_size()) |
632 | 0 | .unwrap(); |
633 | 0 | } |
634 | 0 | } |
635 | | |
636 | | /// returns the size of memory used by this store, including all |
637 | | /// referenced `RecordBatch`es, in bytes |
638 | 20 | pub fn size(&self) -> usize { |
639 | 20 | std::mem::size_of::<Self>() |
640 | 20 | + self.batches.capacity() |
641 | 20 | * (std::mem::size_of::<u32>() + std::mem::size_of::<RecordBatchEntry>()) |
642 | 20 | + self.batches_size |
643 | 20 | } |
644 | | } |