Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/sorts/merge.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Merge that deals with an arbitrary size of streaming inputs.
19
//! This is an order-preserving merge.
20
21
use std::collections::VecDeque;
22
use std::pin::Pin;
23
use std::sync::Arc;
24
use std::task::{ready, Context, Poll};
25
26
use crate::metrics::BaselineMetrics;
27
use crate::sorts::builder::BatchBuilder;
28
use crate::sorts::cursor::{Cursor, CursorValues};
29
use crate::sorts::stream::PartitionedStream;
30
use crate::RecordBatchStream;
31
32
use arrow::datatypes::SchemaRef;
33
use arrow::record_batch::RecordBatch;
34
use datafusion_common::Result;
35
use datafusion_execution::memory_pool::MemoryReservation;
36
37
use futures::Stream;
38
39
/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
40
type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch)>>>;
41
42
/// Merges a stream of sorted cursors and record batches into a single sorted stream
43
#[derive(Debug)]
44
pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
45
    in_progress: BatchBuilder,
46
47
    /// The sorted input streams to merge together
48
    streams: CursorStream<C>,
49
50
    /// used to record execution metrics
51
    metrics: BaselineMetrics,
52
53
    /// If the stream has encountered an error
54
    aborted: bool,
55
56
    /// A loser tree that always produces the minimum cursor
57
    ///
58
    /// Node 0 stores the top winner, Nodes 1..num_streams store
59
    /// the loser nodes
60
    ///
61
    /// This implements a "Tournament Tree" (aka Loser Tree) to keep
62
    /// track of the current smallest element at the top. When the top
63
    /// record is taken, the tree structure is not modified, and only
64
    /// the path from bottom to top is visited, keeping the number of
65
    /// comparisons close to the theoretical limit of `log(S)`.
66
    ///
67
    /// The current implementation uses a vector to store the tree.
68
    /// Conceptually, it looks like this (assuming 8 streams):
69
    ///
70
    /// ```text
71
    ///     0 (winner)
72
    ///
73
    ///     1
74
    ///    / \
75
    ///   2   3
76
    ///  / \ / \
77
    /// 4  5 6  7
78
    /// ```
79
    ///
80
    /// Where element at index 0 in the vector is the current winner. Element
81
    /// at index 1 is the root of the loser tree, element at index 2 is the
82
    /// left child of the root, and element at index 3 is the right child of
83
    /// the root and so on.
84
    ///
85
    /// reference: <https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
86
    loser_tree: Vec<usize>,
87
88
    /// If the most recently yielded overall winner has been replaced
89
    /// within the loser tree. A value of `false` indicates that the
90
    /// overall winner has been yielded but the loser tree has not
91
    /// been updated
92
    loser_tree_adjusted: bool,
93
94
    /// Target batch size
95
    batch_size: usize,
96
97
    /// Cursors for each input partition. `None` means the input is exhausted
98
    cursors: Vec<Option<Cursor<C>>>,
99
100
    /// Optional number of rows to fetch
101
    fetch: Option<usize>,
102
103
    /// number of rows produced
104
    produced: usize,
105
106
    /// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`,
107
    /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the
108
    /// vector to ensure the next iteration starts with a different partition, preventing the same partition
109
    /// from being continuously polled.
110
    uninitiated_partitions: VecDeque<usize>,
111
}
112
113
impl<C: CursorValues> SortPreservingMergeStream<C> {
114
19
    pub(crate) fn new(
115
19
        streams: CursorStream<C>,
116
19
        schema: SchemaRef,
117
19
        metrics: BaselineMetrics,
118
19
        batch_size: usize,
119
19
        fetch: Option<usize>,
120
19
        reservation: MemoryReservation,
121
19
    ) -> Self {
122
19
        let stream_count = streams.partitions();
123
19
124
19
        Self {
125
19
            in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation),
126
19
            streams,
127
19
            metrics,
128
19
            aborted: false,
129
56
            cursors: (0..stream_count).map(|_| None).collect(),
130
19
            loser_tree: vec![],
131
19
            loser_tree_adjusted: false,
132
19
            batch_size,
133
19
            fetch,
134
19
            produced: 0,
135
19
            uninitiated_partitions: (0..stream_count).collect(),
136
19
        }
137
19
    }
138
139
    /// If the stream at the given index is not exhausted, and the last cursor for the
140
    /// stream is finished, poll the stream for the next RecordBatch and create a new
141
    /// cursor for the stream from the returned result
142
14.7k
    fn maybe_poll_stream(
143
14.7k
        &mut self,
144
14.7k
        cx: &mut Context<'_>,
145
14.7k
        idx: usize,
146
14.7k
    ) -> Poll<Result<()>> {
147
14.7k
        if self.cursors[idx].is_some() {
148
            // Cursor is not finished - don't need a new RecordBatch yet
149
13.8k
            return Poll::Ready(Ok(()));
150
927
        }
151
152
927
        match 
futures::ready!197
(self.streams.poll_next(cx, idx)) {
153
54
            None => Poll::Ready(Ok(())),
154
0
            Some(Err(e)) => Poll::Ready(Err(e)),
155
676
            Some(Ok((cursor, batch))) => {
156
676
                self.cursors[idx] = Some(Cursor::new(cursor));
157
676
                Poll::Ready(self.in_progress.push_batch(idx, batch))
158
            }
159
        }
160
14.7k
    }
161
162
293
    fn poll_next_inner(
163
293
        &mut self,
164
293
        cx: &mut Context<'_>,
165
293
    ) -> Poll<Option<Result<RecordBatch>>> {
166
293
        if self.aborted {
167
0
            return Poll::Ready(None);
168
293
        }
169
293
        // Once all partitions have set their corresponding cursors for the loser tree,
170
293
        // we skip the following block. Until then, this function may be called multiple
171
293
        // times and can return Poll::Pending if any partition returns Poll::Pending.
172
293
        if self.loser_tree.is_empty() {
173
81
            let remaining_partitions = self.uninitiated_partitions.clone();
174
135
            for 
i117
in remaining_partitions {
175
117
                match self.maybe_poll_stream(cx, i) {
176
0
                    Poll::Ready(Err(e)) => {
177
0
                        self.aborted = true;
178
0
                        return Poll::Ready(Some(Err(e)));
179
                    }
180
                    Poll::Pending => {
181
                        // If a partition returns Poll::Pending, to avoid continuously polling it
182
                        // and potentially increasing upstream buffer sizes, we move it to the
183
                        // back of the polling queue.
184
63
                        if let Some(front) = self.uninitiated_partitions.pop_front() {
185
63
                            // This pop_front can never return `None`.
186
63
                            self.uninitiated_partitions.push_back(front);
187
63
                        }
0
188
                        // This function could remain in a pending state, so we manually wake it here.
189
                        // However, this approach can be investigated further to find a more natural way
190
                        // to avoid disrupting the runtime scheduler.
191
63
                        cx.waker().wake_by_ref();
192
63
                        return Poll::Pending;
193
                    }
194
54
                    _ => {
195
54
                        // If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None),
196
54
                        // we remove this partition from the queue so it is not polled again.
197
140
                        self.uninitiated_partitions.retain(|idx| *idx != i);
198
54
                    }
199
                }
200
            }
201
18
            self.init_loser_tree();
202
212
        }
203
204
        // NB timer records time taken on drop, so there are no
205
        // calls to `timer.done()` below.
206
230
        let elapsed_compute = self.metrics.elapsed_compute().clone();
207
230
        let _timer = elapsed_compute.timer();
208
209
        loop {
210
            // Adjust the loser tree if necessary, returning control if needed
211
14.6k
            if !self.loser_tree_adjusted {
212
14.6k
                let winner = self.loser_tree[0];
213
14.6k
                if let Err(
e0
) =
ready!134
(self.maybe_poll_stream(cx, winner)) {
214
0
                    self.aborted = true;
215
0
                    return Poll::Ready(Some(Err(e)));
216
14.4k
                }
217
14.4k
                self.update_loser_tree();
218
35
            }
219
220
14.5k
            let stream_idx = self.loser_tree[0];
221
14.5k
            if self.advance(stream_idx) {
222
14.4k
                self.loser_tree_adjusted = false;
223
14.4k
                self.in_progress.push_row(stream_idx);
224
14.4k
225
14.4k
                // stop sorting if fetch has been reached
226
14.4k
                if self.fetch_reached() {
227
0
                    self.aborted = true;
228
14.4k
                } else if self.in_progress.len() < self.batch_size {
229
14.4k
                    continue;
230
61
                }
231
35
            }
232
233
96
            self.produced += self.in_progress.len();
234
96
235
96
            return Poll::Ready(self.in_progress.build_record_batch().transpose());
236
        }
237
293
    }
238
239
14.4k
    fn fetch_reached(&mut self) -> bool {
240
14.4k
        self.fetch
241
14.4k
            .map(|fetch| 
self.produced + self.in_progress.len() >= fetch0
)
242
14.4k
            .unwrap_or(false)
243
14.4k
    }
244
245
14.5k
    fn advance(&mut self, stream_idx: usize) -> bool {
246
14.5k
        let slot = &mut self.cursors[stream_idx];
247
14.5k
        match slot.as_mut() {
248
14.4k
            Some(c) => {
249
14.4k
                c.advance();
250
14.4k
                if c.is_finished() {
251
676
                    *slot = None;
252
13.8k
                }
253
14.4k
                true
254
            }
255
35
            None => false,
256
        }
257
14.5k
    }
258
259
    /// Returns `true` if the cursor at index `a` is greater than at index `b`
260
    #[inline]
261
32.1k
    fn is_gt(&self, a: usize, b: usize) -> bool {
262
32.1k
        match (&self.cursors[a], &self.cursors[b]) {
263
70
            (None, _) => true,
264
194
            (_, None) => false,
265
31.9k
            (Some(ac), Some(bc)) => ac.cmp(bc).then_with(|| 
a.cmp(&b)13.4k
).is_gt(),
266
        }
267
32.1k
    }
268
269
    /// Find the leaf node index in the loser tree for the given cursor index
270
    ///
271
    /// Note that this is not necessarily a leaf node in the tree, but it can
272
    /// also be a half-node (a node with only one child). This happens when the
273
    /// number of cursors/streams is not a power of two. Thus, the loser tree
274
    /// will be unbalanced, but it will still work correctly.
275
    ///
276
    /// For example, with 5 streams, the loser tree will look like this:
277
    ///
278
    /// ```text
279
    ///           0 (winner)
280
    ///
281
    ///           1
282
    ///        /     \
283
    ///       2       3
284
    ///     /  \     / \
285
    ///    4    |   |   |
286
    ///   / \   |   |   |
287
    /// -+---+--+---+---+---- Below is not a part of loser tree
288
    ///  S3 S4 S0   S1  S2
289
    /// ```
290
    ///
291
    /// S0, S1, ... S4 are the streams (read: stream at index 0, stream at
292
    /// index 1, etc.)
293
    ///
294
    /// Zooming in at node 2 in the loser tree as an example, we can see that
295
    /// it takes as input the next item at (S0) and the loser of (S3, S4).
296
    ///
297
    #[inline]
298
14.5k
    fn lt_leaf_node_index(&self, cursor_index: usize) -> usize {
299
14.5k
        (self.cursors.len() + cursor_index) / 2
300
14.5k
    }
301
302
    /// Find the parent node index for the given node index
303
    #[inline]
304
32.1k
    fn lt_parent_node_index(&self, node_idx: usize) -> usize {
305
32.1k
        node_idx / 2
306
32.1k
    }
307
308
    /// Attempts to initialize the loser tree with one value from each
309
    /// non exhausted input, if possible
310
18
    fn init_loser_tree(&mut self) {
311
18
        // Init loser tree
312
18
        self.loser_tree = vec![usize::MAX; self.cursors.len()];
313
54
        for i in 0..
self.cursors.len()18
{
314
54
            let mut winner = i;
315
54
            let mut cmp_node = self.lt_leaf_node_index(i);
316
90
            while cmp_node != 0 && 
self.loser_tree[cmp_node] != usize::MAX72
{
317
36
                let challenger = self.loser_tree[cmp_node];
318
36
                if self.is_gt(winner, challenger) {
319
34
                    self.loser_tree[cmp_node] = winner;
320
34
                    winner = challenger;
321
34
                }
2
322
323
36
                cmp_node = self.lt_parent_node_index(cmp_node);
324
            }
325
54
            self.loser_tree[cmp_node] = winner;
326
        }
327
18
        self.loser_tree_adjusted = true;
328
18
    }
329
330
    /// Attempts to update the loser tree, following winner replacement, if possible
331
14.4k
    fn update_loser_tree(&mut self) {
332
14.4k
        let mut winner = self.loser_tree[0];
333
14.4k
        // Replace overall winner by walking tree of losers
334
14.4k
        let mut cmp_node = self.lt_leaf_node_index(winner);
335
46.6k
        while cmp_node != 0 {
336
32.1k
            let challenger = self.loser_tree[cmp_node];
337
32.1k
            if self.is_gt(winner, challenger) {
338
2.89k
                self.loser_tree[cmp_node] = winner;
339
2.89k
                winner = challenger;
340
29.2k
            }
341
32.1k
            cmp_node = self.lt_parent_node_index(cmp_node);
342
        }
343
14.4k
        self.loser_tree[0] = winner;
344
14.4k
        self.loser_tree_adjusted = true;
345
14.4k
    }
346
}
347
348
impl<C: CursorValues + Unpin> Stream for SortPreservingMergeStream<C> {
349
    type Item = Result<RecordBatch>;
350
351
293
    fn poll_next(
352
293
        mut self: Pin<&mut Self>,
353
293
        cx: &mut Context<'_>,
354
293
    ) -> Poll<Option<Self::Item>> {
355
293
        let poll = self.poll_next_inner(cx);
356
293
        self.metrics.record_poll(poll)
357
293
    }
358
}
359
360
impl<C: CursorValues + Unpin> RecordBatchStream for SortPreservingMergeStream<C> {
361
0
    fn schema(&self) -> SchemaRef {
362
0
        Arc::clone(self.in_progress.schema())
363
0
    }
364
}