/Users/andrewlamb/Software/datafusion/datafusion/common-runtime/src/common.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::future::Future; |
19 | | |
20 | | use tokio::task::{JoinError, JoinSet}; |
21 | | |
22 | | /// Helper that provides a simple API to spawn a single task and join it. |
23 | | /// Provides guarantees of aborting on `Drop` to keep it cancel-safe. |
24 | | /// |
25 | | /// Technically, it's just a wrapper of `JoinSet` (with size=1). |
26 | | #[derive(Debug)] |
27 | | pub struct SpawnedTask<R> { |
28 | | inner: JoinSet<R>, |
29 | | } |
30 | | |
31 | | impl<R: 'static> SpawnedTask<R> { |
32 | 2.81k | pub fn spawn<T>(task: T) -> Self |
33 | 2.81k | where |
34 | 2.81k | T: Future<Output = R>, |
35 | 2.81k | T: Send + 'static, |
36 | 2.81k | R: Send, |
37 | 2.81k | { |
38 | 2.81k | let mut inner = JoinSet::new(); |
39 | 2.81k | inner.spawn(task); |
40 | 2.81k | Self { inner } |
41 | 2.81k | } |
42 | | |
43 | | pub fn spawn_blocking<T>(task: T) -> Self |
44 | | where |
45 | | T: FnOnce() -> R, |
46 | | T: Send + 'static, |
47 | | R: Send, |
48 | | { |
49 | | let mut inner = JoinSet::new(); |
50 | | inner.spawn_blocking(task); |
51 | | Self { inner } |
52 | | } |
53 | | |
54 | | /// Joins the task, returning the result of join (`Result<R, JoinError>`). |
55 | 1.40k | pub async fn join(mut self) -> Result<R, JoinError> { |
56 | 1.40k | self.inner |
57 | 1.40k | .join_next() |
58 | 1.29k | .await |
59 | 1.40k | .expect("`SpawnedTask` instance always contains exactly 1 task") |
60 | 1.40k | } |
61 | | |
62 | | /// Joins the task and unwinds the panic if it happens. |
63 | | pub async fn join_unwind(self) -> Result<R, JoinError> { |
64 | | self.join().await.map_err(|e| { |
65 | | // `JoinError` can be caused either by panic or cancellation. We have to handle panics: |
66 | | if e.is_panic() { |
67 | | std::panic::resume_unwind(e.into_panic()); |
68 | | } else { |
69 | | // Cancellation may be caused by two reasons: |
70 | | // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside). |
71 | | // 2. The runtime is shutting down. |
72 | | log::warn!("SpawnedTask was polled during shutdown"); |
73 | | e |
74 | | } |
75 | | }) |
76 | | } |
77 | | } |
78 | | |
79 | | #[cfg(test)] |
80 | | mod tests { |
81 | | use super::*; |
82 | | |
83 | | use std::future::{pending, Pending}; |
84 | | |
85 | | use tokio::runtime::Runtime; |
86 | | |
87 | | #[tokio::test] |
88 | | async fn runtime_shutdown() { |
89 | | let rt = Runtime::new().unwrap(); |
90 | | let task = rt |
91 | | .spawn(async { |
92 | | SpawnedTask::spawn(async { |
93 | | let fut: Pending<()> = pending(); |
94 | | fut.await; |
95 | | unreachable!("should never return"); |
96 | | }) |
97 | | }) |
98 | | .await |
99 | | .unwrap(); |
100 | | |
101 | | // caller shutdown their DF runtime (e.g. timeout, error in caller, etc) |
102 | | rt.shutdown_background(); |
103 | | |
104 | | // race condition |
105 | | // poll occurs during shutdown (buffered stream poll calls, etc) |
106 | | assert!(matches!( |
107 | | task.join_unwind().await, |
108 | | Err(e) if e.is_cancelled() |
109 | | )); |
110 | | } |
111 | | |
112 | | #[tokio::test] |
113 | | #[should_panic(expected = "foo")] |
114 | | async fn panic_resume() { |
115 | | // this should panic w/o an `unwrap` |
116 | | SpawnedTask::spawn(async { panic!("foo") }) |
117 | | .join_unwind() |
118 | | .await |
119 | | .ok(); |
120 | | } |
121 | | } |