/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/sorts/merge.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Merge that deals with an arbitrary size of streaming inputs. |
19 | | //! This is an order-preserving merge. |
20 | | |
21 | | use std::collections::VecDeque; |
22 | | use std::pin::Pin; |
23 | | use std::sync::Arc; |
24 | | use std::task::{ready, Context, Poll}; |
25 | | |
26 | | use crate::metrics::BaselineMetrics; |
27 | | use crate::sorts::builder::BatchBuilder; |
28 | | use crate::sorts::cursor::{Cursor, CursorValues}; |
29 | | use crate::sorts::stream::PartitionedStream; |
30 | | use crate::RecordBatchStream; |
31 | | |
32 | | use arrow::datatypes::SchemaRef; |
33 | | use arrow::record_batch::RecordBatch; |
34 | | use datafusion_common::Result; |
35 | | use datafusion_execution::memory_pool::MemoryReservation; |
36 | | |
37 | | use futures::Stream; |
38 | | |
39 | | /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] |
40 | | type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch)>>>; |
41 | | |
42 | | /// Merges a stream of sorted cursors and record batches into a single sorted stream |
43 | | #[derive(Debug)] |
44 | | pub(crate) struct SortPreservingMergeStream<C: CursorValues> { |
45 | | in_progress: BatchBuilder, |
46 | | |
47 | | /// The sorted input streams to merge together |
48 | | streams: CursorStream<C>, |
49 | | |
50 | | /// used to record execution metrics |
51 | | metrics: BaselineMetrics, |
52 | | |
53 | | /// If the stream has encountered an error |
54 | | aborted: bool, |
55 | | |
56 | | /// A loser tree that always produces the minimum cursor |
57 | | /// |
58 | | /// Node 0 stores the top winner, Nodes 1..num_streams store |
59 | | /// the loser nodes |
60 | | /// |
61 | | /// This implements a "Tournament Tree" (aka Loser Tree) to keep |
62 | | /// track of the current smallest element at the top. When the top |
63 | | /// record is taken, the tree structure is not modified, and only |
64 | | /// the path from bottom to top is visited, keeping the number of |
65 | | /// comparisons close to the theoretical limit of `log(S)`. |
66 | | /// |
67 | | /// The current implementation uses a vector to store the tree. |
68 | | /// Conceptually, it looks like this (assuming 8 streams): |
69 | | /// |
70 | | /// ```text |
71 | | /// 0 (winner) |
72 | | /// |
73 | | /// 1 |
74 | | /// / \ |
75 | | /// 2 3 |
76 | | /// / \ / \ |
77 | | /// 4 5 6 7 |
78 | | /// ``` |
79 | | /// |
80 | | /// Where element at index 0 in the vector is the current winner. Element |
81 | | /// at index 1 is the root of the loser tree, element at index 2 is the |
82 | | /// left child of the root, and element at index 3 is the right child of |
83 | | /// the root and so on. |
84 | | /// |
85 | | /// reference: <https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree> |
86 | | loser_tree: Vec<usize>, |
87 | | |
88 | | /// If the most recently yielded overall winner has been replaced |
89 | | /// within the loser tree. A value of `false` indicates that the |
90 | | /// overall winner has been yielded but the loser tree has not |
91 | | /// been updated |
92 | | loser_tree_adjusted: bool, |
93 | | |
94 | | /// Target batch size |
95 | | batch_size: usize, |
96 | | |
97 | | /// Cursors for each input partition. `None` means the input is exhausted |
98 | | cursors: Vec<Option<Cursor<C>>>, |
99 | | |
100 | | /// Optional number of rows to fetch |
101 | | fetch: Option<usize>, |
102 | | |
103 | | /// number of rows produced |
104 | | produced: usize, |
105 | | |
106 | | /// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`, |
107 | | /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the |
108 | | /// vector to ensure the next iteration starts with a different partition, preventing the same partition |
109 | | /// from being continuously polled. |
110 | | uninitiated_partitions: VecDeque<usize>, |
111 | | } |
112 | | |
113 | | impl<C: CursorValues> SortPreservingMergeStream<C> { |
114 | 19 | pub(crate) fn new( |
115 | 19 | streams: CursorStream<C>, |
116 | 19 | schema: SchemaRef, |
117 | 19 | metrics: BaselineMetrics, |
118 | 19 | batch_size: usize, |
119 | 19 | fetch: Option<usize>, |
120 | 19 | reservation: MemoryReservation, |
121 | 19 | ) -> Self { |
122 | 19 | let stream_count = streams.partitions(); |
123 | 19 | |
124 | 19 | Self { |
125 | 19 | in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation), |
126 | 19 | streams, |
127 | 19 | metrics, |
128 | 19 | aborted: false, |
129 | 56 | cursors: (0..stream_count).map(|_| None).collect(), |
130 | 19 | loser_tree: vec![], |
131 | 19 | loser_tree_adjusted: false, |
132 | 19 | batch_size, |
133 | 19 | fetch, |
134 | 19 | produced: 0, |
135 | 19 | uninitiated_partitions: (0..stream_count).collect(), |
136 | 19 | } |
137 | 19 | } |
138 | | |
139 | | /// If the stream at the given index is not exhausted, and the last cursor for the |
140 | | /// stream is finished, poll the stream for the next RecordBatch and create a new |
141 | | /// cursor for the stream from the returned result |
142 | 14.7k | fn maybe_poll_stream( |
143 | 14.7k | &mut self, |
144 | 14.7k | cx: &mut Context<'_>, |
145 | 14.7k | idx: usize, |
146 | 14.7k | ) -> Poll<Result<()>> { |
147 | 14.7k | if self.cursors[idx].is_some() { |
148 | | // Cursor is not finished - don't need a new RecordBatch yet |
149 | 13.8k | return Poll::Ready(Ok(())); |
150 | 927 | } |
151 | | |
152 | 927 | match futures::ready!197 (self.streams.poll_next(cx, idx)) { |
153 | 54 | None => Poll::Ready(Ok(())), |
154 | 0 | Some(Err(e)) => Poll::Ready(Err(e)), |
155 | 676 | Some(Ok((cursor, batch))) => { |
156 | 676 | self.cursors[idx] = Some(Cursor::new(cursor)); |
157 | 676 | Poll::Ready(self.in_progress.push_batch(idx, batch)) |
158 | | } |
159 | | } |
160 | 14.7k | } |
161 | | |
162 | 293 | fn poll_next_inner( |
163 | 293 | &mut self, |
164 | 293 | cx: &mut Context<'_>, |
165 | 293 | ) -> Poll<Option<Result<RecordBatch>>> { |
166 | 293 | if self.aborted { |
167 | 0 | return Poll::Ready(None); |
168 | 293 | } |
169 | 293 | // Once all partitions have set their corresponding cursors for the loser tree, |
170 | 293 | // we skip the following block. Until then, this function may be called multiple |
171 | 293 | // times and can return Poll::Pending if any partition returns Poll::Pending. |
172 | 293 | if self.loser_tree.is_empty() { |
173 | 81 | let remaining_partitions = self.uninitiated_partitions.clone(); |
174 | 135 | for i117 in remaining_partitions { |
175 | 117 | match self.maybe_poll_stream(cx, i) { |
176 | 0 | Poll::Ready(Err(e)) => { |
177 | 0 | self.aborted = true; |
178 | 0 | return Poll::Ready(Some(Err(e))); |
179 | | } |
180 | | Poll::Pending => { |
181 | | // If a partition returns Poll::Pending, to avoid continuously polling it |
182 | | // and potentially increasing upstream buffer sizes, we move it to the |
183 | | // back of the polling queue. |
184 | 63 | if let Some(front) = self.uninitiated_partitions.pop_front() { |
185 | 63 | // This pop_front can never return `None`. |
186 | 63 | self.uninitiated_partitions.push_back(front); |
187 | 63 | }0 |
188 | | // This function could remain in a pending state, so we manually wake it here. |
189 | | // However, this approach can be investigated further to find a more natural way |
190 | | // to avoid disrupting the runtime scheduler. |
191 | 63 | cx.waker().wake_by_ref(); |
192 | 63 | return Poll::Pending; |
193 | | } |
194 | 54 | _ => { |
195 | 54 | // If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None), |
196 | 54 | // we remove this partition from the queue so it is not polled again. |
197 | 140 | self.uninitiated_partitions.retain(|idx| *idx != i); |
198 | 54 | } |
199 | | } |
200 | | } |
201 | 18 | self.init_loser_tree(); |
202 | 212 | } |
203 | | |
204 | | // NB timer records time taken on drop, so there are no |
205 | | // calls to `timer.done()` below. |
206 | 230 | let elapsed_compute = self.metrics.elapsed_compute().clone(); |
207 | 230 | let _timer = elapsed_compute.timer(); |
208 | | |
209 | | loop { |
210 | | // Adjust the loser tree if necessary, returning control if needed |
211 | 14.6k | if !self.loser_tree_adjusted { |
212 | 14.6k | let winner = self.loser_tree[0]; |
213 | 14.6k | if let Err(e0 ) = ready!134 (self.maybe_poll_stream(cx, winner)) { |
214 | 0 | self.aborted = true; |
215 | 0 | return Poll::Ready(Some(Err(e))); |
216 | 14.4k | } |
217 | 14.4k | self.update_loser_tree(); |
218 | 35 | } |
219 | | |
220 | 14.5k | let stream_idx = self.loser_tree[0]; |
221 | 14.5k | if self.advance(stream_idx) { |
222 | 14.4k | self.loser_tree_adjusted = false; |
223 | 14.4k | self.in_progress.push_row(stream_idx); |
224 | 14.4k | |
225 | 14.4k | // stop sorting if fetch has been reached |
226 | 14.4k | if self.fetch_reached() { |
227 | 0 | self.aborted = true; |
228 | 14.4k | } else if self.in_progress.len() < self.batch_size { |
229 | 14.4k | continue; |
230 | 61 | } |
231 | 35 | } |
232 | | |
233 | 96 | self.produced += self.in_progress.len(); |
234 | 96 | |
235 | 96 | return Poll::Ready(self.in_progress.build_record_batch().transpose()); |
236 | | } |
237 | 293 | } |
238 | | |
239 | 14.4k | fn fetch_reached(&mut self) -> bool { |
240 | 14.4k | self.fetch |
241 | 14.4k | .map(|fetch| self.produced + self.in_progress.len() >= fetch0 ) |
242 | 14.4k | .unwrap_or(false) |
243 | 14.4k | } |
244 | | |
245 | 14.5k | fn advance(&mut self, stream_idx: usize) -> bool { |
246 | 14.5k | let slot = &mut self.cursors[stream_idx]; |
247 | 14.5k | match slot.as_mut() { |
248 | 14.4k | Some(c) => { |
249 | 14.4k | c.advance(); |
250 | 14.4k | if c.is_finished() { |
251 | 676 | *slot = None; |
252 | 13.8k | } |
253 | 14.4k | true |
254 | | } |
255 | 35 | None => false, |
256 | | } |
257 | 14.5k | } |
258 | | |
259 | | /// Returns `true` if the cursor at index `a` is greater than at index `b` |
260 | | #[inline] |
261 | 32.1k | fn is_gt(&self, a: usize, b: usize) -> bool { |
262 | 32.1k | match (&self.cursors[a], &self.cursors[b]) { |
263 | 70 | (None, _) => true, |
264 | 194 | (_, None) => false, |
265 | 31.9k | (Some(ac), Some(bc)) => ac.cmp(bc).then_with(|| a.cmp(&b)13.4k ).is_gt(), |
266 | | } |
267 | 32.1k | } |
268 | | |
269 | | /// Find the leaf node index in the loser tree for the given cursor index |
270 | | /// |
271 | | /// Note that this is not necessarily a leaf node in the tree, but it can |
272 | | /// also be a half-node (a node with only one child). This happens when the |
273 | | /// number of cursors/streams is not a power of two. Thus, the loser tree |
274 | | /// will be unbalanced, but it will still work correctly. |
275 | | /// |
276 | | /// For example, with 5 streams, the loser tree will look like this: |
277 | | /// |
278 | | /// ```text |
279 | | /// 0 (winner) |
280 | | /// |
281 | | /// 1 |
282 | | /// / \ |
283 | | /// 2 3 |
284 | | /// / \ / \ |
285 | | /// 4 | | | |
286 | | /// / \ | | | |
287 | | /// -+---+--+---+---+---- Below is not a part of loser tree |
288 | | /// S3 S4 S0 S1 S2 |
289 | | /// ``` |
290 | | /// |
291 | | /// S0, S1, ... S4 are the streams (read: stream at index 0, stream at |
292 | | /// index 1, etc.) |
293 | | /// |
294 | | /// Zooming in at node 2 in the loser tree as an example, we can see that |
295 | | /// it takes as input the next item at (S0) and the loser of (S3, S4). |
296 | | /// |
297 | | #[inline] |
298 | 14.5k | fn lt_leaf_node_index(&self, cursor_index: usize) -> usize { |
299 | 14.5k | (self.cursors.len() + cursor_index) / 2 |
300 | 14.5k | } |
301 | | |
302 | | /// Find the parent node index for the given node index |
303 | | #[inline] |
304 | 32.1k | fn lt_parent_node_index(&self, node_idx: usize) -> usize { |
305 | 32.1k | node_idx / 2 |
306 | 32.1k | } |
307 | | |
308 | | /// Attempts to initialize the loser tree with one value from each |
309 | | /// non exhausted input, if possible |
310 | 18 | fn init_loser_tree(&mut self) { |
311 | 18 | // Init loser tree |
312 | 18 | self.loser_tree = vec![usize::MAX; self.cursors.len()]; |
313 | 54 | for i in 0..self.cursors.len()18 { |
314 | 54 | let mut winner = i; |
315 | 54 | let mut cmp_node = self.lt_leaf_node_index(i); |
316 | 90 | while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX72 { |
317 | 36 | let challenger = self.loser_tree[cmp_node]; |
318 | 36 | if self.is_gt(winner, challenger) { |
319 | 34 | self.loser_tree[cmp_node] = winner; |
320 | 34 | winner = challenger; |
321 | 34 | }2 |
322 | | |
323 | 36 | cmp_node = self.lt_parent_node_index(cmp_node); |
324 | | } |
325 | 54 | self.loser_tree[cmp_node] = winner; |
326 | | } |
327 | 18 | self.loser_tree_adjusted = true; |
328 | 18 | } |
329 | | |
330 | | /// Attempts to update the loser tree, following winner replacement, if possible |
331 | 14.4k | fn update_loser_tree(&mut self) { |
332 | 14.4k | let mut winner = self.loser_tree[0]; |
333 | 14.4k | // Replace overall winner by walking tree of losers |
334 | 14.4k | let mut cmp_node = self.lt_leaf_node_index(winner); |
335 | 46.6k | while cmp_node != 0 { |
336 | 32.1k | let challenger = self.loser_tree[cmp_node]; |
337 | 32.1k | if self.is_gt(winner, challenger) { |
338 | 2.89k | self.loser_tree[cmp_node] = winner; |
339 | 2.89k | winner = challenger; |
340 | 29.2k | } |
341 | 32.1k | cmp_node = self.lt_parent_node_index(cmp_node); |
342 | | } |
343 | 14.4k | self.loser_tree[0] = winner; |
344 | 14.4k | self.loser_tree_adjusted = true; |
345 | 14.4k | } |
346 | | } |
347 | | |
348 | | impl<C: CursorValues + Unpin> Stream for SortPreservingMergeStream<C> { |
349 | | type Item = Result<RecordBatch>; |
350 | | |
351 | 293 | fn poll_next( |
352 | 293 | mut self: Pin<&mut Self>, |
353 | 293 | cx: &mut Context<'_>, |
354 | 293 | ) -> Poll<Option<Self::Item>> { |
355 | 293 | let poll = self.poll_next_inner(cx); |
356 | 293 | self.metrics.record_poll(poll) |
357 | 293 | } |
358 | | } |
359 | | |
360 | | impl<C: CursorValues + Unpin> RecordBatchStream for SortPreservingMergeStream<C> { |
361 | 0 | fn schema(&self) -> SchemaRef { |
362 | 0 | Arc::clone(self.in_progress.schema()) |
363 | 0 | } |
364 | | } |