Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/topk/mod.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! TopK: Combination of Sort / LIMIT
19
20
use arrow::{
21
    compute::interleave,
22
    row::{RowConverter, Rows, SortField},
23
};
24
use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
25
26
use arrow_array::{Array, ArrayRef, RecordBatch};
27
use arrow_schema::SchemaRef;
28
use datafusion_common::Result;
29
use datafusion_execution::{
30
    memory_pool::{MemoryConsumer, MemoryReservation},
31
    runtime_env::RuntimeEnv,
32
};
33
use datafusion_physical_expr::PhysicalSortExpr;
34
use hashbrown::HashMap;
35
36
use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
37
38
use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};
39
40
/// Global TopK
41
///
42
/// # Background
43
///
44
/// "Top K" is a common query optimization used for queries such as
45
/// "find the top 3 customers by revenue". The (simplified) SQL for
46
/// such a query might be:
47
///
48
/// ```sql
49
/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3;
50
/// ```
51
///
52
/// The simple plan would be:
53
///
54
/// ```sql
55
/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3;
56
/// +--------------+----------------------------------------+
57
/// | plan_type    | plan                                   |
58
/// +--------------+----------------------------------------+
59
/// | logical_plan | Limit: 3                               |
60
/// |              |   Sort: revenue DESC NULLS FIRST       |
61
/// |              |     Projection: customer_id, revenue   |
62
/// |              |       TableScan: sales                 |
63
/// +--------------+----------------------------------------+
64
/// ```
65
///
66
/// While this plan produces the correct answer, it will fully sorts the
67
/// input before discarding everything other than the top 3 elements.
68
///
69
/// The same answer can be produced by simply keeping track of the top
70
/// K=3 elements, reducing the total amount of required buffer memory.
71
///
72
/// # Structure
73
///
74
/// This operator tracks the top K items using a `TopKHeap`.
75
pub struct TopK {
76
    /// schema of the output (and the input)
77
    schema: SchemaRef,
78
    /// Runtime metrics
79
    metrics: TopKMetrics,
80
    /// Reservation
81
    reservation: MemoryReservation,
82
    /// The target number of rows for output batches
83
    batch_size: usize,
84
    /// sort expressions
85
    expr: Arc<[PhysicalSortExpr]>,
86
    /// row converter, for sort keys
87
    row_converter: RowConverter,
88
    /// scratch space for converting rows
89
    scratch_rows: Rows,
90
    /// stores the top k values and their sort key values, in order
91
    heap: TopKHeap,
92
}
93
94
impl TopK {
95
    /// Create a new [`TopK`] that stores the top `k` values, as
96
    /// defined by the sort expressions in `expr`.
97
    // TODO: make a builder or some other nicer API to avoid the
98
    // clippy warning
99
    #[allow(clippy::too_many_arguments)]
100
5
    pub fn try_new(
101
5
        partition_id: usize,
102
5
        schema: SchemaRef,
103
5
        expr: Vec<PhysicalSortExpr>,
104
5
        k: usize,
105
5
        batch_size: usize,
106
5
        runtime: Arc<RuntimeEnv>,
107
5
        metrics: &ExecutionPlanMetricsSet,
108
5
        partition: usize,
109
5
    ) -> Result<Self> {
110
5
        let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
111
5
            .register(&runtime.memory_pool);
112
5
113
5
        let expr: Arc<[PhysicalSortExpr]> = expr.into();
114
115
5
        let sort_fields: Vec<_> = expr
116
5
            .iter()
117
13
            .map(|e| {
118
13
                Ok(SortField::new_with_options(
119
13
                    e.expr.data_type(&schema)
?0
,
120
13
                    e.options,
121
                ))
122
13
            })
123
5
            .collect::<Result<_>>()
?0
;
124
125
        // TODO there is potential to add special cases for single column sort fields
126
        // to improve performance
127
5
        let row_converter = RowConverter::new(sort_fields)
?0
;
128
5
        let scratch_rows = row_converter.empty_rows(
129
5
            batch_size,
130
5
            20 * batch_size, // guestimate 20 bytes per row
131
5
        );
132
5
133
5
        Ok(Self {
134
5
            schema: Arc::clone(&schema),
135
5
            metrics: TopKMetrics::new(metrics, partition),
136
5
            reservation,
137
5
            batch_size,
138
5
            expr,
139
5
            row_converter,
140
5
            scratch_rows,
141
5
            heap: TopKHeap::new(k, batch_size, schema),
142
5
        })
143
5
    }
144
145
    /// Insert `batch`, remembering if any of its values are among
146
    /// the top k seen so far.
147
20
    pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
148
20
        // Updates on drop
149
20
        let _timer = self.metrics.baseline.elapsed_compute().timer();
150
151
20
        let sort_keys: Vec<ArrayRef> = self
152
20
            .expr
153
20
            .iter()
154
52
            .map(|expr| {
155
52
                let value = expr.expr.evaluate(&batch)
?0
;
156
52
                value.into_array(batch.num_rows())
157
52
            })
158
20
            .collect::<Result<Vec<_>>>()
?0
;
159
160
        // reuse existing `Rows` to avoid reallocations
161
20
        let rows = &mut self.scratch_rows;
162
20
        rows.clear();
163
20
        self.row_converter.append(rows, &sort_keys)
?0
;
164
165
        // TODO make this algorithmically better?:
166
        // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`)
167
        //       this avoids some work and also might be better vectorizable.
168
20
        let mut batch_entry = self.heap.register_batch(batch);
169
2.00k
        for (index, row) in 
rows.iter().enumerate()20
{
170
2.00k
            match self.heap.max() {
171
                // heap has k items, and the new row is greater than the
172
                // current max in the heap ==> it is not a new topk
173
1.42k
                Some(
max_row1.32k
) if row.as_ref() >= max_row.row(
) => {}1.32k
174
                // don't yet have k items or new item is lower than the currently k low values
175
676
                None | Some(_) => {
176
676
                    self.heap.add(&mut batch_entry, row, index);
177
676
                    self.metrics.row_replacements.add(1);
178
676
                }
179
            }
180
        }
181
20
        self.heap.insert_batch_entry(batch_entry);
182
20
183
20
        // conserve memory
184
20
        self.heap.maybe_compact()
?0
;
185
186
        // update memory reservation
187
20
        self.reservation.try_resize(self.size())
?0
;
188
20
        Ok(())
189
20
    }
190
191
    /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap
192
5
    pub fn emit(self) -> Result<SendableRecordBatchStream> {
193
5
        let Self {
194
5
            schema,
195
5
            metrics,
196
5
            reservation: _,
197
5
            batch_size,
198
5
            expr: _,
199
5
            row_converter: _,
200
5
            scratch_rows: _,
201
5
            mut heap,
202
5
        } = self;
203
5
        let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop
204
205
5
        let mut batch = heap.emit()
?0
;
206
5
        metrics.baseline.output_rows().add(batch.num_rows());
207
5
208
5
        // break into record batches as needed
209
5
        let mut batches = vec![];
210
        loop {
211
5
            if batch.num_rows() <= batch_size {
212
5
                batches.push(Ok(batch));
213
5
                break;
214
0
            } else {
215
0
                batches.push(Ok(batch.slice(0, batch_size)));
216
0
                let remaining_length = batch.num_rows() - batch_size;
217
0
                batch = batch.slice(batch_size, remaining_length);
218
0
            }
219
        }
220
5
        Ok(Box::pin(RecordBatchStreamAdapter::new(
221
5
            schema,
222
5
            futures::stream::iter(batches),
223
5
        )))
224
5
    }
225
226
    /// return the size of memory used by this operator, in bytes
227
20
    fn size(&self) -> usize {
228
20
        std::mem::size_of::<Self>()
229
20
            + self.row_converter.size()
230
20
            + self.scratch_rows.size()
231
20
            + self.heap.size()
232
20
    }
233
}
234
235
struct TopKMetrics {
236
    /// metrics
237
    pub baseline: BaselineMetrics,
238
239
    /// count of how many rows were replaced in the heap
240
    pub row_replacements: Count,
241
}
242
243
impl TopKMetrics {
244
5
    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
245
5
        Self {
246
5
            baseline: BaselineMetrics::new(metrics, partition),
247
5
            row_replacements: MetricBuilder::new(metrics)
248
5
                .counter("row_replacements", partition),
249
5
        }
250
5
    }
251
}
252
253
/// This structure keeps at most the *smallest* k items, using the
254
/// [arrow::row] format for sort keys. While it is called "topK" for
255
/// values like `1, 2, 3, 4, 5` the "top 3" really means the
256
/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`.
257
///
258
/// Using the `Row` format handles things such as ascending vs
259
/// descending and nulls first vs nulls last.
260
struct TopKHeap {
261
    /// The maximum number of elements to store in this heap.
262
    k: usize,
263
    /// The target number of rows for output batches
264
    batch_size: usize,
265
    /// Storage for up at most `k` items using a BinaryHeap. Reverserd
266
    /// so that the smallest k so far is on the top
267
    inner: BinaryHeap<TopKRow>,
268
    /// Storage the original row values (TopKRow only has the sort key)
269
    store: RecordBatchStore,
270
    /// The size of all owned data held by this heap
271
    owned_bytes: usize,
272
}
273
274
impl TopKHeap {
275
5
    fn new(k: usize, batch_size: usize, schema: SchemaRef) -> Self {
276
5
        assert!(k > 0);
277
5
        Self {
278
5
            k,
279
5
            batch_size,
280
5
            inner: BinaryHeap::new(),
281
5
            store: RecordBatchStore::new(schema),
282
5
            owned_bytes: 0,
283
5
        }
284
5
    }
285
286
    /// Register a [`RecordBatch`] with the heap, returning the
287
    /// appropriate entry
288
20
    pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry {
289
20
        self.store.register(batch)
290
20
    }
291
292
    /// Insert a [`RecordBatchEntry`] created by a previous call to
293
    /// [`Self::register_batch`] into storage.
294
20
    pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) {
295
20
        self.store.insert(entry)
296
20
    }
297
298
    /// Returns the largest value stored by the heap if there are k
299
    /// items, otherwise returns None. Remember this structure is
300
    /// keeping the "smallest" k values
301
2.00k
    fn max(&self) -> Option<&TopKRow> {
302
2.00k
        if self.inner.len() < self.k {
303
571
            None
304
        } else {
305
1.42k
            self.inner.peek()
306
        }
307
2.00k
    }
308
309
    /// Adds `row` to this heap. If inserting this new item would
310
    /// increase the size past `k`, removes the previously smallest
311
    /// item.
312
676
    fn add(
313
676
        &mut self,
314
676
        batch_entry: &mut RecordBatchEntry,
315
676
        row: impl AsRef<[u8]>,
316
676
        index: usize,
317
676
    ) {
318
676
        let batch_id = batch_entry.id;
319
676
        batch_entry.uses += 1;
320
676
321
676
        assert!(self.inner.len() <= self.k);
322
676
        let row = row.as_ref();
323
324
        // Reuse storage for evicted item if possible
325
676
        let new_top_k = if self.inner.len() == self.k {
326
105
            let prev_min = self.inner.pop().unwrap();
327
105
328
105
            // Update batch use
329
105
            if prev_min.batch_id == batch_entry.id {
330
105
                batch_entry.uses -= 1;
331
105
            } else {
332
0
                self.store.unuse(prev_min.batch_id);
333
0
            }
334
335
            // update memory accounting
336
105
            self.owned_bytes -= prev_min.owned_size();
337
105
            prev_min.with_new_row(row, batch_id, index)
338
        } else {
339
571
            TopKRow::new(row, batch_id, index)
340
        };
341
342
676
        self.owned_bytes += new_top_k.owned_size();
343
676
344
676
        // put the new row into the heap
345
676
        self.inner.push(new_top_k)
346
676
    }
347
348
    /// Returns the values stored in this heap, from values low to
349
    /// high, as a single [`RecordBatch`], resetting the inner heap
350
5
    pub fn emit(&mut self) -> Result<RecordBatch> {
351
5
        Ok(self.emit_with_state()
?0
.0)
352
5
    }
353
354
    /// Returns the values stored in this heap, from values low to
355
    /// high, as a single [`RecordBatch`], and a sorted vec of the
356
    /// current heap's contents
357
5
    pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec<TopKRow>)> {
358
5
        let schema = Arc::clone(self.store.schema());
359
5
360
5
        // generate sorted rows
361
5
        let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
362
5
363
5
        if self.store.is_empty() {
364
0
            return Ok((RecordBatch::new_empty(schema), topk_rows));
365
5
        }
366
5
367
5
        // Indices for each row within its respective RecordBatch
368
5
        let indices: Vec<_> = topk_rows
369
5
            .iter()
370
5
            .enumerate()
371
571
            .map(|(i, k)| (i, k.index))
372
5
            .collect();
373
5
374
5
        let num_columns = schema.fields().len();
375
376
        // build the output columns one at time, using the
377
        // `interleave` kernel to pick rows from different arrays
378
5
        let output_columns: Vec<_> = (0..num_columns)
379
13
            .map(|col| {
380
13
                let input_arrays: Vec<_> = topk_rows
381
13
                    .iter()
382
1.71k
                    .map(|k| {
383
1.71k
                        let entry =
384
1.71k
                            self.store.get(k.batch_id).expect("invalid stored batch id");
385
1.71k
                        entry.batch.column(col) as &dyn Array
386
1.71k
                    })
387
13
                    .collect();
388
13
389
13
                // at this point `indices` contains indexes within the
390
13
                // rows and `input_arrays` contains a reference to the
391
13
                // relevant Array for that index. `interleave` pulls
392
13
                // them together into a single new array
393
13
                Ok(interleave(&input_arrays, &indices)
?0
)
394
13
            })
395
5
            .collect::<Result<_>>()
?0
;
396
397
5
        let new_batch = RecordBatch::try_new(schema, output_columns)
?0
;
398
5
        Ok((new_batch, topk_rows))
399
5
    }
400
401
    /// Compact this heap, rewriting all stored batches into a single
402
    /// input batch
403
20
    pub fn maybe_compact(&mut self) -> Result<()> {
404
20
        // we compact if the number of "unused" rows in the store is
405
20
        // past some pre-defined threshold. Target holding up to
406
20
        // around 20 batches, but handle cases of large k where some
407
20
        // batches might be partially full
408
20
        let max_unused_rows = (20 * self.batch_size) + self.k;
409
20
        let unused_rows = self.store.unused_rows();
410
20
411
20
        // don't compact if the store has one extra batch or
412
20
        // unused rows is under the threshold
413
20
        if self.store.len() <= 2 || 
unused_rows < max_unused_rows2
{
414
20
            return Ok(());
415
0
        }
416
0
        // at first, compact the entire thing always into a new batch
417
0
        // (maybe we can get fancier in the future about ignoring
418
0
        // batches that have a high usage ratio already
419
0
420
0
        // Note: new batch is in the same order as inner
421
0
        let num_rows = self.inner.len();
422
0
        let (new_batch, mut topk_rows) = self.emit_with_state()?;
423
424
        // clear all old entries in store (this invalidates all
425
        // store_ids in `inner`)
426
0
        self.store.clear();
427
0
428
0
        let mut batch_entry = self.register_batch(new_batch);
429
0
        batch_entry.uses = num_rows;
430
431
        // rewrite all existing entries to use the new batch, and
432
        // remove old entries. The sortedness and their relative
433
        // position do not change
434
0
        for (i, topk_row) in topk_rows.iter_mut().enumerate() {
435
0
            topk_row.batch_id = batch_entry.id;
436
0
            topk_row.index = i;
437
0
        }
438
0
        self.insert_batch_entry(batch_entry);
439
0
        // restore the heap
440
0
        self.inner = BinaryHeap::from(topk_rows);
441
0
442
0
        Ok(())
443
20
    }
444
445
    /// return the size of memory used by this heap, in bytes
446
20
    fn size(&self) -> usize {
447
20
        std::mem::size_of::<Self>()
448
20
            + (self.inner.capacity() * std::mem::size_of::<TopKRow>())
449
20
            + self.store.size()
450
20
            + self.owned_bytes
451
20
    }
452
}
453
454
/// Represents one of the top K rows held in this heap. Orders
455
/// according to memcmp of row (e.g. the arrow Row format, but could
456
/// also be primitive values)
457
///
458
/// Reuses allocations to minimize runtime overhead of creating new Vecs
459
#[derive(Debug, PartialEq)]
460
struct TopKRow {
461
    /// the value of the sort key for this row. This contains the
462
    /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to
463
    /// reuse allocations.
464
    row: Vec<u8>,
465
    /// the RecordBatch this row came from: an id into a [`RecordBatchStore`]
466
    batch_id: u32,
467
    /// the index in this record batch the row came from
468
    index: usize,
469
}
470
471
impl TopKRow {
472
    /// Create a new TopKRow with new allocation
473
571
    fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self {
474
571
        Self {
475
571
            row: row.as_ref().to_vec(),
476
571
            batch_id,
477
571
            index,
478
571
        }
479
571
    }
480
481
    /// Create a new  TopKRow reusing the existing allocation
482
105
    fn with_new_row(
483
105
        self,
484
105
        new_row: impl AsRef<[u8]>,
485
105
        batch_id: u32,
486
105
        index: usize,
487
105
    ) -> Self {
488
105
        let Self {
489
105
            mut row,
490
105
            batch_id: _,
491
105
            index: _,
492
105
        } = self;
493
105
        row.clear();
494
105
        row.extend_from_slice(new_row.as_ref());
495
105
496
105
        Self {
497
105
            row,
498
105
            batch_id,
499
105
            index,
500
105
        }
501
105
    }
502
503
    /// Returns the number of bytes owned by this row in the heap (not
504
    /// including itself)
505
781
    fn owned_size(&self) -> usize {
506
781
        self.row.capacity()
507
781
    }
508
509
    /// Returns a slice to the owned row value
510
1.42k
    fn row(&self) -> &[u8] {
511
1.42k
        self.row.as_slice()
512
1.42k
    }
513
}
514
515
impl Eq for TopKRow {}
516
517
impl PartialOrd for TopKRow {
518
7.39k
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
519
7.39k
        Some(self.cmp(other))
520
7.39k
    }
521
}
522
523
impl Ord for TopKRow {
524
7.39k
    fn cmp(&self, other: &Self) -> Ordering {
525
7.39k
        self.row.cmp(&other.row)
526
7.39k
    }
527
}
528
529
#[derive(Debug)]
530
struct RecordBatchEntry {
531
    id: u32,
532
    batch: RecordBatch,
533
    // for this batch, how many times has it been used
534
    uses: usize,
535
}
536
537
/// This structure tracks [`RecordBatch`] by an id so that:
538
///
539
/// 1. The baches can be tracked via an id that can be copied cheaply
540
/// 2. The total memory held by all batches is tracked
541
#[derive(Debug)]
542
struct RecordBatchStore {
543
    /// id generator
544
    next_id: u32,
545
    /// storage
546
    batches: HashMap<u32, RecordBatchEntry>,
547
    /// total size of all record batches tracked by this store
548
    batches_size: usize,
549
    /// schema of the batches
550
    schema: SchemaRef,
551
}
552
553
impl RecordBatchStore {
554
5
    fn new(schema: SchemaRef) -> Self {
555
5
        Self {
556
5
            next_id: 0,
557
5
            batches: HashMap::new(),
558
5
            batches_size: 0,
559
5
            schema,
560
5
        }
561
5
    }
562
563
    /// Register this batch with the store and assign an ID. No
564
    /// attempt is made to compare this batch to other batches
565
20
    pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry {
566
20
        let id = self.next_id;
567
20
        self.next_id += 1;
568
20
        RecordBatchEntry { id, batch, uses: 0 }
569
20
    }
570
571
    /// Insert a record batch entry into this store, tracking its
572
    /// memory use, if it has any uses
573
20
    pub fn insert(&mut self, entry: RecordBatchEntry) {
574
20
        // uses of 0 means that none of the rows in the batch were stored in the topk
575
20
        if entry.uses > 0 {
576
9
            self.batches_size += entry.batch.get_array_memory_size();
577
9
            self.batches.insert(entry.id, entry);
578
11
        }
579
20
    }
580
581
    /// Clear all values in this store, invalidating all previous batch ids
582
0
    fn clear(&mut self) {
583
0
        self.batches.clear();
584
0
        self.batches_size = 0;
585
0
    }
586
587
1.71k
    fn get(&self, id: u32) -> Option<&RecordBatchEntry> {
588
1.71k
        self.batches.get(&id)
589
1.71k
    }
590
591
    /// returns the total number of batches stored in this store
592
20
    fn len(&self) -> usize {
593
20
        self.batches.len()
594
20
    }
595
596
    /// Returns the total number of rows in batches minus the number
597
    /// which are in use
598
20
    fn unused_rows(&self) -> usize {
599
20
        self.batches
600
20
            .values()
601
31
            .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses)
602
20
            .sum()
603
20
    }
604
605
    /// returns true if the store has nothing stored
606
5
    fn is_empty(&self) -> bool {
607
5
        self.batches.is_empty()
608
5
    }
609
610
    /// return the schema of batches stored
611
5
    fn schema(&self) -> &SchemaRef {
612
5
        &self.schema
613
5
    }
614
615
    /// remove a use from the specified batch id. If the use count
616
    /// reaches zero the batch entry is removed from the store
617
    ///
618
    /// panics if there were no remaining uses of id
619
0
    pub fn unuse(&mut self, id: u32) {
620
0
        let remove = if let Some(batch_entry) = self.batches.get_mut(&id) {
621
0
            batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow");
622
0
            batch_entry.uses == 0
623
        } else {
624
0
            panic!("No entry for id {id}");
625
        };
626
627
0
        if remove {
628
0
            let old_entry = self.batches.remove(&id).unwrap();
629
0
            self.batches_size = self
630
0
                .batches_size
631
0
                .checked_sub(old_entry.batch.get_array_memory_size())
632
0
                .unwrap();
633
0
        }
634
0
    }
635
636
    /// returns the size of memory used by this store, including all
637
    /// referenced `RecordBatch`es, in bytes
638
20
    pub fn size(&self) -> usize {
639
20
        std::mem::size_of::<Self>()
640
20
            + self.batches.capacity()
641
20
                * (std::mem::size_of::<u32>() + std::mem::size_of::<RecordBatchEntry>())
642
20
            + self.batches_size
643
20
    }
644
}