Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/stream.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Stream wrappers for physical operators
19
20
use std::pin::Pin;
21
use std::sync::Arc;
22
use std::task::Context;
23
use std::task::Poll;
24
25
use super::metrics::BaselineMetrics;
26
use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
27
use crate::displayable;
28
29
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
30
use datafusion_common::{internal_err, Result};
31
use datafusion_execution::TaskContext;
32
33
use futures::stream::BoxStream;
34
use futures::{Future, Stream, StreamExt};
35
use log::debug;
36
use pin_project_lite::pin_project;
37
use tokio::sync::mpsc::{Receiver, Sender};
38
use tokio::task::JoinSet;
39
40
/// Creates a stream from a collection of producing tasks, routing panics to the stream.
41
///
42
/// Note that this is similar to  [`ReceiverStream` from tokio-stream], with the differences being:
43
///
44
/// 1. Methods to bound and "detach"  tasks (`spawn()` and `spawn_blocking()`).
45
///
46
/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
47
///
48
/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
49
///
50
/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
51
52
pub(crate) struct ReceiverStreamBuilder<O> {
53
    tx: Sender<Result<O>>,
54
    rx: Receiver<Result<O>>,
55
    join_set: JoinSet<Result<()>>,
56
}
57
58
impl<O: Send + 'static> ReceiverStreamBuilder<O> {
59
    /// create new channels with the specified buffer size
60
71
    pub fn new(capacity: usize) -> Self {
61
71
        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
62
71
63
71
        Self {
64
71
            tx,
65
71
            rx,
66
71
            join_set: JoinSet::new(),
67
71
        }
68
71
    }
69
70
    /// Get a handle for sending data to the output
71
274
    pub fn tx(&self) -> Sender<Result<O>> {
72
274
        self.tx.clone()
73
274
    }
74
75
    /// Spawn task that will be aborted if this builder (or the stream
76
    /// built from it) are dropped
77
265
    pub fn spawn<F>(&mut self, task: F)
78
265
    where
79
265
        F: Future<Output = Result<()>>,
80
265
        F: Send + 'static,
81
265
    {
82
265
        self.join_set.spawn(task);
83
265
    }
84
85
    /// Spawn a blocking task that will be aborted if this builder (or the stream
86
    /// built from it) are dropped
87
    ///
88
    /// this is often used to spawn tasks that write to the sender
89
    /// retrieved from `Self::tx`
90
9
    pub fn spawn_blocking<F>(&mut self, f: F)
91
9
    where
92
9
        F: FnOnce() -> Result<()>,
93
9
        F: Send + 'static,
94
9
    {
95
9
        self.join_set.spawn_blocking(f);
96
9
    }
97
98
    /// Create a stream of all data written to `tx`
99
71
    pub fn build(self) -> BoxStream<'static, Result<O>> {
100
71
        let Self {
101
71
            tx,
102
71
            rx,
103
71
            mut join_set,
104
71
        } = self;
105
71
106
71
        // don't need tx
107
71
        drop(tx);
108
71
109
71
        // future that checks the result of the join set, and propagates panic if seen
110
71
        let check = async move 
{70
111
997
            while let Some(
result259
) =
join_set.join_next()326
.await {
112
259
                match result {
113
256
                    Ok(task_result) => {
114
256
                        match task_result {
115
                            // nothing to report
116
256
                            Ok(_) => continue,
117
                            // This means a blocking task error
118
0
                            Err(error) => return Some(Err(error)),
119
                        }
120
                    }
121
                    // This means a tokio task error, likely a panic
122
3
                    Err(e) => {
123
3
                        if e.is_panic() {
124
                            // resume on the main thread
125
3
                            std::panic::resume_unwind(e.into_panic());
126
                        } else {
127
                            // This should only occur if the task is
128
                            // cancelled, which would only occur if
129
                            // the JoinSet were aborted, which in turn
130
                            // would imply that the receiver has been
131
                            // dropped and this code is not running
132
0
                            return Some(internal_err!("Non Panic Task error: {e}"));
133
                        }
134
                    }
135
                }
136
            }
137
65
            None
138
65
        };
139
140
71
        let check_stream = futures::stream::once(check)
141
71
            // unwrap Option / only return the error
142
71
            .filter_map(|item| async move 
{ item 65
}65
);
143
71
144
71
        // Convert the receiver into a stream
145
1.23k
        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
146
1.23k
            let 
next_item1.23k
= rx.recv().
await465
;
147
1.23k
            next_item.map(|next_item| 
(next_item, rx)1.17k
)
148
1.23k
        });
149
71
150
71
        // Merge the streams together so whichever is ready first
151
71
        // produces the batch
152
71
        futures::stream::select(rx_stream, check_stream).boxed()
153
71
    }
154
}
155
156
/// Builder for `RecordBatchReceiverStream` that propagates errors
157
/// and panic's correctly.
158
///
159
/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
160
/// that produce [`RecordBatch`]es and send them to a single
161
/// `Receiver` which can improve parallelism.
162
///
163
/// This also handles propagating panic`s and canceling the tasks.
164
///
165
/// # Example
166
///
167
/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
168
/// the `tx` end of the builder, after building the stream, we can receive
169
/// those batches with calling `.next()`
170
///
171
/// ```
172
/// # use std::sync::Arc;
173
/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
174
/// # use datafusion_common::arrow::array::RecordBatch;
175
/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
176
/// # use futures::stream::StreamExt;
177
/// # use tokio::runtime::Builder;
178
/// # let rt = Builder::new_current_thread().build().unwrap();
179
/// #
180
/// # rt.block_on(async {
181
/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
182
/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
183
///
184
/// // task 1
185
/// let tx_1 = builder.tx();
186
/// let schema_1 = Arc::clone(&schema);
187
/// builder.spawn(async move {
188
///     // Your task needs to send batches to the tx
189
///     tx_1.send(Ok(RecordBatch::new_empty(schema_1))).await.unwrap();
190
///
191
///     Ok(())
192
/// });
193
///
194
/// // task 2
195
/// let tx_2 = builder.tx();
196
/// let schema_2 = Arc::clone(&schema);
197
/// builder.spawn(async move {
198
///     // Your task needs to send batches to the tx
199
///     tx_2.send(Ok(RecordBatch::new_empty(schema_2))).await.unwrap();
200
///
201
///     Ok(())
202
/// });
203
///
204
/// let mut stream = builder.build();
205
/// while let Some(res_batch) = stream.next().await {
206
///     // `res_batch` can either from task 1 or 2
207
///
208
///     // do something with `res_batch`
209
/// }
210
/// # });
211
/// ```
212
pub struct RecordBatchReceiverStreamBuilder {
213
    schema: SchemaRef,
214
    inner: ReceiverStreamBuilder<RecordBatch>,
215
}
216
217
impl RecordBatchReceiverStreamBuilder {
218
    /// create new channels with the specified buffer size
219
71
    pub fn new(schema: SchemaRef, capacity: usize) -> Self {
220
71
        Self {
221
71
            schema,
222
71
            inner: ReceiverStreamBuilder::new(capacity),
223
71
        }
224
71
    }
225
226
    /// Get a handle for sending [`RecordBatch`] to the output
227
274
    pub fn tx(&self) -> Sender<Result<RecordBatch>> {
228
274
        self.inner.tx()
229
274
    }
230
231
    /// Spawn task that will be aborted if this builder (or the stream
232
    /// built from it) are dropped
233
    ///
234
    /// This is often used to spawn tasks that write to the sender
235
    /// retrieved from [`Self::tx`], for examples, see the document
236
    /// of this type.
237
20
    pub fn spawn<F>(&mut self, task: F)
238
20
    where
239
20
        F: Future<Output = Result<()>>,
240
20
        F: Send + 'static,
241
20
    {
242
20
        self.inner.spawn(task)
243
20
    }
244
245
    /// Spawn a blocking task that will be aborted if this builder (or the stream
246
    /// built from it) are dropped
247
    ///
248
    /// This is often used to spawn tasks that write to the sender
249
    /// retrieved from [`Self::tx`], for examples, see the document
250
    /// of this type.
251
9
    pub fn spawn_blocking<F>(&mut self, f: F)
252
9
    where
253
9
        F: FnOnce() -> Result<()>,
254
9
        F: Send + 'static,
255
9
    {
256
9
        self.inner.spawn_blocking(f)
257
9
    }
258
259
    /// runs the `partition` of the `input` ExecutionPlan on the
260
    /// tokio threadpool and writes its outputs to this stream
261
    ///
262
    /// If the input partition produces an error, the error will be
263
    /// sent to the output stream and no further results are sent.
264
245
    pub(crate) fn run_input(
265
245
        &mut self,
266
245
        input: Arc<dyn ExecutionPlan>,
267
245
        partition: usize,
268
245
        context: Arc<TaskContext>,
269
245
    ) {
270
245
        let output = self.tx();
271
245
272
245
        self.inner.spawn(async move 
{241
273
241
            let mut stream = match input.execute(partition, context) {
274
0
                Err(e) => {
275
0
                    // If send fails, the plan being torn down, there
276
0
                    // is no place to send the error and no reason to continue.
277
0
                    output.send(Err(e)).await.ok();
278
0
                    debug!(
279
0
                        "Stopping execution: error executing input: {}",
280
0
                        displayable(input.as_ref()).one_line()
281
                    );
282
0
                    return Ok(());
283
                }
284
241
                Ok(stream) => stream,
285
            };
286
287
            // Transfer batches from inner stream to the output tx
288
            // immediately.
289
1.11k
            while let Some(
item875
) = stream.next().
await7
{
290
875
                let is_err = item.is_err();
291
875
292
875
                // If send fails, plan being torn down, there is no
293
875
                // place to send the error and no reason to continue.
294
875
                if output.send(item).
await487
.
is_err()874
{
295
0
                    debug!(
296
0
                        "Stopping execution: output is gone, plan cancelling: {}",
297
0
                        displayable(input.as_ref()).one_line()
298
                    );
299
0
                    return Ok(());
300
874
                }
301
874
302
874
                // stop after the first error is encontered (don't
303
874
                // drive all streams to completion)
304
874
                if is_err {
305
1
                    debug!(
306
0
                        "Stopping execution: plan returned error: {}",
307
0
                        displayable(input.as_ref()).one_line()
308
                    );
309
1
                    return Ok(());
310
873
                }
311
            }
312
313
226
            Ok(())
314
245
        
}227
);
315
245
    }
316
317
    /// Create a stream of all [`RecordBatch`] written to `tx`
318
71
    pub fn build(self) -> SendableRecordBatchStream {
319
71
        Box::pin(RecordBatchStreamAdapter::new(
320
71
            self.schema,
321
71
            self.inner.build(),
322
71
        ))
323
71
    }
324
}
325
326
#[doc(hidden)]
327
pub struct RecordBatchReceiverStream {}
328
329
impl RecordBatchReceiverStream {
330
    /// Create a builder with an internal buffer of capacity batches.
331
71
    pub fn builder(
332
71
        schema: SchemaRef,
333
71
        capacity: usize,
334
71
    ) -> RecordBatchReceiverStreamBuilder {
335
71
        RecordBatchReceiverStreamBuilder::new(schema, capacity)
336
71
    }
337
}
338
339
pin_project! {
340
    /// Combines a [`Stream`] with a [`SchemaRef`] implementing
341
    /// [`RecordBatchStream`] for the combination
342
    pub struct RecordBatchStreamAdapter<S> {
343
        schema: SchemaRef,
344
345
        #[pin]
346
        stream: S,
347
    }
348
}
349
350
impl<S> RecordBatchStreamAdapter<S> {
351
    /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream
352
5.69k
    pub fn new(schema: SchemaRef, stream: S) -> Self {
353
5.69k
        Self { schema, stream }
354
5.69k
    }
355
}
356
357
impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
358
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359
0
        f.debug_struct("RecordBatchStreamAdapter")
360
0
            .field("schema", &self.schema)
361
0
            .finish()
362
0
    }
363
}
364
365
impl<S> Stream for RecordBatchStreamAdapter<S>
366
where
367
    S: Stream<Item = Result<RecordBatch>>,
368
{
369
    type Item = Result<RecordBatch>;
370
371
28.3k
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372
28.3k
        self.project().stream.poll_next(cx)
373
28.3k
    }
374
375
0
    fn size_hint(&self) -> (usize, Option<usize>) {
376
0
        self.stream.size_hint()
377
0
    }
378
}
379
380
impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
381
where
382
    S: Stream<Item = Result<RecordBatch>>,
383
{
384
736
    fn schema(&self) -> SchemaRef {
385
736
        Arc::clone(&self.schema)
386
736
    }
387
}
388
389
/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
390
/// that will produce no results
391
pub struct EmptyRecordBatchStream {
392
    /// Schema wrapped by Arc
393
    schema: SchemaRef,
394
}
395
396
impl EmptyRecordBatchStream {
397
    /// Create an empty RecordBatchStream
398
0
    pub fn new(schema: SchemaRef) -> Self {
399
0
        Self { schema }
400
0
    }
401
}
402
403
impl RecordBatchStream for EmptyRecordBatchStream {
404
0
    fn schema(&self) -> SchemaRef {
405
0
        Arc::clone(&self.schema)
406
0
    }
407
}
408
409
impl Stream for EmptyRecordBatchStream {
410
    type Item = Result<RecordBatch>;
411
412
0
    fn poll_next(
413
0
        self: Pin<&mut Self>,
414
0
        _cx: &mut Context<'_>,
415
0
    ) -> Poll<Option<Self::Item>> {
416
0
        Poll::Ready(None)
417
0
    }
418
}
419
420
/// Stream wrapper that records `BaselineMetrics` for a particular
421
/// `[SendableRecordBatchStream]` (likely a partition)
422
pub(crate) struct ObservedStream {
423
    inner: SendableRecordBatchStream,
424
    baseline_metrics: BaselineMetrics,
425
}
426
427
impl ObservedStream {
428
46
    pub fn new(
429
46
        inner: SendableRecordBatchStream,
430
46
        baseline_metrics: BaselineMetrics,
431
46
    ) -> Self {
432
46
        Self {
433
46
            inner,
434
46
            baseline_metrics,
435
46
        }
436
46
    }
437
}
438
439
impl RecordBatchStream for ObservedStream {
440
31
    fn schema(&self) -> arrow::datatypes::SchemaRef {
441
31
        self.inner.schema()
442
31
    }
443
}
444
445
impl futures::Stream for ObservedStream {
446
    type Item = Result<RecordBatch>;
447
448
1.16k
    fn poll_next(
449
1.16k
        mut self: Pin<&mut Self>,
450
1.16k
        cx: &mut Context<'_>,
451
1.16k
    ) -> Poll<Option<Self::Item>> {
452
1.16k
        let poll = self.inner.poll_next_unpin(cx);
453
1.16k
        self.baseline_metrics.record_poll(poll)
454
1.16k
    }
455
}
456
457
#[cfg(test)]
458
mod test {
459
    use super::*;
460
    use crate::test::exec::{
461
        assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
462
    };
463
464
    use arrow_schema::{DataType, Field, Schema};
465
    use datafusion_common::exec_err;
466
467
4
    fn schema() -> SchemaRef {
468
4
        Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
469
4
    }
470
471
    #[tokio::test]
472
    #[should_panic(expected = "PanickingStream did panic")]
473
1
    async fn record_batch_receiver_stream_propagates_panics() {
474
1
        let schema = schema();
475
1
476
1
        let num_partitions = 10;
477
1
        let input = PanicExec::new(Arc::clone(&schema), num_partitions);
478
1
        consume(input, 10).await
479
1
    }
480
481
    #[tokio::test]
482
    #[should_panic(expected = "PanickingStream did panic: 1")]
483
1
    async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
484
1
        let schema = schema();
485
1
486
1
        // make 2 partitions, second partition panics before the first
487
1
        let num_partitions = 2;
488
1
        let input = PanicExec::new(Arc::clone(&schema), num_partitions)
489
1
            .with_partition_panic(0, 10)
490
1
            .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
491
1
492
1
        // ensure that the panic results in an early shutdown (that
493
1
        // everything stops after the first panic).
494
1
495
1
        // Since the stream reads every other batch: (0,1,0,1,0,panic)
496
1
        // so should not exceed 5 batches prior to the panic
497
1
        let max_batches = 5;
498
3
        consume(input, max_batches).await
499
1
    }
500
501
    #[tokio::test]
502
1
    async fn record_batch_receiver_stream_drop_cancel() {
503
1
        let task_ctx = Arc::new(TaskContext::default());
504
1
        let schema = schema();
505
1
506
1
        // Make an input that never proceeds
507
1
        let input = BlockingExec::new(Arc::clone(&schema), 1);
508
1
        let refs = input.refs();
509
1
510
1
        // Configure a RecordBatchReceiverStream to consume the input
511
1
        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
512
1
        builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
513
1
        let stream = builder.build();
514
1
515
1
        // input should still be present
516
1
        assert!(std::sync::Weak::strong_count(&refs) > 0);
517
1
518
1
        // drop the stream, ensure the refs go to zero
519
1
        drop(stream);
520
1
        assert_strong_count_converges_to_zero(refs).await;
521
1
    }
522
523
    #[tokio::test]
524
    /// Ensure that if an error is received in one stream, the
525
    /// `RecordBatchReceiverStream` stops early and does not drive
526
    /// other streams to completion.
527
1
    async fn record_batch_receiver_stream_error_does_not_drive_completion() {
528
1
        let task_ctx = Arc::new(TaskContext::default());
529
1
        let schema = schema();
530
1
531
1
        // make an input that will error twice
532
1
        let error_stream = MockExec::new(
533
1
            vec![exec_err!("Test1"), exec_err!("Test2")],
534
1
            Arc::clone(&schema),
535
1
        )
536
1
        .with_use_task(false);
537
1
538
1
        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
539
1
        builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
540
1
        let mut stream = builder.build();
541
1
542
1
        // get the first result, which should be an error
543
1
        let first_batch = stream.next().await.unwrap();
544
1
        let first_err = first_batch.unwrap_err();
545
1
        assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
546
1
547
1
        // There should be no more batches produced (should not get the second error)
548
1
        assert!(stream.next().
await0
.is_none());
549
1
    }
550
551
    /// Consumes all the input's partitions into a
552
    /// RecordBatchReceiverStream and runs it to completion
553
    ///
554
    /// panic's if more than max_batches is seen,
555
2
    async fn consume(input: PanicExec, max_batches: usize) {
556
2
        let task_ctx = Arc::new(TaskContext::default());
557
2
558
2
        let input = Arc::new(input);
559
2
        let num_partitions = input.properties().output_partitioning().partition_count();
560
2
561
2
        // Configure a RecordBatchReceiverStream to consume all the input partitions
562
2
        let mut builder =
563
2
            RecordBatchReceiverStream::builder(input.schema(), num_partitions);
564
12
        for partition in 0..
num_partitions2
{
565
12
            builder.run_input(
566
12
                Arc::clone(&input) as Arc<dyn ExecutionPlan>,
567
12
                partition,
568
12
                Arc::clone(&task_ctx),
569
12
            );
570
12
        }
571
2
        let mut stream = builder.build();
572
2
573
2
        // drain the stream until it is complete, panic'ing on error
574
2
        let mut num_batches = 0;
575
6
        while let Some(
next4
) = stream.next().
await4
{
576
4
            next.unwrap();
577
4
            num_batches += 1;
578
4
            assert!(
579
4
                num_batches < max_batches,
580
0
                "Got the limit of {num_batches} batches before seeing panic"
581
            );
582
        }
583
0
    }
584
}