Skip to content

Commit

Permalink
Merge #1937
Browse files Browse the repository at this point in the history
1937: feat: I/O safety for 'poll' r=asomers a=SteveLauC

#### What this PR does:
1. Adds I/O safety for module `poll`.

Co-authored-by: Steve Lau <[email protected]>
  • Loading branch information
bors[bot] and SteveLauC authored Dec 11, 2022
2 parents df5877c + 47ecc9a commit 6c8ff7b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
52 changes: 44 additions & 8 deletions src/poll.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Wait for events to trigger on specific file descriptors
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd};

use crate::errno::Errno;
use crate::Result;
Expand All @@ -14,20 +14,36 @@ use crate::Result;
/// retrieved by calling [`revents()`](#method.revents) on the `PollFd`.
#[repr(transparent)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct PollFd {
pub struct PollFd<'fd> {
pollfd: libc::pollfd,
_fd: std::marker::PhantomData<BorrowedFd<'fd>>,
}

impl PollFd {
impl<'fd> PollFd<'fd> {
/// Creates a new `PollFd` specifying the events of interest
/// for a given file descriptor.
pub const fn new(fd: RawFd, events: PollFlags) -> PollFd {
//
// Different from other I/O-safe interfaces, here, we have to take `AsFd`
// by reference to prevent the case where the `fd` is closed but it is
// still in use. For example:
//
// ```rust
// let (reader, _) = pipe().unwrap();
//
// // If `PollFd::new()` takes `AsFd` by value, then `reader` will be consumed,
// // but the file descriptor of `reader` will still be in use.
// let pollfd = PollFd::new(reader, flag);
//
// // Do something with `pollfd`, which uses the CLOSED fd.
// ```
pub fn new<Fd: AsFd>(fd: &'fd Fd, events: PollFlags) -> PollFd<'fd> {
PollFd {
pollfd: libc::pollfd {
fd,
fd: fd.as_fd().as_raw_fd(),
events: events.bits(),
revents: PollFlags::empty().bits(),
},
_fd: std::marker::PhantomData,
}
}

Expand Down Expand Up @@ -68,9 +84,29 @@ impl PollFd {
}
}

impl AsRawFd for PollFd {
fn as_raw_fd(&self) -> RawFd {
self.pollfd.fd
impl<'fd> AsFd for PollFd<'fd> {
fn as_fd(&self) -> BorrowedFd<'_> {
// Safety:
//
// BorrowedFd::borrow_raw(RawFd) requires that the raw fd being passed
// must remain open for the duration of the returned BorrowedFd, this is
// guaranteed as the returned BorrowedFd has the lifetime parameter same
// as `self`:
// "fn as_fd<'self>(&'self self) -> BorrowedFd<'self>"
// which means that `self` (PollFd) is guaranteed to outlive the returned
// BorrowedFd. (Lifetime: PollFd > BorrowedFd)
//
// And the lifetime parameter of PollFd::new(fd, ...) ensures that `fd`
// (an owned file descriptor) must outlive the returned PollFd:
// "pub fn new<Fd: AsFd>(fd: &'fd Fd, events: PollFlags) -> PollFd<'fd>"
// (Lifetime: Owned fd > PollFd)
//
// With two above relationships, we can conclude that the `Owned file
// descriptor` will outlive the returned BorrowedFd,
// (Lifetime: Owned fd > BorrowedFd)
// i.e., the raw fd being passed will remain valid for the lifetime of
// the returned BorrowedFd.
unsafe { BorrowedFd::borrow_raw(self.pollfd.fd) }
}
}

Expand Down
22 changes: 10 additions & 12 deletions test/test_poll.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use nix::{
errno::Errno,
poll::{poll, PollFd, PollFlags},
unistd::{pipe, write},
unistd::{close, pipe, write},
};
use std::os::unix::io::{BorrowedFd, FromRawFd, OwnedFd};

macro_rules! loop_while_eintr {
($poll_expr: expr) => {
Expand All @@ -19,7 +20,8 @@ macro_rules! loop_while_eintr {
#[test]
fn test_poll() {
let (r, w) = pipe().unwrap();
let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
let r = unsafe { OwnedFd::from_raw_fd(r) };
let mut fds = [PollFd::new(&r, PollFlags::POLLIN)];

// Poll an idle pipe. Should timeout
let nfds = loop_while_eintr!(poll(&mut fds, 100));
Expand All @@ -32,6 +34,7 @@ fn test_poll() {
let nfds = poll(&mut fds, 100).unwrap();
assert_eq!(nfds, 1);
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
close(w).unwrap();
}

// ppoll(2) is the same as poll except for how it handles timeouts and signals.
Expand All @@ -51,7 +54,8 @@ fn test_ppoll() {

let timeout = TimeSpec::milliseconds(1);
let (r, w) = pipe().unwrap();
let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
let r = unsafe { OwnedFd::from_raw_fd(r) };
let mut fds = [PollFd::new(&r, PollFlags::POLLIN)];

// Poll an idle pipe. Should timeout
let sigset = SigSet::empty();
Expand All @@ -65,19 +69,13 @@ fn test_ppoll() {
let nfds = ppoll(&mut fds, Some(timeout), None).unwrap();
assert_eq!(nfds, 1);
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
}

#[test]
fn test_pollfd_fd() {
use std::os::unix::io::AsRawFd;

let pfd = PollFd::new(0x1234, PollFlags::empty());
assert_eq!(pfd.as_raw_fd(), 0x1234);
close(w).unwrap();
}

#[test]
fn test_pollfd_events() {
let mut pfd = PollFd::new(-1, PollFlags::POLLIN);
let fd_zero = unsafe { BorrowedFd::borrow_raw(0) };
let mut pfd = PollFd::new(&fd_zero, PollFlags::POLLIN);
assert_eq!(pfd.events(), PollFlags::POLLIN);
pfd.set_events(PollFlags::POLLOUT);
assert_eq!(pfd.events(), PollFlags::POLLOUT);
Expand Down

0 comments on commit 6c8ff7b

Please sign in to comment.