Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: I/O safety for 'poll' #1937

Merged
merged 1 commit into from
Dec 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions src/poll.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Wait for events to trigger on specific file descriptors
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd};

use crate::errno::Errno;
use crate::Result;
Expand All @@ -14,20 +14,36 @@ use crate::Result;
/// retrieved by calling [`revents()`](#method.revents) on the `PollFd`.
#[repr(transparent)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct PollFd {
pub struct PollFd<'fd> {
pollfd: libc::pollfd,
_fd: std::marker::PhantomData<BorrowedFd<'fd>>,
}

impl PollFd {
impl<'fd> PollFd<'fd> {
/// Creates a new `PollFd` specifying the events of interest
/// for a given file descriptor.
pub const fn new(fd: RawFd, events: PollFlags) -> PollFd {
//
// Different from other I/O-safe interfaces, here, we have to take `AsFd`
// by reference to prevent the case where the `fd` is closed but it is
// still in use. For example:
//
// ```rust
// let (reader, _) = pipe().unwrap();
SteveLauC marked this conversation as resolved.
Show resolved Hide resolved
//
// // If `PollFd::new()` takes `AsFd` by value, then `reader` will be consumed,
// // but the file descriptor of `reader` will still be in use.
// let pollfd = PollFd::new(reader, flag);
//
// // Do something with `pollfd`, which uses the CLOSED fd.
// ```
pub fn new<Fd: AsFd>(fd: &'fd Fd, events: PollFlags) -> PollFd<'fd> {
PollFd {
pollfd: libc::pollfd {
fd,
fd: fd.as_fd().as_raw_fd(),
events: events.bits(),
revents: PollFlags::empty().bits(),
},
_fd: std::marker::PhantomData,
}
}

Expand Down Expand Up @@ -68,9 +84,29 @@ impl PollFd {
}
}

impl AsRawFd for PollFd {
fn as_raw_fd(&self) -> RawFd {
self.pollfd.fd
impl<'fd> AsFd for PollFd<'fd> {
fn as_fd(&self) -> BorrowedFd<'_> {
// Safety:
//
// BorrowedFd::borrow_raw(RawFd) requires that the raw fd being passed
// must remain open for the duration of the returned BorrowedFd, this is
// guaranteed as the returned BorrowedFd has the lifetime parameter same
// as `self`:
// "fn as_fd<'self>(&'self self) -> BorrowedFd<'self>"
// which means that `self` (PollFd) is guaranteed to outlive the returned
// BorrowedFd. (Lifetime: PollFd > BorrowedFd)
//
// And the lifetime parameter of PollFd::new(fd, ...) ensures that `fd`
// (an owned file descriptor) must outlive the returned PollFd:
// "pub fn new<Fd: AsFd>(fd: &'fd Fd, events: PollFlags) -> PollFd<'fd>"
// (Lifetime: Owned fd > PollFd)
//
// With two above relationships, we can conclude that the `Owned file
// descriptor` will outlive the returned BorrowedFd,
// (Lifetime: Owned fd > BorrowedFd)
// i.e., the raw fd being passed will remain valid for the lifetime of
// the returned BorrowedFd.
unsafe { BorrowedFd::borrow_raw(self.pollfd.fd) }
}
}

Expand Down
22 changes: 10 additions & 12 deletions test/test_poll.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use nix::{
errno::Errno,
poll::{poll, PollFd, PollFlags},
unistd::{pipe, write},
unistd::{close, pipe, write},
};
use std::os::unix::io::{BorrowedFd, FromRawFd, OwnedFd};

macro_rules! loop_while_eintr {
($poll_expr: expr) => {
Expand All @@ -19,7 +20,8 @@ macro_rules! loop_while_eintr {
#[test]
fn test_poll() {
let (r, w) = pipe().unwrap();
let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
let r = unsafe { OwnedFd::from_raw_fd(r) };
let mut fds = [PollFd::new(&r, PollFlags::POLLIN)];

// Poll an idle pipe. Should timeout
let nfds = loop_while_eintr!(poll(&mut fds, 100));
Expand All @@ -32,6 +34,7 @@ fn test_poll() {
let nfds = poll(&mut fds, 100).unwrap();
assert_eq!(nfds, 1);
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
close(w).unwrap();
}

// ppoll(2) is the same as poll except for how it handles timeouts and signals.
Expand All @@ -51,7 +54,8 @@ fn test_ppoll() {

let timeout = TimeSpec::milliseconds(1);
let (r, w) = pipe().unwrap();
let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
let r = unsafe { OwnedFd::from_raw_fd(r) };
let mut fds = [PollFd::new(&r, PollFlags::POLLIN)];

// Poll an idle pipe. Should timeout
let sigset = SigSet::empty();
Expand All @@ -65,19 +69,13 @@ fn test_ppoll() {
let nfds = ppoll(&mut fds, Some(timeout), None).unwrap();
assert_eq!(nfds, 1);
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
}

#[test]
fn test_pollfd_fd() {
use std::os::unix::io::AsRawFd;

let pfd = PollFd::new(0x1234, PollFlags::empty());
assert_eq!(pfd.as_raw_fd(), 0x1234);
close(w).unwrap();
}

#[test]
fn test_pollfd_events() {
let mut pfd = PollFd::new(-1, PollFlags::POLLIN);
let fd_zero = unsafe { BorrowedFd::borrow_raw(0) };
let mut pfd = PollFd::new(&fd_zero, PollFlags::POLLIN);
assert_eq!(pfd.events(), PollFlags::POLLIN);
pfd.set_events(PollFlags::POLLOUT);
assert_eq!(pfd.events(), PollFlags::POLLOUT);
Expand Down