Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/repartition/distributor_channels.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Special channel construction to distribute data from various inputs into N outputs
19
//! minimizing buffering but preventing deadlocks when repartitoning
20
//!
21
//! # Design
22
//!
23
//! ```text
24
//! +----+      +------+
25
//! | TX |==||  | Gate |
26
//! +----+  ||  |      |  +--------+  +----+
27
//!         ====|      |==| Buffer |==| RX |
28
//! +----+  ||  |      |  +--------+  +----+
29
//! | TX |==||  |      |
30
//! +----+      |      |
31
//!             |      |
32
//! +----+      |      |  +--------+  +----+
33
//! | TX |======|      |==| Buffer |==| RX |
34
//! +----+      +------+  +--------+  +----+
35
//! ```
36
//!
37
//! There are `N` virtual MPSC (multi-producer, single consumer) channels with unbounded capacity. However, if all
38
//! buffers/channels are non-empty, than a global gate will be closed preventing new data from being written (the
39
//! sender futures will be [pending](Poll::Pending)) until at least one channel is empty (and not closed).
40
use std::{
41
    collections::VecDeque,
42
    future::Future,
43
    ops::DerefMut,
44
    pin::Pin,
45
    sync::{
46
        atomic::{AtomicUsize, Ordering},
47
        Arc,
48
    },
49
    task::{Context, Poll, Waker},
50
};
51
52
use parking_lot::Mutex;
53
54
/// Create `n` empty channels.
55
1.40k
pub fn channels<T>(
56
1.40k
    n: usize,
57
1.40k
) -> (Vec<DistributionSender<T>>, Vec<DistributionReceiver<T>>) {
58
1.40k
    let channels = (0..n)
59
5.57k
        .map(|id| Arc::new(Channel::new_with_one_sender(id)))
60
1.40k
        .collect::<Vec<_>>();
61
1.40k
    let gate = Arc::new(Gate {
62
1.40k
        empty_channels: AtomicUsize::new(n),
63
1.40k
        send_wakers: Mutex::new(None),
64
1.40k
    });
65
1.40k
    let senders = channels
66
1.40k
        .iter()
67
5.57k
        .map(|channel| DistributionSender {
68
5.57k
            channel: Arc::clone(channel),
69
5.57k
            gate: Arc::clone(&gate),
70
5.57k
        })
71
1.40k
        .collect();
72
1.40k
    let receivers = channels
73
1.40k
        .into_iter()
74
5.57k
        .map(|channel| DistributionReceiver {
75
5.57k
            channel,
76
5.57k
            gate: Arc::clone(&gate),
77
5.57k
        })
78
1.40k
        .collect();
79
1.40k
    (senders, receivers)
80
1.40k
}
81
82
type PartitionAwareSenders<T> = Vec<Vec<DistributionSender<T>>>;
83
type PartitionAwareReceivers<T> = Vec<Vec<DistributionReceiver<T>>>;
84
85
/// Create `n_out` empty channels for each of the `n_in` inputs.
86
/// This way, each distinct partition will communicate via a dedicated channel.
87
/// This SPSC structure enables us to track which partition input data comes from.
88
0
pub fn partition_aware_channels<T>(
89
0
    n_in: usize,
90
0
    n_out: usize,
91
0
) -> (PartitionAwareSenders<T>, PartitionAwareReceivers<T>) {
92
0
    (0..n_in).map(|_| channels(n_out)).unzip()
93
0
}
94
95
/// Erroring during [send](DistributionSender::send).
96
///
97
/// This occurs when the [receiver](DistributionReceiver) is gone.
98
#[derive(PartialEq, Eq)]
99
pub struct SendError<T>(pub T);
100
101
impl<T> std::fmt::Debug for SendError<T> {
102
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103
0
        f.debug_tuple("SendError").finish()
104
0
    }
105
}
106
107
impl<T> std::fmt::Display for SendError<T> {
108
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109
0
        write!(f, "cannot send data, receiver is gone")
110
0
    }
111
}
112
113
impl<T> std::error::Error for SendError<T> {}
114
115
/// Sender side of distribution [channels].
116
///
117
/// This handle can be cloned. All clones will write into the same channel. Dropping the last sender will close the
118
/// channel. In this case, the [receiver](DistributionReceiver) will still be able to poll the remaining data, but will
119
/// receive `None` afterwards.
120
#[derive(Debug)]
121
pub struct DistributionSender<T> {
122
    /// To prevent lock inversion / deadlock, channel lock is always acquired prior to gate lock
123
    channel: SharedChannel<T>,
124
    gate: SharedGate,
125
}
126
127
impl<T> DistributionSender<T> {
128
    /// Send data.
129
    ///
130
    /// This fails if the [receiver](DistributionReceiver) is gone.
131
23.0k
    pub fn send(&self, element: T) -> SendFuture<'_, T> {
132
23.0k
        SendFuture {
133
23.0k
            channel: &self.channel,
134
23.0k
            gate: &self.gate,
135
23.0k
            element: Box::new(Some(element)),
136
23.0k
        }
137
23.0k
    }
138
}
139
140
impl<T> Clone for DistributionSender<T> {
141
11.2k
    fn clone(&self) -> Self {
142
11.2k
        self.channel.n_senders.fetch_add(1, Ordering::SeqCst);
143
11.2k
144
11.2k
        Self {
145
11.2k
            channel: Arc::clone(&self.channel),
146
11.2k
            gate: Arc::clone(&self.gate),
147
11.2k
        }
148
11.2k
    }
149
}
150
151
impl<T> Drop for DistributionSender<T> {
152
16.8k
    fn drop(&mut self) {
153
16.8k
        let n_senders_pre = self.channel.n_senders.fetch_sub(1, Ordering::SeqCst);
154
16.8k
        // is the the last copy of the sender side?
155
16.8k
        if n_senders_pre > 1 {
156
11.2k
            return;
157
5.57k
        }
158
159
5.57k
        let receivers = {
160
5.57k
            let mut state = self.channel.state.lock();
161
5.57k
162
5.57k
            // During the shutdown of a empty channel, both the sender and the receiver side will be dropped. However we
163
5.57k
            // only want to decrement the "empty channels" counter once.
164
5.57k
            //
165
5.57k
            // We are within a critical section here, so we we can safely assume that either the last sender or the
166
5.57k
            // receiver (there's only one) will be dropped first.
167
5.57k
            //
168
5.57k
            // If the last sender is dropped first, `state.data` will still exists and the sender side decrements the
169
5.57k
            // signal. The receiver side then MUST check the `n_senders` counter during the section and if it is zero,
170
5.57k
            // it inferres that it is dropped afterwards and MUST NOT decrement the counter.
171
5.57k
            //
172
5.57k
            // If the receiver end is dropped first, it will inferr -- based on `n_senders` -- that there are still
173
5.57k
            // senders and it will decrement the `empty_channels` counter. It will also set `data` to `None`. The sender
174
5.57k
            // side will then see that `data` is `None` and can therefore inferr that the receiver end was dropped, and
175
5.57k
            // hence it MUST NOT decrement the `empty_channels` counter.
176
5.57k
            if state
177
5.57k
                .data
178
5.57k
                .as_ref()
179
5.57k
                .map(|data| 
data.is_empty()5.52k
)
180
5.57k
                .unwrap_or_default()
181
15
            {
182
15
                // channel is gone, so we need to clear our signal
183
15
                self.gate.decr_empty_channels();
184
5.55k
            }
185
186
            // make sure that nobody can add wakers anymore
187
5.57k
            state.recv_wakers.take().expect("not closed yet")
188
        };
189
190
        // wake outside of lock scope
191
5.57k
        for 
recv5
in receivers {
192
5
            recv.wake();
193
5
        }
194
16.8k
    }
195
}
196
197
/// Future backing [send](DistributionSender::send).
198
#[derive(Debug)]
199
pub struct SendFuture<'a, T> {
200
    channel: &'a SharedChannel<T>,
201
    gate: &'a SharedGate,
202
    // the additional Box is required for `Self: Unpin`
203
    element: Box<Option<T>>,
204
}
205
206
impl<'a, T> Future for SendFuture<'a, T> {
207
    type Output = Result<(), SendError<T>>;
208
209
26.6k
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210
26.6k
        let this = &mut *self;
211
26.6k
        assert!(this.element.is_some(), 
"polled ready future"2
);
212
213
        // lock scope
214
23.0k
        let to_wake = {
215
26.6k
            let mut guard_channel_state = this.channel.state.lock();
216
217
26.6k
            let Some(
data26.6k
) = guard_channel_state.data.as_mut() else {
218
                // receiver end dead
219
4
                return Poll::Ready(Err(SendError(
220
4
                    this.element.take().expect("just checked"),
221
4
                )));
222
            };
223
224
            // does ANY receiver need data?
225
            // if so, allow sender to create another
226
26.6k
            if this.gate.empty_channels.load(Ordering::SeqCst) == 0 {
227
3.58k
                let mut guard = this.gate.send_wakers.lock();
228
3.58k
                if let Some(send_wakers) = guard.deref_mut() {
229
3.58k
                    send_wakers.push((cx.waker().clone(), this.channel.id));
230
3.58k
                    return Poll::Pending;
231
0
                }
232
23.0k
            }
233
234
23.0k
            let was_empty = data.is_empty();
235
23.0k
            data.push_back(this.element.take().expect("just checked"));
236
23.0k
237
23.0k
            if was_empty {
238
10.3k
                this.gate.decr_empty_channels();
239
10.3k
                guard_channel_state.take_recv_wakers()
240
            } else {
241
12.6k
                Vec::with_capacity(0)
242
            }
243
        };
244
245
        // wake outside of lock scope
246
26.3k
        for 
receiver3.32k
in to_wake {
247
3.32k
            receiver.wake();
248
3.32k
        }
249
250
23.0k
        Poll::Ready(Ok(()))
251
26.6k
    }
252
}
253
254
/// Receiver side of distribution [channels].
255
#[derive(Debug)]
256
pub struct DistributionReceiver<T> {
257
    channel: SharedChannel<T>,
258
    gate: SharedGate,
259
}
260
261
impl<T> DistributionReceiver<T> {
262
    /// Receive data from channel.
263
    ///
264
    /// Returns `None` if the channel is empty and no [senders](DistributionSender) are left.
265
26.3k
    pub fn recv(&mut self) -> RecvFuture<'_, T> {
266
26.3k
        RecvFuture {
267
26.3k
            channel: &mut self.channel,
268
26.3k
            gate: &mut self.gate,
269
26.3k
            rdy: false,
270
26.3k
        }
271
26.3k
    }
272
}
273
274
impl<T> Drop for DistributionReceiver<T> {
275
5.57k
    fn drop(&mut self) {
276
5.57k
        let mut guard_channel_state = self.channel.state.lock();
277
5.57k
        let data = guard_channel_state.data.take().expect("not dropped yet");
278
5.57k
279
5.57k
        // See `DistributedSender::drop` for an explanation of the drop order and when the "empty channels" counter is
280
5.57k
        // decremented.
281
5.57k
        if data.is_empty() && 
(self.channel.n_senders.load(Ordering::SeqCst) > 0)5.56k
{
282
27
            // channel is gone, so we need to clear our signal
283
27
            self.gate.decr_empty_channels();
284
5.54k
        }
285
286
        // senders may be waiting for gate to open but should error now that the channel is closed
287
5.57k
        self.gate.wake_channel_senders(self.channel.id);
288
5.57k
    }
289
}
290
291
/// Future backing [recv](DistributionReceiver::recv).
292
pub struct RecvFuture<'a, T> {
293
    channel: &'a mut SharedChannel<T>,
294
    gate: &'a mut SharedGate,
295
    rdy: bool,
296
}
297
298
impl<'a, T> Future for RecvFuture<'a, T> {
299
    type Output = Option<T>;
300
301
26.3k
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
302
26.3k
        let this = &mut *self;
303
26.3k
        assert!(!this.rdy, 
"polled ready future"2
);
304
305
26.3k
        let mut guard_channel_state = this.channel.state.lock();
306
26.3k
        let channel_state = guard_channel_state.deref_mut();
307
26.3k
        let data = channel_state.data.as_mut().expect("not dropped yet");
308
26.3k
309
26.3k
        match data.pop_front() {
310
23.0k
            Some(element) => {
311
23.0k
                // change "empty" signal for this channel?
312
23.0k
                if data.is_empty() && 
channel_state.recv_wakers.is_some()10.3k
{
313
                    // update counter
314
4.85k
                    let old_counter =
315
4.85k
                        this.gate.empty_channels.fetch_add(1, Ordering::SeqCst);
316
317
                    // open gate?
318
4.85k
                    let to_wake = if old_counter == 0 {
319
3.39k
                        let mut guard = this.gate.send_wakers.lock();
320
3.39k
321
3.39k
                        // check after lock to see if we should still change the state
322
3.39k
                        if this.gate.empty_channels.load(Ordering::SeqCst) > 0 {
323
3.39k
                            guard.take().unwrap_or_default()
324
                        } else {
325
0
                            Vec::with_capacity(0)
326
                        }
327
                    } else {
328
1.45k
                        Vec::with_capacity(0)
329
                    };
330
331
4.85k
                    drop(guard_channel_state);
332
333
                    // wake outside of lock scope
334
8.42k
                    for (
waker, _channel_id3.57k
) in to_wake {
335
3.57k
                        waker.wake();
336
3.57k
                    }
337
18.1k
                }
338
339
23.0k
                this.rdy = true;
340
23.0k
                Poll::Ready(Some(element))
341
            }
342
            None => {
343
3.32k
                if let Some(
recv_wakers3.32k
) = channel_state.recv_wakers.as_mut() {
344
3.32k
                    recv_wakers.push(cx.waker().clone());
345
3.32k
                    Poll::Pending
346
                } else {
347
4
                    this.rdy = true;
348
4
                    Poll::Ready(None)
349
                }
350
            }
351
        }
352
26.3k
    }
353
}
354
355
/// Links senders and receivers.
356
#[derive(Debug)]
357
struct Channel<T> {
358
    /// Reference counter for the sender side.
359
    n_senders: AtomicUsize,
360
361
    /// Channel ID.
362
    ///
363
    /// This is used to address [send wakers](Gate::send_wakers).
364
    id: usize,
365
366
    /// Mutable state.
367
    state: Mutex<ChannelState<T>>,
368
}
369
370
impl<T> Channel<T> {
371
    /// Create new channel with one sender (so we don't need to [fetch-add](AtomicUsize::fetch_add) directly afterwards).
372
5.57k
    fn new_with_one_sender(id: usize) -> Self {
373
5.57k
        Channel {
374
5.57k
            n_senders: AtomicUsize::new(1),
375
5.57k
            id,
376
5.57k
            state: Mutex::new(ChannelState {
377
5.57k
                data: Some(VecDeque::default()),
378
5.57k
                recv_wakers: Some(Vec::default()),
379
5.57k
            }),
380
5.57k
        }
381
5.57k
    }
382
}
383
384
#[derive(Debug)]
385
struct ChannelState<T> {
386
    /// Buffered data.
387
    ///
388
    /// This is [`None`] when the receiver is gone.
389
    data: Option<VecDeque<T>>,
390
391
    /// Wakers for the receiver side.
392
    ///
393
    /// The receiver will be pending if the [buffer](Self::data) is empty and
394
    /// there are senders left (otherwise this is set to [`None`]).
395
    recv_wakers: Option<Vec<Waker>>,
396
}
397
398
impl<T> ChannelState<T> {
399
    /// Get all [`recv_wakers`](Self::recv_wakers) and replace with identically-sized buffer.
400
    ///
401
    /// The wakers should be woken AFTER the lock to [this state](Self) was dropped.
402
    ///
403
    /// # Panics
404
    /// Assumes that channel is NOT closed yet, i.e. that [`recv_wakers`](Self::recv_wakers) is not [`None`].
405
10.3k
    fn take_recv_wakers(&mut self) -> Vec<Waker> {
406
10.3k
        let to_wake = self.recv_wakers.as_mut().expect("not closed");
407
10.3k
        let mut tmp = Vec::with_capacity(to_wake.capacity());
408
10.3k
        std::mem::swap(to_wake, &mut tmp);
409
10.3k
        tmp
410
10.3k
    }
411
}
412
413
/// Shared channel.
414
///
415
/// One or multiple senders and a single receiver will share a channel.
416
type SharedChannel<T> = Arc<Channel<T>>;
417
418
/// The "all channels have data" gate.
419
#[derive(Debug)]
420
struct Gate {
421
    /// Number of currently empty (and still open) channels.
422
    empty_channels: AtomicUsize,
423
424
    /// Wakers for the sender side, including their channel IDs.
425
    ///
426
    /// This is `None` if the there are non-empty channels.
427
    send_wakers: Mutex<Option<Vec<(Waker, usize)>>>,
428
}
429
430
impl Gate {
431
    /// Wake senders for a specific channel.
432
    ///
433
    /// This is helpful to signal that the receiver side is gone and the senders shall now error.
434
5.57k
    fn wake_channel_senders(&self, id: usize) {
435
        // lock scope
436
5.57k
        let to_wake = {
437
5.57k
            let mut guard = self.send_wakers.lock();
438
439
5.57k
            if let Some(
send_wakers5.55k
) = guard.deref_mut() {
440
                // `drain_filter` is unstable, so implement our own
441
5.55k
                let (wake, keep) =
442
5.55k
                    send_wakers.drain(..).partition(|(_waker, id2)| 
id == *id24.04k
);
443
5.55k
444
5.55k
                *send_wakers = keep;
445
5.55k
446
5.55k
                wake
447
            } else {
448
15
                Vec::with_capacity(0)
449
            }
450
        };
451
452
        // wake outside of lock scope
453
5.58k
        for (
waker, _id7
) in to_wake {
454
7
            waker.wake();
455
7
        }
456
5.57k
    }
457
458
10.4k
    fn decr_empty_channels(&self) {
459
10.4k
        let old_count = self.empty_channels.fetch_sub(1, Ordering::SeqCst);
460
10.4k
461
10.4k
        if old_count == 1 {
462
4.79k
            let mut guard = self.send_wakers.lock();
463
4.79k
464
4.79k
            // double-check state during lock
465
4.79k
            if self.empty_channels.load(Ordering::SeqCst) == 0 && guard.is_none() {
466
4.79k
                *guard = Some(Vec::new());
467
4.79k
            }
0
468
5.62k
        }
469
10.4k
    }
470
}
471
472
/// Gate shared by all senders and receivers.
473
type SharedGate = Arc<Gate>;
474
475
#[cfg(test)]
476
mod tests {
477
    use std::sync::atomic::AtomicBool;
478
479
    use futures::{task::ArcWake, FutureExt};
480
481
    use super::*;
482
483
    #[test]
484
1
    fn test_single_channel_no_gate() {
485
1
        // use two channels so that the first one never hits the gate
486
1
        let (mut txs, mut rxs) = channels(2);
487
1
488
1
        let mut recv_fut = rxs[0].recv();
489
1
        let waker = poll_pending(&mut recv_fut);
490
1
491
1
        poll_ready(&mut txs[0].send("foo")).unwrap();
492
1
        assert!(waker.woken());
493
1
        assert_eq!(poll_ready(&mut recv_fut), Some("foo"),);
494
495
1
        poll_ready(&mut txs[0].send("bar")).unwrap();
496
1
        poll_ready(&mut txs[0].send("baz")).unwrap();
497
1
        poll_ready(&mut txs[0].send("end")).unwrap();
498
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),);
499
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("baz"),);
500
501
        // close channel
502
1
        txs.remove(0);
503
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("end"),);
504
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), None,);
505
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), None,);
506
1
    }
507
508
    #[test]
509
1
    fn test_multi_sender() {
510
1
        // use two channels so that the first one never hits the gate
511
1
        let (txs, mut rxs) = channels(2);
512
1
513
1
        let tx_clone = txs[0].clone();
514
1
515
1
        poll_ready(&mut txs[0].send("foo")).unwrap();
516
1
        poll_ready(&mut tx_clone.send("bar")).unwrap();
517
1
518
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("foo"),);
519
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),);
520
1
    }
521
522
    #[test]
523
1
    fn test_gate() {
524
1
        let (txs, mut rxs) = channels(2);
525
1
526
1
        // gate initially open
527
1
        poll_ready(&mut txs[0].send("0_a")).unwrap();
528
1
529
1
        // gate still open because channel 1 is still empty
530
1
        poll_ready(&mut txs[0].send("0_b")).unwrap();
531
1
532
1
        // gate still open because channel 1 is still empty prior to this call, so this call still goes through
533
1
        poll_ready(&mut txs[1].send("1_a")).unwrap();
534
1
535
1
        // both channels non-empty => gate closed
536
1
537
1
        let mut send_fut = txs[1].send("1_b");
538
1
        let waker = poll_pending(&mut send_fut);
539
1
540
1
        // drain channel 0
541
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_a"),);
542
1
        poll_pending(&mut send_fut);
543
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_b"),);
544
545
        // channel 0 empty => gate open
546
1
        assert!(waker.woken());
547
1
        poll_ready(&mut send_fut).unwrap();
548
1
    }
549
550
    #[test]
551
1
    fn test_close_channel_by_dropping_tx() {
552
1
        let (mut txs, mut rxs) = channels(2);
553
1
554
1
        let tx0 = txs.remove(0);
555
1
        let tx1 = txs.remove(0);
556
1
        let tx0_clone = tx0.clone();
557
1
558
1
        let mut recv_fut = rxs[0].recv();
559
1
560
1
        poll_ready(&mut tx1.send("a")).unwrap();
561
1
        let recv_waker = poll_pending(&mut recv_fut);
562
1
563
1
        // drop original sender
564
1
        drop(tx0);
565
1
566
1
        // not yet closed (there's a clone left)
567
1
        assert!(!recv_waker.woken());
568
1
        poll_ready(&mut tx1.send("b")).unwrap();
569
1
        let recv_waker = poll_pending(&mut recv_fut);
570
1
571
1
        // create new clone
572
1
        let tx0_clone2 = tx0_clone.clone();
573
1
        assert!(!recv_waker.woken());
574
1
        poll_ready(&mut tx1.send("c")).unwrap();
575
1
        let recv_waker = poll_pending(&mut recv_fut);
576
1
577
1
        // drop first clone
578
1
        drop(tx0_clone);
579
1
        assert!(!recv_waker.woken());
580
1
        poll_ready(&mut tx1.send("d")).unwrap();
581
1
        let recv_waker = poll_pending(&mut recv_fut);
582
1
583
1
        // drop last clone
584
1
        drop(tx0_clone2);
585
1
586
1
        // channel closed => also close gate
587
1
        poll_pending(&mut tx1.send("e"));
588
1
        assert!(recv_waker.woken());
589
1
        assert_eq!(poll_ready(&mut recv_fut), None,);
590
1
    }
591
592
    #[test]
593
1
    fn test_close_channel_by_dropping_rx_on_open_gate() {
594
1
        let (txs, mut rxs) = channels(2);
595
1
596
1
        let rx0 = rxs.remove(0);
597
1
        let _rx1 = rxs.remove(0);
598
1
599
1
        poll_ready(&mut txs[1].send("a")).unwrap();
600
1
601
1
        // drop receiver => also close gate
602
1
        drop(rx0);
603
1
604
1
        poll_pending(&mut txs[1].send("b"));
605
1
        assert_eq!(poll_ready(&mut txs[0].send("foo")), Err(SendError("foo")),);
606
1
    }
607
608
    #[test]
609
1
    fn test_close_channel_by_dropping_rx_on_closed_gate() {
610
1
        let (txs, mut rxs) = channels(2);
611
1
612
1
        let rx0 = rxs.remove(0);
613
1
        let mut rx1 = rxs.remove(0);
614
1
615
1
        // fill both channels
616
1
        poll_ready(&mut txs[0].send("0_a")).unwrap();
617
1
        poll_ready(&mut txs[1].send("1_a")).unwrap();
618
1
619
1
        let mut send_fut0 = txs[0].send("0_b");
620
1
        let mut send_fut1 = txs[1].send("1_b");
621
1
        let waker0 = poll_pending(&mut send_fut0);
622
1
        let waker1 = poll_pending(&mut send_fut1);
623
1
624
1
        // drop receiver
625
1
        drop(rx0);
626
1
627
1
        assert!(waker0.woken());
628
1
        assert!(!waker1.woken());
629
1
        assert_eq!(poll_ready(&mut send_fut0), Err(SendError("0_b")),);
630
631
        // gate closed, so cannot send on channel 1
632
1
        poll_pending(&mut send_fut1);
633
1
634
1
        // channel 1 can still receive data
635
1
        assert_eq!(poll_ready(&mut rx1.recv()), Some("1_a"),);
636
1
    }
637
638
    #[test]
639
1
    fn test_drop_rx_three_channels() {
640
1
        let (mut txs, mut rxs) = channels(3);
641
1
642
1
        let tx0 = txs.remove(0);
643
1
        let tx1 = txs.remove(0);
644
1
        let tx2 = txs.remove(0);
645
1
        let mut rx0 = rxs.remove(0);
646
1
        let rx1 = rxs.remove(0);
647
1
        let _rx2 = rxs.remove(0);
648
1
649
1
        // fill channels
650
1
        poll_ready(&mut tx0.send("0_a")).unwrap();
651
1
        poll_ready(&mut tx1.send("1_a")).unwrap();
652
1
        poll_ready(&mut tx2.send("2_a")).unwrap();
653
1
654
1
        // drop / close one channel
655
1
        drop(rx1);
656
1
657
1
        // receive data
658
1
        assert_eq!(poll_ready(&mut rx0.recv()), Some("0_a"),);
659
660
        // use senders again
661
1
        poll_ready(&mut tx0.send("0_b")).unwrap();
662
1
        assert_eq!(poll_ready(&mut tx1.send("1_b")), Err(SendError("1_b")),);
663
1
        poll_pending(&mut tx2.send("2_b"));
664
1
    }
665
666
    #[test]
667
1
    fn test_close_channel_by_dropping_rx_clears_data() {
668
1
        let (txs, rxs) = channels(1);
669
1
670
1
        let obj = Arc::new(());
671
1
        let counter = Arc::downgrade(&obj);
672
1
        assert_eq!(counter.strong_count(), 1);
673
674
        // add object to channel
675
1
        poll_ready(&mut txs[0].send(obj)).unwrap();
676
1
        assert_eq!(counter.strong_count(), 1);
677
678
        // drop receiver
679
1
        drop(rxs);
680
1
681
1
        assert_eq!(counter.strong_count(), 0);
682
1
    }
683
684
    /// Ensure that polling "pending" futures work even when you poll them too often (which happens under some circumstances).
685
    #[test]
686
1
    fn test_poll_empty_channel_twice() {
687
1
        let (txs, mut rxs) = channels(1);
688
1
689
1
        let mut recv_fut = rxs[0].recv();
690
1
        let waker_1a = poll_pending(&mut recv_fut);
691
1
        let waker_1b = poll_pending(&mut recv_fut);
692
1
693
1
        let mut recv_fut = rxs[0].recv();
694
1
        let waker_2 = poll_pending(&mut recv_fut);
695
1
696
1
        poll_ready(&mut txs[0].send("a")).unwrap();
697
1
        assert!(waker_1a.woken());
698
1
        assert!(waker_1b.woken());
699
1
        assert!(waker_2.woken());
700
1
        assert_eq!(poll_ready(&mut recv_fut), Some("a"),);
701
702
1
        poll_ready(&mut txs[0].send("b")).unwrap();
703
1
        let mut send_fut = txs[0].send("c");
704
1
        let waker_3 = poll_pending(&mut send_fut);
705
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("b"),);
706
1
        assert!(waker_3.woken());
707
1
        poll_ready(&mut send_fut).unwrap();
708
1
        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("c"));
709
710
1
        let mut recv_fut = rxs[0].recv();
711
1
        let waker_4 = poll_pending(&mut recv_fut);
712
1
713
1
        let mut recv_fut = rxs[0].recv();
714
1
        let waker_5 = poll_pending(&mut recv_fut);
715
1
716
1
        poll_ready(&mut txs[0].send("d")).unwrap();
717
1
        let mut send_fut = txs[0].send("e");
718
1
        let waker_6a = poll_pending(&mut send_fut);
719
1
        let waker_6b = poll_pending(&mut send_fut);
720
1
721
1
        assert!(waker_4.woken());
722
1
        assert!(waker_5.woken());
723
1
        assert_eq!(poll_ready(&mut recv_fut), Some("d"),);
724
725
1
        assert!(waker_6a.woken());
726
1
        assert!(waker_6b.woken());
727
1
        poll_ready(&mut send_fut).unwrap();
728
1
    }
729
730
    #[test]
731
    #[should_panic(expected = "polled ready future")]
732
1
    fn test_panic_poll_send_future_after_ready_ok() {
733
1
        let (txs, _rxs) = channels(1);
734
1
        let mut fut = txs[0].send("foo");
735
1
        poll_ready(&mut fut).unwrap();
736
1
        poll_ready(&mut fut).ok();
737
1
    }
738
739
    #[test]
740
    #[should_panic(expected = "polled ready future")]
741
1
    fn test_panic_poll_send_future_after_ready_err() {
742
1
        let (txs, rxs) = channels(1);
743
1
744
1
        drop(rxs);
745
1
746
1
        let mut fut = txs[0].send("foo");
747
1
        poll_ready(&mut fut).unwrap_err();
748
1
        poll_ready(&mut fut).ok();
749
1
    }
750
751
    #[test]
752
    #[should_panic(expected = "polled ready future")]
753
1
    fn test_panic_poll_recv_future_after_ready_some() {
754
1
        let (txs, mut rxs) = channels(1);
755
1
756
1
        poll_ready(&mut txs[0].send("foo")).unwrap();
757
1
758
1
        let mut fut = rxs[0].recv();
759
1
        poll_ready(&mut fut).unwrap();
760
1
        poll_ready(&mut fut);
761
1
    }
762
763
    #[test]
764
    #[should_panic(expected = "polled ready future")]
765
1
    fn test_panic_poll_recv_future_after_ready_none() {
766
1
        let (txs, mut rxs) = channels::<u8>(1);
767
1
768
1
        drop(txs);
769
1
770
1
        let mut fut = rxs[0].recv();
771
1
        assert!(poll_ready(&mut fut).is_none());
772
1
        poll_ready(&mut fut);
773
1
    }
774
775
    #[test]
776
    #[should_panic(expected = "future is pending")]
777
1
    fn test_meta_poll_ready_wrong_state() {
778
1
        let mut fut = futures::future::pending::<u8>();
779
1
        poll_ready(&mut fut);
780
1
    }
781
782
    #[test]
783
    #[should_panic(expected = "future is ready")]
784
1
    fn test_meta_poll_pending_wrong_state() {
785
1
        let mut fut = futures::future::ready(1);
786
1
        poll_pending(&mut fut);
787
1
    }
788
789
    /// Test [`poll_pending`] (i.e. the testing utils, not the actual library code).
790
    #[test]
791
1
    fn test_meta_poll_pending_waker() {
792
1
        let (tx, mut rx) = futures::channel::oneshot::channel();
793
1
        let waker = poll_pending(&mut rx);
794
1
        assert!(!waker.woken());
795
1
        tx.send(1).unwrap();
796
1
        assert!(waker.woken());
797
1
    }
798
799
    /// Poll a given [`Future`] and ensure it is [ready](Poll::Ready).
800
    #[track_caller]
801
57
    fn poll_ready<F>(fut: &mut F) -> F::Output
802
57
    where
803
57
        F: Future + Unpin,
804
57
    {
805
57
        match poll(fut).0 {
806
56
            Poll::Ready(x) => x,
807
1
            Poll::Pending => panic!("future is pending"),
808
        }
809
56
    }
810
811
    /// Poll a given [`Future`] and ensure it is [pending](Poll::Pending).
812
    ///
813
    /// Returns a waker that can later be checked.
814
    #[track_caller]
815
23
    fn poll_pending<F>(fut: &mut F) -> Arc<TestWaker>
816
23
    where
817
23
        F: Future + Unpin,
818
23
    {
819
23
        let (res, waker) = poll(fut);
820
23
        match res {
821
1
            Poll::Ready(_) => panic!("future is ready"),
822
22
            Poll::Pending => waker,
823
22
        }
824
22
    }
825
826
80
    fn poll<F>(fut: &mut F) -> (Poll<F::Output>, Arc<TestWaker>)
827
80
    where
828
80
        F: Future + Unpin,
829
80
    {
830
80
        let test_waker = Arc::new(TestWaker::default());
831
80
        let waker = futures::task::waker(Arc::clone(&test_waker));
832
80
        let mut cx = std::task::Context::from_waker(&waker);
833
80
        let res = fut.poll_unpin(&mut cx);
834
80
        (res, test_waker)
835
80
    }
836
837
    /// A test [`Waker`] that signal if [`wake`](Waker::wake) was called.
838
    #[derive(Debug, Default)]
839
    struct TestWaker {
840
        woken: AtomicBool,
841
    }
842
843
    impl TestWaker {
844
        /// Was [`wake`](Waker::wake) called?
845
18
        fn woken(&self) -> bool {
846
18
            self.woken.load(Ordering::SeqCst)
847
18
        }
848
    }
849
850
    impl ArcWake for TestWaker {
851
22
        fn wake_by_ref(arc_self: &Arc<Self>) {
852
22
            arc_self.woken.store(true, Ordering::SeqCst);
853
22
        }
854
    }
855
}