/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/repartition/distributor_channels.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Special channel construction to distribute data from various inputs into N outputs |
19 | | //! minimizing buffering but preventing deadlocks when repartitoning |
20 | | //! |
21 | | //! # Design |
22 | | //! |
23 | | //! ```text |
24 | | //! +----+ +------+ |
25 | | //! | TX |==|| | Gate | |
26 | | //! +----+ || | | +--------+ +----+ |
27 | | //! ====| |==| Buffer |==| RX | |
28 | | //! +----+ || | | +--------+ +----+ |
29 | | //! | TX |==|| | | |
30 | | //! +----+ | | |
31 | | //! | | |
32 | | //! +----+ | | +--------+ +----+ |
33 | | //! | TX |======| |==| Buffer |==| RX | |
34 | | //! +----+ +------+ +--------+ +----+ |
35 | | //! ``` |
36 | | //! |
37 | | //! There are `N` virtual MPSC (multi-producer, single consumer) channels with unbounded capacity. However, if all |
38 | | //! buffers/channels are non-empty, than a global gate will be closed preventing new data from being written (the |
39 | | //! sender futures will be [pending](Poll::Pending)) until at least one channel is empty (and not closed). |
40 | | use std::{ |
41 | | collections::VecDeque, |
42 | | future::Future, |
43 | | ops::DerefMut, |
44 | | pin::Pin, |
45 | | sync::{ |
46 | | atomic::{AtomicUsize, Ordering}, |
47 | | Arc, |
48 | | }, |
49 | | task::{Context, Poll, Waker}, |
50 | | }; |
51 | | |
52 | | use parking_lot::Mutex; |
53 | | |
54 | | /// Create `n` empty channels. |
55 | 1.40k | pub fn channels<T>( |
56 | 1.40k | n: usize, |
57 | 1.40k | ) -> (Vec<DistributionSender<T>>, Vec<DistributionReceiver<T>>) { |
58 | 1.40k | let channels = (0..n) |
59 | 5.57k | .map(|id| Arc::new(Channel::new_with_one_sender(id))) |
60 | 1.40k | .collect::<Vec<_>>(); |
61 | 1.40k | let gate = Arc::new(Gate { |
62 | 1.40k | empty_channels: AtomicUsize::new(n), |
63 | 1.40k | send_wakers: Mutex::new(None), |
64 | 1.40k | }); |
65 | 1.40k | let senders = channels |
66 | 1.40k | .iter() |
67 | 5.57k | .map(|channel| DistributionSender { |
68 | 5.57k | channel: Arc::clone(channel), |
69 | 5.57k | gate: Arc::clone(&gate), |
70 | 5.57k | }) |
71 | 1.40k | .collect(); |
72 | 1.40k | let receivers = channels |
73 | 1.40k | .into_iter() |
74 | 5.57k | .map(|channel| DistributionReceiver { |
75 | 5.57k | channel, |
76 | 5.57k | gate: Arc::clone(&gate), |
77 | 5.57k | }) |
78 | 1.40k | .collect(); |
79 | 1.40k | (senders, receivers) |
80 | 1.40k | } |
81 | | |
82 | | type PartitionAwareSenders<T> = Vec<Vec<DistributionSender<T>>>; |
83 | | type PartitionAwareReceivers<T> = Vec<Vec<DistributionReceiver<T>>>; |
84 | | |
85 | | /// Create `n_out` empty channels for each of the `n_in` inputs. |
86 | | /// This way, each distinct partition will communicate via a dedicated channel. |
87 | | /// This SPSC structure enables us to track which partition input data comes from. |
88 | 0 | pub fn partition_aware_channels<T>( |
89 | 0 | n_in: usize, |
90 | 0 | n_out: usize, |
91 | 0 | ) -> (PartitionAwareSenders<T>, PartitionAwareReceivers<T>) { |
92 | 0 | (0..n_in).map(|_| channels(n_out)).unzip() |
93 | 0 | } |
94 | | |
95 | | /// Erroring during [send](DistributionSender::send). |
96 | | /// |
97 | | /// This occurs when the [receiver](DistributionReceiver) is gone. |
98 | | #[derive(PartialEq, Eq)] |
99 | | pub struct SendError<T>(pub T); |
100 | | |
101 | | impl<T> std::fmt::Debug for SendError<T> { |
102 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
103 | 0 | f.debug_tuple("SendError").finish() |
104 | 0 | } |
105 | | } |
106 | | |
107 | | impl<T> std::fmt::Display for SendError<T> { |
108 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
109 | 0 | write!(f, "cannot send data, receiver is gone") |
110 | 0 | } |
111 | | } |
112 | | |
113 | | impl<T> std::error::Error for SendError<T> {} |
114 | | |
115 | | /// Sender side of distribution [channels]. |
116 | | /// |
117 | | /// This handle can be cloned. All clones will write into the same channel. Dropping the last sender will close the |
118 | | /// channel. In this case, the [receiver](DistributionReceiver) will still be able to poll the remaining data, but will |
119 | | /// receive `None` afterwards. |
120 | | #[derive(Debug)] |
121 | | pub struct DistributionSender<T> { |
122 | | /// To prevent lock inversion / deadlock, channel lock is always acquired prior to gate lock |
123 | | channel: SharedChannel<T>, |
124 | | gate: SharedGate, |
125 | | } |
126 | | |
127 | | impl<T> DistributionSender<T> { |
128 | | /// Send data. |
129 | | /// |
130 | | /// This fails if the [receiver](DistributionReceiver) is gone. |
131 | 23.0k | pub fn send(&self, element: T) -> SendFuture<'_, T> { |
132 | 23.0k | SendFuture { |
133 | 23.0k | channel: &self.channel, |
134 | 23.0k | gate: &self.gate, |
135 | 23.0k | element: Box::new(Some(element)), |
136 | 23.0k | } |
137 | 23.0k | } |
138 | | } |
139 | | |
140 | | impl<T> Clone for DistributionSender<T> { |
141 | 11.2k | fn clone(&self) -> Self { |
142 | 11.2k | self.channel.n_senders.fetch_add(1, Ordering::SeqCst); |
143 | 11.2k | |
144 | 11.2k | Self { |
145 | 11.2k | channel: Arc::clone(&self.channel), |
146 | 11.2k | gate: Arc::clone(&self.gate), |
147 | 11.2k | } |
148 | 11.2k | } |
149 | | } |
150 | | |
151 | | impl<T> Drop for DistributionSender<T> { |
152 | 16.8k | fn drop(&mut self) { |
153 | 16.8k | let n_senders_pre = self.channel.n_senders.fetch_sub(1, Ordering::SeqCst); |
154 | 16.8k | // is the the last copy of the sender side? |
155 | 16.8k | if n_senders_pre > 1 { |
156 | 11.2k | return; |
157 | 5.57k | } |
158 | | |
159 | 5.57k | let receivers = { |
160 | 5.57k | let mut state = self.channel.state.lock(); |
161 | 5.57k | |
162 | 5.57k | // During the shutdown of a empty channel, both the sender and the receiver side will be dropped. However we |
163 | 5.57k | // only want to decrement the "empty channels" counter once. |
164 | 5.57k | // |
165 | 5.57k | // We are within a critical section here, so we we can safely assume that either the last sender or the |
166 | 5.57k | // receiver (there's only one) will be dropped first. |
167 | 5.57k | // |
168 | 5.57k | // If the last sender is dropped first, `state.data` will still exists and the sender side decrements the |
169 | 5.57k | // signal. The receiver side then MUST check the `n_senders` counter during the section and if it is zero, |
170 | 5.57k | // it inferres that it is dropped afterwards and MUST NOT decrement the counter. |
171 | 5.57k | // |
172 | 5.57k | // If the receiver end is dropped first, it will inferr -- based on `n_senders` -- that there are still |
173 | 5.57k | // senders and it will decrement the `empty_channels` counter. It will also set `data` to `None`. The sender |
174 | 5.57k | // side will then see that `data` is `None` and can therefore inferr that the receiver end was dropped, and |
175 | 5.57k | // hence it MUST NOT decrement the `empty_channels` counter. |
176 | 5.57k | if state |
177 | 5.57k | .data |
178 | 5.57k | .as_ref() |
179 | 5.57k | .map(|data| data.is_empty()5.52k ) |
180 | 5.57k | .unwrap_or_default() |
181 | 15 | { |
182 | 15 | // channel is gone, so we need to clear our signal |
183 | 15 | self.gate.decr_empty_channels(); |
184 | 5.55k | } |
185 | | |
186 | | // make sure that nobody can add wakers anymore |
187 | 5.57k | state.recv_wakers.take().expect("not closed yet") |
188 | | }; |
189 | | |
190 | | // wake outside of lock scope |
191 | 5.57k | for recv5 in receivers { |
192 | 5 | recv.wake(); |
193 | 5 | } |
194 | 16.8k | } |
195 | | } |
196 | | |
197 | | /// Future backing [send](DistributionSender::send). |
198 | | #[derive(Debug)] |
199 | | pub struct SendFuture<'a, T> { |
200 | | channel: &'a SharedChannel<T>, |
201 | | gate: &'a SharedGate, |
202 | | // the additional Box is required for `Self: Unpin` |
203 | | element: Box<Option<T>>, |
204 | | } |
205 | | |
206 | | impl<'a, T> Future for SendFuture<'a, T> { |
207 | | type Output = Result<(), SendError<T>>; |
208 | | |
209 | 26.6k | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
210 | 26.6k | let this = &mut *self; |
211 | 26.6k | assert!(this.element.is_some(), "polled ready future"2 ); |
212 | | |
213 | | // lock scope |
214 | 23.0k | let to_wake = { |
215 | 26.6k | let mut guard_channel_state = this.channel.state.lock(); |
216 | | |
217 | 26.6k | let Some(data26.6k ) = guard_channel_state.data.as_mut() else { |
218 | | // receiver end dead |
219 | 4 | return Poll::Ready(Err(SendError( |
220 | 4 | this.element.take().expect("just checked"), |
221 | 4 | ))); |
222 | | }; |
223 | | |
224 | | // does ANY receiver need data? |
225 | | // if so, allow sender to create another |
226 | 26.6k | if this.gate.empty_channels.load(Ordering::SeqCst) == 0 { |
227 | 3.58k | let mut guard = this.gate.send_wakers.lock(); |
228 | 3.58k | if let Some(send_wakers) = guard.deref_mut() { |
229 | 3.58k | send_wakers.push((cx.waker().clone(), this.channel.id)); |
230 | 3.58k | return Poll::Pending; |
231 | 0 | } |
232 | 23.0k | } |
233 | | |
234 | 23.0k | let was_empty = data.is_empty(); |
235 | 23.0k | data.push_back(this.element.take().expect("just checked")); |
236 | 23.0k | |
237 | 23.0k | if was_empty { |
238 | 10.3k | this.gate.decr_empty_channels(); |
239 | 10.3k | guard_channel_state.take_recv_wakers() |
240 | | } else { |
241 | 12.6k | Vec::with_capacity(0) |
242 | | } |
243 | | }; |
244 | | |
245 | | // wake outside of lock scope |
246 | 26.3k | for receiver3.32k in to_wake { |
247 | 3.32k | receiver.wake(); |
248 | 3.32k | } |
249 | | |
250 | 23.0k | Poll::Ready(Ok(())) |
251 | 26.6k | } |
252 | | } |
253 | | |
254 | | /// Receiver side of distribution [channels]. |
255 | | #[derive(Debug)] |
256 | | pub struct DistributionReceiver<T> { |
257 | | channel: SharedChannel<T>, |
258 | | gate: SharedGate, |
259 | | } |
260 | | |
261 | | impl<T> DistributionReceiver<T> { |
262 | | /// Receive data from channel. |
263 | | /// |
264 | | /// Returns `None` if the channel is empty and no [senders](DistributionSender) are left. |
265 | 26.3k | pub fn recv(&mut self) -> RecvFuture<'_, T> { |
266 | 26.3k | RecvFuture { |
267 | 26.3k | channel: &mut self.channel, |
268 | 26.3k | gate: &mut self.gate, |
269 | 26.3k | rdy: false, |
270 | 26.3k | } |
271 | 26.3k | } |
272 | | } |
273 | | |
274 | | impl<T> Drop for DistributionReceiver<T> { |
275 | 5.57k | fn drop(&mut self) { |
276 | 5.57k | let mut guard_channel_state = self.channel.state.lock(); |
277 | 5.57k | let data = guard_channel_state.data.take().expect("not dropped yet"); |
278 | 5.57k | |
279 | 5.57k | // See `DistributedSender::drop` for an explanation of the drop order and when the "empty channels" counter is |
280 | 5.57k | // decremented. |
281 | 5.57k | if data.is_empty() && (self.channel.n_senders.load(Ordering::SeqCst) > 0)5.56k { |
282 | 27 | // channel is gone, so we need to clear our signal |
283 | 27 | self.gate.decr_empty_channels(); |
284 | 5.54k | } |
285 | | |
286 | | // senders may be waiting for gate to open but should error now that the channel is closed |
287 | 5.57k | self.gate.wake_channel_senders(self.channel.id); |
288 | 5.57k | } |
289 | | } |
290 | | |
291 | | /// Future backing [recv](DistributionReceiver::recv). |
292 | | pub struct RecvFuture<'a, T> { |
293 | | channel: &'a mut SharedChannel<T>, |
294 | | gate: &'a mut SharedGate, |
295 | | rdy: bool, |
296 | | } |
297 | | |
298 | | impl<'a, T> Future for RecvFuture<'a, T> { |
299 | | type Output = Option<T>; |
300 | | |
301 | 26.3k | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
302 | 26.3k | let this = &mut *self; |
303 | 26.3k | assert!(!this.rdy, "polled ready future"2 ); |
304 | | |
305 | 26.3k | let mut guard_channel_state = this.channel.state.lock(); |
306 | 26.3k | let channel_state = guard_channel_state.deref_mut(); |
307 | 26.3k | let data = channel_state.data.as_mut().expect("not dropped yet"); |
308 | 26.3k | |
309 | 26.3k | match data.pop_front() { |
310 | 23.0k | Some(element) => { |
311 | 23.0k | // change "empty" signal for this channel? |
312 | 23.0k | if data.is_empty() && channel_state.recv_wakers.is_some()10.3k { |
313 | | // update counter |
314 | 4.85k | let old_counter = |
315 | 4.85k | this.gate.empty_channels.fetch_add(1, Ordering::SeqCst); |
316 | | |
317 | | // open gate? |
318 | 4.85k | let to_wake = if old_counter == 0 { |
319 | 3.39k | let mut guard = this.gate.send_wakers.lock(); |
320 | 3.39k | |
321 | 3.39k | // check after lock to see if we should still change the state |
322 | 3.39k | if this.gate.empty_channels.load(Ordering::SeqCst) > 0 { |
323 | 3.39k | guard.take().unwrap_or_default() |
324 | | } else { |
325 | 0 | Vec::with_capacity(0) |
326 | | } |
327 | | } else { |
328 | 1.45k | Vec::with_capacity(0) |
329 | | }; |
330 | | |
331 | 4.85k | drop(guard_channel_state); |
332 | | |
333 | | // wake outside of lock scope |
334 | 8.42k | for (waker, _channel_id3.57k ) in to_wake { |
335 | 3.57k | waker.wake(); |
336 | 3.57k | } |
337 | 18.1k | } |
338 | | |
339 | 23.0k | this.rdy = true; |
340 | 23.0k | Poll::Ready(Some(element)) |
341 | | } |
342 | | None => { |
343 | 3.32k | if let Some(recv_wakers3.32k ) = channel_state.recv_wakers.as_mut() { |
344 | 3.32k | recv_wakers.push(cx.waker().clone()); |
345 | 3.32k | Poll::Pending |
346 | | } else { |
347 | 4 | this.rdy = true; |
348 | 4 | Poll::Ready(None) |
349 | | } |
350 | | } |
351 | | } |
352 | 26.3k | } |
353 | | } |
354 | | |
355 | | /// Links senders and receivers. |
356 | | #[derive(Debug)] |
357 | | struct Channel<T> { |
358 | | /// Reference counter for the sender side. |
359 | | n_senders: AtomicUsize, |
360 | | |
361 | | /// Channel ID. |
362 | | /// |
363 | | /// This is used to address [send wakers](Gate::send_wakers). |
364 | | id: usize, |
365 | | |
366 | | /// Mutable state. |
367 | | state: Mutex<ChannelState<T>>, |
368 | | } |
369 | | |
370 | | impl<T> Channel<T> { |
371 | | /// Create new channel with one sender (so we don't need to [fetch-add](AtomicUsize::fetch_add) directly afterwards). |
372 | 5.57k | fn new_with_one_sender(id: usize) -> Self { |
373 | 5.57k | Channel { |
374 | 5.57k | n_senders: AtomicUsize::new(1), |
375 | 5.57k | id, |
376 | 5.57k | state: Mutex::new(ChannelState { |
377 | 5.57k | data: Some(VecDeque::default()), |
378 | 5.57k | recv_wakers: Some(Vec::default()), |
379 | 5.57k | }), |
380 | 5.57k | } |
381 | 5.57k | } |
382 | | } |
383 | | |
384 | | #[derive(Debug)] |
385 | | struct ChannelState<T> { |
386 | | /// Buffered data. |
387 | | /// |
388 | | /// This is [`None`] when the receiver is gone. |
389 | | data: Option<VecDeque<T>>, |
390 | | |
391 | | /// Wakers for the receiver side. |
392 | | /// |
393 | | /// The receiver will be pending if the [buffer](Self::data) is empty and |
394 | | /// there are senders left (otherwise this is set to [`None`]). |
395 | | recv_wakers: Option<Vec<Waker>>, |
396 | | } |
397 | | |
398 | | impl<T> ChannelState<T> { |
399 | | /// Get all [`recv_wakers`](Self::recv_wakers) and replace with identically-sized buffer. |
400 | | /// |
401 | | /// The wakers should be woken AFTER the lock to [this state](Self) was dropped. |
402 | | /// |
403 | | /// # Panics |
404 | | /// Assumes that channel is NOT closed yet, i.e. that [`recv_wakers`](Self::recv_wakers) is not [`None`]. |
405 | 10.3k | fn take_recv_wakers(&mut self) -> Vec<Waker> { |
406 | 10.3k | let to_wake = self.recv_wakers.as_mut().expect("not closed"); |
407 | 10.3k | let mut tmp = Vec::with_capacity(to_wake.capacity()); |
408 | 10.3k | std::mem::swap(to_wake, &mut tmp); |
409 | 10.3k | tmp |
410 | 10.3k | } |
411 | | } |
412 | | |
413 | | /// Shared channel. |
414 | | /// |
415 | | /// One or multiple senders and a single receiver will share a channel. |
416 | | type SharedChannel<T> = Arc<Channel<T>>; |
417 | | |
418 | | /// The "all channels have data" gate. |
419 | | #[derive(Debug)] |
420 | | struct Gate { |
421 | | /// Number of currently empty (and still open) channels. |
422 | | empty_channels: AtomicUsize, |
423 | | |
424 | | /// Wakers for the sender side, including their channel IDs. |
425 | | /// |
426 | | /// This is `None` if the there are non-empty channels. |
427 | | send_wakers: Mutex<Option<Vec<(Waker, usize)>>>, |
428 | | } |
429 | | |
430 | | impl Gate { |
431 | | /// Wake senders for a specific channel. |
432 | | /// |
433 | | /// This is helpful to signal that the receiver side is gone and the senders shall now error. |
434 | 5.57k | fn wake_channel_senders(&self, id: usize) { |
435 | | // lock scope |
436 | 5.57k | let to_wake = { |
437 | 5.57k | let mut guard = self.send_wakers.lock(); |
438 | | |
439 | 5.57k | if let Some(send_wakers5.55k ) = guard.deref_mut() { |
440 | | // `drain_filter` is unstable, so implement our own |
441 | 5.55k | let (wake, keep) = |
442 | 5.55k | send_wakers.drain(..).partition(|(_waker, id2)| id == *id24.04k ); |
443 | 5.55k | |
444 | 5.55k | *send_wakers = keep; |
445 | 5.55k | |
446 | 5.55k | wake |
447 | | } else { |
448 | 15 | Vec::with_capacity(0) |
449 | | } |
450 | | }; |
451 | | |
452 | | // wake outside of lock scope |
453 | 5.58k | for (waker, _id7 ) in to_wake { |
454 | 7 | waker.wake(); |
455 | 7 | } |
456 | 5.57k | } |
457 | | |
458 | 10.4k | fn decr_empty_channels(&self) { |
459 | 10.4k | let old_count = self.empty_channels.fetch_sub(1, Ordering::SeqCst); |
460 | 10.4k | |
461 | 10.4k | if old_count == 1 { |
462 | 4.79k | let mut guard = self.send_wakers.lock(); |
463 | 4.79k | |
464 | 4.79k | // double-check state during lock |
465 | 4.79k | if self.empty_channels.load(Ordering::SeqCst) == 0 && guard.is_none() { |
466 | 4.79k | *guard = Some(Vec::new()); |
467 | 4.79k | }0 |
468 | 5.62k | } |
469 | 10.4k | } |
470 | | } |
471 | | |
472 | | /// Gate shared by all senders and receivers. |
473 | | type SharedGate = Arc<Gate>; |
474 | | |
475 | | #[cfg(test)] |
476 | | mod tests { |
477 | | use std::sync::atomic::AtomicBool; |
478 | | |
479 | | use futures::{task::ArcWake, FutureExt}; |
480 | | |
481 | | use super::*; |
482 | | |
483 | | #[test] |
484 | 1 | fn test_single_channel_no_gate() { |
485 | 1 | // use two channels so that the first one never hits the gate |
486 | 1 | let (mut txs, mut rxs) = channels(2); |
487 | 1 | |
488 | 1 | let mut recv_fut = rxs[0].recv(); |
489 | 1 | let waker = poll_pending(&mut recv_fut); |
490 | 1 | |
491 | 1 | poll_ready(&mut txs[0].send("foo")).unwrap(); |
492 | 1 | assert!(waker.woken()); |
493 | 1 | assert_eq!(poll_ready(&mut recv_fut), Some("foo"),); |
494 | | |
495 | 1 | poll_ready(&mut txs[0].send("bar")).unwrap(); |
496 | 1 | poll_ready(&mut txs[0].send("baz")).unwrap(); |
497 | 1 | poll_ready(&mut txs[0].send("end")).unwrap(); |
498 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),); |
499 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("baz"),); |
500 | | |
501 | | // close channel |
502 | 1 | txs.remove(0); |
503 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("end"),); |
504 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), None,); |
505 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), None,); |
506 | 1 | } |
507 | | |
508 | | #[test] |
509 | 1 | fn test_multi_sender() { |
510 | 1 | // use two channels so that the first one never hits the gate |
511 | 1 | let (txs, mut rxs) = channels(2); |
512 | 1 | |
513 | 1 | let tx_clone = txs[0].clone(); |
514 | 1 | |
515 | 1 | poll_ready(&mut txs[0].send("foo")).unwrap(); |
516 | 1 | poll_ready(&mut tx_clone.send("bar")).unwrap(); |
517 | 1 | |
518 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("foo"),); |
519 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),); |
520 | 1 | } |
521 | | |
522 | | #[test] |
523 | 1 | fn test_gate() { |
524 | 1 | let (txs, mut rxs) = channels(2); |
525 | 1 | |
526 | 1 | // gate initially open |
527 | 1 | poll_ready(&mut txs[0].send("0_a")).unwrap(); |
528 | 1 | |
529 | 1 | // gate still open because channel 1 is still empty |
530 | 1 | poll_ready(&mut txs[0].send("0_b")).unwrap(); |
531 | 1 | |
532 | 1 | // gate still open because channel 1 is still empty prior to this call, so this call still goes through |
533 | 1 | poll_ready(&mut txs[1].send("1_a")).unwrap(); |
534 | 1 | |
535 | 1 | // both channels non-empty => gate closed |
536 | 1 | |
537 | 1 | let mut send_fut = txs[1].send("1_b"); |
538 | 1 | let waker = poll_pending(&mut send_fut); |
539 | 1 | |
540 | 1 | // drain channel 0 |
541 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_a"),); |
542 | 1 | poll_pending(&mut send_fut); |
543 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_b"),); |
544 | | |
545 | | // channel 0 empty => gate open |
546 | 1 | assert!(waker.woken()); |
547 | 1 | poll_ready(&mut send_fut).unwrap(); |
548 | 1 | } |
549 | | |
550 | | #[test] |
551 | 1 | fn test_close_channel_by_dropping_tx() { |
552 | 1 | let (mut txs, mut rxs) = channels(2); |
553 | 1 | |
554 | 1 | let tx0 = txs.remove(0); |
555 | 1 | let tx1 = txs.remove(0); |
556 | 1 | let tx0_clone = tx0.clone(); |
557 | 1 | |
558 | 1 | let mut recv_fut = rxs[0].recv(); |
559 | 1 | |
560 | 1 | poll_ready(&mut tx1.send("a")).unwrap(); |
561 | 1 | let recv_waker = poll_pending(&mut recv_fut); |
562 | 1 | |
563 | 1 | // drop original sender |
564 | 1 | drop(tx0); |
565 | 1 | |
566 | 1 | // not yet closed (there's a clone left) |
567 | 1 | assert!(!recv_waker.woken()); |
568 | 1 | poll_ready(&mut tx1.send("b")).unwrap(); |
569 | 1 | let recv_waker = poll_pending(&mut recv_fut); |
570 | 1 | |
571 | 1 | // create new clone |
572 | 1 | let tx0_clone2 = tx0_clone.clone(); |
573 | 1 | assert!(!recv_waker.woken()); |
574 | 1 | poll_ready(&mut tx1.send("c")).unwrap(); |
575 | 1 | let recv_waker = poll_pending(&mut recv_fut); |
576 | 1 | |
577 | 1 | // drop first clone |
578 | 1 | drop(tx0_clone); |
579 | 1 | assert!(!recv_waker.woken()); |
580 | 1 | poll_ready(&mut tx1.send("d")).unwrap(); |
581 | 1 | let recv_waker = poll_pending(&mut recv_fut); |
582 | 1 | |
583 | 1 | // drop last clone |
584 | 1 | drop(tx0_clone2); |
585 | 1 | |
586 | 1 | // channel closed => also close gate |
587 | 1 | poll_pending(&mut tx1.send("e")); |
588 | 1 | assert!(recv_waker.woken()); |
589 | 1 | assert_eq!(poll_ready(&mut recv_fut), None,); |
590 | 1 | } |
591 | | |
592 | | #[test] |
593 | 1 | fn test_close_channel_by_dropping_rx_on_open_gate() { |
594 | 1 | let (txs, mut rxs) = channels(2); |
595 | 1 | |
596 | 1 | let rx0 = rxs.remove(0); |
597 | 1 | let _rx1 = rxs.remove(0); |
598 | 1 | |
599 | 1 | poll_ready(&mut txs[1].send("a")).unwrap(); |
600 | 1 | |
601 | 1 | // drop receiver => also close gate |
602 | 1 | drop(rx0); |
603 | 1 | |
604 | 1 | poll_pending(&mut txs[1].send("b")); |
605 | 1 | assert_eq!(poll_ready(&mut txs[0].send("foo")), Err(SendError("foo")),); |
606 | 1 | } |
607 | | |
608 | | #[test] |
609 | 1 | fn test_close_channel_by_dropping_rx_on_closed_gate() { |
610 | 1 | let (txs, mut rxs) = channels(2); |
611 | 1 | |
612 | 1 | let rx0 = rxs.remove(0); |
613 | 1 | let mut rx1 = rxs.remove(0); |
614 | 1 | |
615 | 1 | // fill both channels |
616 | 1 | poll_ready(&mut txs[0].send("0_a")).unwrap(); |
617 | 1 | poll_ready(&mut txs[1].send("1_a")).unwrap(); |
618 | 1 | |
619 | 1 | let mut send_fut0 = txs[0].send("0_b"); |
620 | 1 | let mut send_fut1 = txs[1].send("1_b"); |
621 | 1 | let waker0 = poll_pending(&mut send_fut0); |
622 | 1 | let waker1 = poll_pending(&mut send_fut1); |
623 | 1 | |
624 | 1 | // drop receiver |
625 | 1 | drop(rx0); |
626 | 1 | |
627 | 1 | assert!(waker0.woken()); |
628 | 1 | assert!(!waker1.woken()); |
629 | 1 | assert_eq!(poll_ready(&mut send_fut0), Err(SendError("0_b")),); |
630 | | |
631 | | // gate closed, so cannot send on channel 1 |
632 | 1 | poll_pending(&mut send_fut1); |
633 | 1 | |
634 | 1 | // channel 1 can still receive data |
635 | 1 | assert_eq!(poll_ready(&mut rx1.recv()), Some("1_a"),); |
636 | 1 | } |
637 | | |
638 | | #[test] |
639 | 1 | fn test_drop_rx_three_channels() { |
640 | 1 | let (mut txs, mut rxs) = channels(3); |
641 | 1 | |
642 | 1 | let tx0 = txs.remove(0); |
643 | 1 | let tx1 = txs.remove(0); |
644 | 1 | let tx2 = txs.remove(0); |
645 | 1 | let mut rx0 = rxs.remove(0); |
646 | 1 | let rx1 = rxs.remove(0); |
647 | 1 | let _rx2 = rxs.remove(0); |
648 | 1 | |
649 | 1 | // fill channels |
650 | 1 | poll_ready(&mut tx0.send("0_a")).unwrap(); |
651 | 1 | poll_ready(&mut tx1.send("1_a")).unwrap(); |
652 | 1 | poll_ready(&mut tx2.send("2_a")).unwrap(); |
653 | 1 | |
654 | 1 | // drop / close one channel |
655 | 1 | drop(rx1); |
656 | 1 | |
657 | 1 | // receive data |
658 | 1 | assert_eq!(poll_ready(&mut rx0.recv()), Some("0_a"),); |
659 | | |
660 | | // use senders again |
661 | 1 | poll_ready(&mut tx0.send("0_b")).unwrap(); |
662 | 1 | assert_eq!(poll_ready(&mut tx1.send("1_b")), Err(SendError("1_b")),); |
663 | 1 | poll_pending(&mut tx2.send("2_b")); |
664 | 1 | } |
665 | | |
666 | | #[test] |
667 | 1 | fn test_close_channel_by_dropping_rx_clears_data() { |
668 | 1 | let (txs, rxs) = channels(1); |
669 | 1 | |
670 | 1 | let obj = Arc::new(()); |
671 | 1 | let counter = Arc::downgrade(&obj); |
672 | 1 | assert_eq!(counter.strong_count(), 1); |
673 | | |
674 | | // add object to channel |
675 | 1 | poll_ready(&mut txs[0].send(obj)).unwrap(); |
676 | 1 | assert_eq!(counter.strong_count(), 1); |
677 | | |
678 | | // drop receiver |
679 | 1 | drop(rxs); |
680 | 1 | |
681 | 1 | assert_eq!(counter.strong_count(), 0); |
682 | 1 | } |
683 | | |
684 | | /// Ensure that polling "pending" futures work even when you poll them too often (which happens under some circumstances). |
685 | | #[test] |
686 | 1 | fn test_poll_empty_channel_twice() { |
687 | 1 | let (txs, mut rxs) = channels(1); |
688 | 1 | |
689 | 1 | let mut recv_fut = rxs[0].recv(); |
690 | 1 | let waker_1a = poll_pending(&mut recv_fut); |
691 | 1 | let waker_1b = poll_pending(&mut recv_fut); |
692 | 1 | |
693 | 1 | let mut recv_fut = rxs[0].recv(); |
694 | 1 | let waker_2 = poll_pending(&mut recv_fut); |
695 | 1 | |
696 | 1 | poll_ready(&mut txs[0].send("a")).unwrap(); |
697 | 1 | assert!(waker_1a.woken()); |
698 | 1 | assert!(waker_1b.woken()); |
699 | 1 | assert!(waker_2.woken()); |
700 | 1 | assert_eq!(poll_ready(&mut recv_fut), Some("a"),); |
701 | | |
702 | 1 | poll_ready(&mut txs[0].send("b")).unwrap(); |
703 | 1 | let mut send_fut = txs[0].send("c"); |
704 | 1 | let waker_3 = poll_pending(&mut send_fut); |
705 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("b"),); |
706 | 1 | assert!(waker_3.woken()); |
707 | 1 | poll_ready(&mut send_fut).unwrap(); |
708 | 1 | assert_eq!(poll_ready(&mut rxs[0].recv()), Some("c")); |
709 | | |
710 | 1 | let mut recv_fut = rxs[0].recv(); |
711 | 1 | let waker_4 = poll_pending(&mut recv_fut); |
712 | 1 | |
713 | 1 | let mut recv_fut = rxs[0].recv(); |
714 | 1 | let waker_5 = poll_pending(&mut recv_fut); |
715 | 1 | |
716 | 1 | poll_ready(&mut txs[0].send("d")).unwrap(); |
717 | 1 | let mut send_fut = txs[0].send("e"); |
718 | 1 | let waker_6a = poll_pending(&mut send_fut); |
719 | 1 | let waker_6b = poll_pending(&mut send_fut); |
720 | 1 | |
721 | 1 | assert!(waker_4.woken()); |
722 | 1 | assert!(waker_5.woken()); |
723 | 1 | assert_eq!(poll_ready(&mut recv_fut), Some("d"),); |
724 | | |
725 | 1 | assert!(waker_6a.woken()); |
726 | 1 | assert!(waker_6b.woken()); |
727 | 1 | poll_ready(&mut send_fut).unwrap(); |
728 | 1 | } |
729 | | |
730 | | #[test] |
731 | | #[should_panic(expected = "polled ready future")] |
732 | 1 | fn test_panic_poll_send_future_after_ready_ok() { |
733 | 1 | let (txs, _rxs) = channels(1); |
734 | 1 | let mut fut = txs[0].send("foo"); |
735 | 1 | poll_ready(&mut fut).unwrap(); |
736 | 1 | poll_ready(&mut fut).ok(); |
737 | 1 | } |
738 | | |
739 | | #[test] |
740 | | #[should_panic(expected = "polled ready future")] |
741 | 1 | fn test_panic_poll_send_future_after_ready_err() { |
742 | 1 | let (txs, rxs) = channels(1); |
743 | 1 | |
744 | 1 | drop(rxs); |
745 | 1 | |
746 | 1 | let mut fut = txs[0].send("foo"); |
747 | 1 | poll_ready(&mut fut).unwrap_err(); |
748 | 1 | poll_ready(&mut fut).ok(); |
749 | 1 | } |
750 | | |
751 | | #[test] |
752 | | #[should_panic(expected = "polled ready future")] |
753 | 1 | fn test_panic_poll_recv_future_after_ready_some() { |
754 | 1 | let (txs, mut rxs) = channels(1); |
755 | 1 | |
756 | 1 | poll_ready(&mut txs[0].send("foo")).unwrap(); |
757 | 1 | |
758 | 1 | let mut fut = rxs[0].recv(); |
759 | 1 | poll_ready(&mut fut).unwrap(); |
760 | 1 | poll_ready(&mut fut); |
761 | 1 | } |
762 | | |
763 | | #[test] |
764 | | #[should_panic(expected = "polled ready future")] |
765 | 1 | fn test_panic_poll_recv_future_after_ready_none() { |
766 | 1 | let (txs, mut rxs) = channels::<u8>(1); |
767 | 1 | |
768 | 1 | drop(txs); |
769 | 1 | |
770 | 1 | let mut fut = rxs[0].recv(); |
771 | 1 | assert!(poll_ready(&mut fut).is_none()); |
772 | 1 | poll_ready(&mut fut); |
773 | 1 | } |
774 | | |
775 | | #[test] |
776 | | #[should_panic(expected = "future is pending")] |
777 | 1 | fn test_meta_poll_ready_wrong_state() { |
778 | 1 | let mut fut = futures::future::pending::<u8>(); |
779 | 1 | poll_ready(&mut fut); |
780 | 1 | } |
781 | | |
782 | | #[test] |
783 | | #[should_panic(expected = "future is ready")] |
784 | 1 | fn test_meta_poll_pending_wrong_state() { |
785 | 1 | let mut fut = futures::future::ready(1); |
786 | 1 | poll_pending(&mut fut); |
787 | 1 | } |
788 | | |
789 | | /// Test [`poll_pending`] (i.e. the testing utils, not the actual library code). |
790 | | #[test] |
791 | 1 | fn test_meta_poll_pending_waker() { |
792 | 1 | let (tx, mut rx) = futures::channel::oneshot::channel(); |
793 | 1 | let waker = poll_pending(&mut rx); |
794 | 1 | assert!(!waker.woken()); |
795 | 1 | tx.send(1).unwrap(); |
796 | 1 | assert!(waker.woken()); |
797 | 1 | } |
798 | | |
799 | | /// Poll a given [`Future`] and ensure it is [ready](Poll::Ready). |
800 | | #[track_caller] |
801 | 57 | fn poll_ready<F>(fut: &mut F) -> F::Output |
802 | 57 | where |
803 | 57 | F: Future + Unpin, |
804 | 57 | { |
805 | 57 | match poll(fut).0 { |
806 | 56 | Poll::Ready(x) => x, |
807 | 1 | Poll::Pending => panic!("future is pending"), |
808 | | } |
809 | 56 | } |
810 | | |
811 | | /// Poll a given [`Future`] and ensure it is [pending](Poll::Pending). |
812 | | /// |
813 | | /// Returns a waker that can later be checked. |
814 | | #[track_caller] |
815 | 23 | fn poll_pending<F>(fut: &mut F) -> Arc<TestWaker> |
816 | 23 | where |
817 | 23 | F: Future + Unpin, |
818 | 23 | { |
819 | 23 | let (res, waker) = poll(fut); |
820 | 23 | match res { |
821 | 1 | Poll::Ready(_) => panic!("future is ready"), |
822 | 22 | Poll::Pending => waker, |
823 | 22 | } |
824 | 22 | } |
825 | | |
826 | 80 | fn poll<F>(fut: &mut F) -> (Poll<F::Output>, Arc<TestWaker>) |
827 | 80 | where |
828 | 80 | F: Future + Unpin, |
829 | 80 | { |
830 | 80 | let test_waker = Arc::new(TestWaker::default()); |
831 | 80 | let waker = futures::task::waker(Arc::clone(&test_waker)); |
832 | 80 | let mut cx = std::task::Context::from_waker(&waker); |
833 | 80 | let res = fut.poll_unpin(&mut cx); |
834 | 80 | (res, test_waker) |
835 | 80 | } |
836 | | |
837 | | /// A test [`Waker`] that signal if [`wake`](Waker::wake) was called. |
838 | | #[derive(Debug, Default)] |
839 | | struct TestWaker { |
840 | | woken: AtomicBool, |
841 | | } |
842 | | |
843 | | impl TestWaker { |
844 | | /// Was [`wake`](Waker::wake) called? |
845 | 18 | fn woken(&self) -> bool { |
846 | 18 | self.woken.load(Ordering::SeqCst) |
847 | 18 | } |
848 | | } |
849 | | |
850 | | impl ArcWake for TestWaker { |
851 | 22 | fn wake_by_ref(arc_self: &Arc<Self>) { |
852 | 22 | arc_self.woken.store(true, Ordering::SeqCst); |
853 | 22 | } |
854 | | } |
855 | | } |