From 4df7dabe8ddffbf1e2cca59213ca003679d85565 Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Sat, 10 Dec 2022 21:01:52 +0800 Subject: [PATCH] feat: I/O safety for 'poll' --- src/poll.rs | 34 ++++++++++++++++++++++++++++------ test/test_poll.rs | 22 ++++++++++------------ 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/poll.rs b/src/poll.rs index 6f227fee9e..1877ac105e 100644 --- a/src/poll.rs +++ b/src/poll.rs @@ -1,5 +1,5 @@ //! Wait for events to trigger on specific file descriptors -use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd}; use crate::errno::Errno; use crate::Result; @@ -21,10 +21,24 @@ pub struct PollFd { impl PollFd { /// Creates a new `PollFd` specifying the events of interest /// for a given file descriptor. - pub const fn new(fd: RawFd, events: PollFlags) -> PollFd { + // + // Different from other I/O-safe interfaces, here, we have to take `AsFd` + // by reference to prevent the case where the `fd` is closed but it is + // still in use. For example: + // + // ```rust + // let file: File = File::open("some_file").unwrap(); + // + // // If `PollFd::new()` takes `AsFd` by value, then `file` will be consumed, + // // but the internal file descriptor of `file` will still be in use. + // let pollfd = PollFd::new(file, flag); + // + // // Do something with `pollfd`, which uses the CLOSED fd. + // ``` + pub fn new(fd: &Fd, events: PollFlags) -> PollFd { PollFd { pollfd: libc::pollfd { - fd, + fd: fd.as_fd().as_raw_fd(), events: events.bits(), revents: PollFlags::empty().bits(), }, @@ -68,9 +82,17 @@ impl PollFd { } } -impl AsRawFd for PollFd { - fn as_raw_fd(&self) -> RawFd { - self.pollfd.fd +impl AsFd for PollFd { + fn as_fd(&self) -> BorrowedFd<'_> { + // BorrowedFd::borrow_raw() requires that the raw fd being passed must + // remain open for the duration of the returned BorrowedFd, this is + // guaranteed as the returned BorrowedFd has lifetime parameter same + // as `self`: + // "fn as_fd<'self>(&'self self) -> BorrowedFd<'self>" + // which means that `self` (PollFd) is guaranteed to outlive the returned + // BorrowedFd. Since `self` is valid, the fd within `self` is valid, + // then the returned BorrowedFd is valid. + unsafe { BorrowedFd::borrow_raw(self.pollfd.fd) } } } diff --git a/test/test_poll.rs b/test/test_poll.rs index 53964e26bb..045ccd3df1 100644 --- a/test/test_poll.rs +++ b/test/test_poll.rs @@ -1,8 +1,9 @@ use nix::{ errno::Errno, poll::{poll, PollFd, PollFlags}, - unistd::{pipe, write}, + unistd::{close, pipe, write}, }; +use std::os::unix::io::{BorrowedFd, FromRawFd, OwnedFd}; macro_rules! loop_while_eintr { ($poll_expr: expr) => { @@ -19,7 +20,8 @@ macro_rules! loop_while_eintr { #[test] fn test_poll() { let (r, w) = pipe().unwrap(); - let mut fds = [PollFd::new(r, PollFlags::POLLIN)]; + let r = unsafe { OwnedFd::from_raw_fd(r) }; + let mut fds = [PollFd::new(&r, PollFlags::POLLIN)]; // Poll an idle pipe. Should timeout let nfds = loop_while_eintr!(poll(&mut fds, 100)); @@ -32,6 +34,7 @@ fn test_poll() { let nfds = poll(&mut fds, 100).unwrap(); assert_eq!(nfds, 1); assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN)); + close(w).unwrap(); } // ppoll(2) is the same as poll except for how it handles timeouts and signals. @@ -51,7 +54,8 @@ fn test_ppoll() { let timeout = TimeSpec::milliseconds(1); let (r, w) = pipe().unwrap(); - let mut fds = [PollFd::new(r, PollFlags::POLLIN)]; + let r = unsafe { OwnedFd::from_raw_fd(r) }; + let mut fds = [PollFd::new(&r, PollFlags::POLLIN)]; // Poll an idle pipe. Should timeout let sigset = SigSet::empty(); @@ -65,19 +69,13 @@ fn test_ppoll() { let nfds = ppoll(&mut fds, Some(timeout), None).unwrap(); assert_eq!(nfds, 1); assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN)); -} - -#[test] -fn test_pollfd_fd() { - use std::os::unix::io::AsRawFd; - - let pfd = PollFd::new(0x1234, PollFlags::empty()); - assert_eq!(pfd.as_raw_fd(), 0x1234); + close(w).unwrap(); } #[test] fn test_pollfd_events() { - let mut pfd = PollFd::new(-1, PollFlags::POLLIN); + let fd_zero = unsafe { BorrowedFd::borrow_raw(0) }; + let mut pfd = PollFd::new(&fd_zero, PollFlags::POLLIN); assert_eq!(pfd.events(), PollFlags::POLLIN); pfd.set_events(PollFlags::POLLOUT); assert_eq!(pfd.events(), PollFlags::POLLOUT);