Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/common-runtime/src/common.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::future::Future;
19
20
use tokio::task::{JoinError, JoinSet};
21
22
/// Helper that  provides a simple API to spawn a single task and join it.
23
/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
24
///
25
/// Technically, it's just a wrapper of `JoinSet` (with size=1).
26
#[derive(Debug)]
27
pub struct SpawnedTask<R> {
28
    inner: JoinSet<R>,
29
}
30
31
impl<R: 'static> SpawnedTask<R> {
32
2.81k
    pub fn spawn<T>(task: T) -> Self
33
2.81k
    where
34
2.81k
        T: Future<Output = R>,
35
2.81k
        T: Send + 'static,
36
2.81k
        R: Send,
37
2.81k
    {
38
2.81k
        let mut inner = JoinSet::new();
39
2.81k
        inner.spawn(task);
40
2.81k
        Self { inner }
41
2.81k
    }
42
43
    pub fn spawn_blocking<T>(task: T) -> Self
44
    where
45
        T: FnOnce() -> R,
46
        T: Send + 'static,
47
        R: Send,
48
    {
49
        let mut inner = JoinSet::new();
50
        inner.spawn_blocking(task);
51
        Self { inner }
52
    }
53
54
    /// Joins the task, returning the result of join (`Result<R, JoinError>`).
55
1.40k
    pub async fn join(mut self) -> Result<R, JoinError> {
56
1.40k
        self.inner
57
1.40k
            .join_next()
58
1.29k
            .await
59
1.40k
            .expect("`SpawnedTask` instance always contains exactly 1 task")
60
1.40k
    }
61
62
    /// Joins the task and unwinds the panic if it happens.
63
    pub async fn join_unwind(self) -> Result<R, JoinError> {
64
        self.join().await.map_err(|e| {
65
            // `JoinError` can be caused either by panic or cancellation. We have to handle panics:
66
            if e.is_panic() {
67
                std::panic::resume_unwind(e.into_panic());
68
            } else {
69
                // Cancellation may be caused by two reasons:
70
                // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside).
71
                // 2. The runtime is shutting down.
72
                log::warn!("SpawnedTask was polled during shutdown");
73
                e
74
            }
75
        })
76
    }
77
}
78
79
#[cfg(test)]
80
mod tests {
81
    use super::*;
82
83
    use std::future::{pending, Pending};
84
85
    use tokio::runtime::Runtime;
86
87
    #[tokio::test]
88
    async fn runtime_shutdown() {
89
        let rt = Runtime::new().unwrap();
90
        let task = rt
91
            .spawn(async {
92
                SpawnedTask::spawn(async {
93
                    let fut: Pending<()> = pending();
94
                    fut.await;
95
                    unreachable!("should never return");
96
                })
97
            })
98
            .await
99
            .unwrap();
100
101
        // caller shutdown their DF runtime (e.g. timeout, error in caller, etc)
102
        rt.shutdown_background();
103
104
        // race condition
105
        // poll occurs during shutdown (buffered stream poll calls, etc)
106
        assert!(matches!(
107
            task.join_unwind().await,
108
            Err(e) if e.is_cancelled()
109
        ));
110
    }
111
112
    #[tokio::test]
113
    #[should_panic(expected = "foo")]
114
    async fn panic_resume() {
115
        // this should panic w/o an `unwrap`
116
        SpawnedTask::spawn(async { panic!("foo") })
117
            .join_unwind()
118
            .await
119
            .ok();
120
    }
121
}