From 2be852de8e5ccb4ec6933836d7d76fa6bde62763 Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Mon, 12 Dec 2022 14:18:49 +0800 Subject: [PATCH] feat: I/O safety for 'sys/select' --- src/sys/select.rs | 313 ++++++++++++++++++++++++++-------------- test/sys/test_select.rs | 31 ++-- 2 files changed, 222 insertions(+), 122 deletions(-) diff --git a/src/sys/select.rs b/src/sys/select.rs index 7a94cff87e..0e2193b130 100644 --- a/src/sys/select.rs +++ b/src/sys/select.rs @@ -7,7 +7,7 @@ use std::convert::TryFrom; use std::iter::FusedIterator; use std::mem; use std::ops::Range; -use std::os::unix::io::RawFd; +use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd}; use std::ptr::{null, null_mut}; pub use libc::FD_SETSIZE; @@ -15,7 +15,10 @@ pub use libc::FD_SETSIZE; /// Contains a set of file descriptors used by [`select`] #[repr(transparent)] #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub struct FdSet(libc::fd_set); +pub struct FdSet<'fd> { + set: libc::fd_set, + _fd: std::marker::PhantomData>, +} fn assert_fd_valid(fd: RawFd) { assert!( @@ -24,37 +27,40 @@ fn assert_fd_valid(fd: RawFd) { ); } -impl FdSet { +impl<'fd> FdSet<'fd> { /// Create an empty `FdSet` - pub fn new() -> FdSet { + pub fn new() -> FdSet<'fd> { let mut fdset = mem::MaybeUninit::uninit(); unsafe { libc::FD_ZERO(fdset.as_mut_ptr()); - FdSet(fdset.assume_init()) + Self { + set: fdset.assume_init(), + _fd: std::marker::PhantomData, + } } } /// Add a file descriptor to an `FdSet` - pub fn insert(&mut self, fd: RawFd) { - assert_fd_valid(fd); - unsafe { libc::FD_SET(fd, &mut self.0) }; + pub fn insert(&mut self, fd: &'fd Fd) { + assert_fd_valid(fd.as_fd().as_raw_fd()); + unsafe { libc::FD_SET(fd.as_fd().as_raw_fd(), &mut self.set) }; } /// Remove a file descriptor from an `FdSet` - pub fn remove(&mut self, fd: RawFd) { - assert_fd_valid(fd); - unsafe { libc::FD_CLR(fd, &mut self.0) }; + pub fn remove(&mut self, fd: &'fd Fd) { + assert_fd_valid(fd.as_fd().as_raw_fd()); + unsafe { libc::FD_CLR(fd.as_fd().as_raw_fd(), &mut self.set) }; } /// Test an `FdSet` for the presence of a certain file descriptor. - pub fn contains(&self, fd: RawFd) -> bool { - assert_fd_valid(fd); - unsafe { libc::FD_ISSET(fd, &self.0) } + pub fn contains(&self, fd: &'fd Fd) -> bool { + assert_fd_valid(fd.as_fd().as_raw_fd()); + unsafe { libc::FD_ISSET(fd.as_fd().as_raw_fd(), &self.set) } } /// Remove all file descriptors from this `FdSet`. pub fn clear(&mut self) { - unsafe { libc::FD_ZERO(&mut self.0) }; + unsafe { libc::FD_ZERO(&mut self.set) }; } /// Finds the highest file descriptor in the set. @@ -66,15 +72,18 @@ impl FdSet { /// # Example /// /// ``` + /// # use std::os::unix::io::{AsRawFd, BorrowedFd}; /// # use nix::sys::select::FdSet; + /// let fd_four = unsafe {BorrowedFd::borrow_raw(4)}; + /// let fd_nine = unsafe {BorrowedFd::borrow_raw(9)}; /// let mut set = FdSet::new(); - /// set.insert(4); - /// set.insert(9); - /// assert_eq!(set.highest(), Some(9)); + /// set.insert(&fd_four); + /// set.insert(&fd_nine); + /// assert_eq!(set.highest().map(|borrowed_fd|borrowed_fd.as_raw_fd()), Some(9)); /// ``` /// /// [`select`]: fn.select.html - pub fn highest(&self) -> Option { + pub fn highest(&self) -> Option> { self.fds(None).next_back() } @@ -88,11 +97,13 @@ impl FdSet { /// /// ``` /// # use nix::sys::select::FdSet; - /// # use std::os::unix::io::RawFd; + /// # use std::os::unix::io::{AsRawFd, BorrowedFd, RawFd}; /// let mut set = FdSet::new(); - /// set.insert(4); - /// set.insert(9); - /// let fds: Vec = set.fds(None).collect(); + /// let fd_four = unsafe {BorrowedFd::borrow_raw(4)}; + /// let fd_nine = unsafe {BorrowedFd::borrow_raw(9)}; + /// set.insert(&fd_four); + /// set.insert(&fd_nine); + /// let fds: Vec = set.fds(None).map(|borrowed_fd|borrowed_fd.as_raw_fd()).collect(); /// assert_eq!(fds, vec![4, 9]); /// ``` #[inline] @@ -104,7 +115,7 @@ impl FdSet { } } -impl Default for FdSet { +impl<'fd> Default for FdSet<'fd> { fn default() -> Self { Self::new() } @@ -112,18 +123,19 @@ impl Default for FdSet { /// Iterator over `FdSet`. #[derive(Debug)] -pub struct Fds<'a> { - set: &'a FdSet, +pub struct Fds<'a, 'fd> { + set: &'a FdSet<'fd>, range: Range, } -impl<'a> Iterator for Fds<'a> { - type Item = RawFd; +impl<'a, 'fd> Iterator for Fds<'a, 'fd> { + type Item = BorrowedFd<'fd>; - fn next(&mut self) -> Option { + fn next(&mut self) -> Option { for i in &mut self.range { - if self.set.contains(i as RawFd) { - return Some(i as RawFd); + let borrowed_i = unsafe { BorrowedFd::borrow_raw(i as RawFd) }; + if self.set.contains(&borrowed_i) { + return Some(borrowed_i); } } None @@ -136,19 +148,20 @@ impl<'a> Iterator for Fds<'a> { } } -impl<'a> DoubleEndedIterator for Fds<'a> { +impl<'a, 'fd> DoubleEndedIterator for Fds<'a, 'fd> { #[inline] - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option> { while let Some(i) = self.range.next_back() { - if self.set.contains(i as RawFd) { - return Some(i as RawFd); + let borrowed_i = unsafe { BorrowedFd::borrow_raw(i as RawFd) }; + if self.set.contains(&borrowed_i) { + return Some(borrowed_i); } } None } } -impl<'a> FusedIterator for Fds<'a> {} +impl<'a, 'fd> FusedIterator for Fds<'a, 'fd> {} /// Monitors file descriptors for readiness /// @@ -173,7 +186,7 @@ impl<'a> FusedIterator for Fds<'a> {} /// [select(2)](https://pubs.opengroup.org/onlinepubs/9699919799/functions/select.html) /// /// [`FdSet::highest`]: struct.FdSet.html#method.highest -pub fn select<'a, N, R, W, E, T>( +pub fn select<'a, 'fd, N, R, W, E, T>( nfds: N, readfds: R, writefds: W, @@ -181,10 +194,11 @@ pub fn select<'a, N, R, W, E, T>( timeout: T, ) -> Result where + 'fd: 'a, N: Into>, - R: Into>, - W: Into>, - E: Into>, + R: Into>>, + W: Into>>, + E: Into>>, T: Into>, { let mut readfds = readfds.into(); @@ -197,7 +211,11 @@ where .iter_mut() .chain(writefds.iter_mut()) .chain(errorfds.iter_mut()) - .map(|set| set.highest().unwrap_or(-1)) + .map(|set| { + set.highest() + .map(|borrowed_fd| borrowed_fd.as_raw_fd()) + .unwrap_or(-1) + }) .max() .unwrap_or(-1) + 1 @@ -256,17 +274,18 @@ use crate::sys::signal::SigSet; /// [The new pselect() system call](https://lwn.net/Articles/176911/) /// /// [`FdSet::highest`]: struct.FdSet.html#method.highest -pub fn pselect<'a, N, R, W, E, T, S>(nfds: N, +pub fn pselect<'a, 'fd, N, R, W, E, T, S>(nfds: N, readfds: R, writefds: W, errorfds: E, timeout: T, sigmask: S) -> Result where + 'fd: 'a, N: Into>, - R: Into>, - W: Into>, - E: Into>, + R: Into>>, + W: Into>>, + E: Into>>, T: Into>, S: Into>, { @@ -280,7 +299,7 @@ where readfds.iter_mut() .chain(writefds.iter_mut()) .chain(errorfds.iter_mut()) - .map(|set| set.highest().unwrap_or(-1)) + .map(|set| set.highest().map(|borrowed_fd|borrowed_fd.as_raw_fd()).unwrap_or(-1)) .max() .unwrap_or(-1) + 1 }); @@ -303,20 +322,22 @@ where mod tests { use super::*; use crate::sys::time::{TimeVal, TimeValLike}; - use crate::unistd::{pipe, write}; - use std::os::unix::io::RawFd; + use crate::unistd::{close, pipe, write}; + use std::os::unix::io::{FromRawFd, OwnedFd, RawFd}; #[test] fn fdset_insert() { let mut fd_set = FdSet::new(); for i in 0..FD_SETSIZE { - assert!(!fd_set.contains(i as RawFd)); + let borrowed_i = unsafe { BorrowedFd::borrow_raw(i as RawFd) }; + assert!(!fd_set.contains(&borrowed_i)); } - fd_set.insert(7); + let fd_seven = unsafe { BorrowedFd::borrow_raw(7) }; + fd_set.insert(&fd_seven); - assert!(fd_set.contains(7)); + assert!(fd_set.contains(&fd_seven)); } #[test] @@ -324,107 +345,183 @@ mod tests { let mut fd_set = FdSet::new(); for i in 0..FD_SETSIZE { - assert!(!fd_set.contains(i as RawFd)); + let borrowed_i = unsafe { BorrowedFd::borrow_raw(i as RawFd) }; + assert!(!fd_set.contains(&borrowed_i)); } - fd_set.insert(7); - fd_set.remove(7); + let fd_seven = unsafe { BorrowedFd::borrow_raw(7) }; + fd_set.insert(&fd_seven); + fd_set.remove(&fd_seven); for i in 0..FD_SETSIZE { - assert!(!fd_set.contains(i as RawFd)); + let borrowed_i = unsafe { BorrowedFd::borrow_raw(i as RawFd) }; + assert!(!fd_set.contains(&borrowed_i)); } } #[test] + #[allow(non_snake_case)] fn fdset_clear() { let mut fd_set = FdSet::new(); - fd_set.insert(1); - fd_set.insert((FD_SETSIZE / 2) as RawFd); - fd_set.insert((FD_SETSIZE - 1) as RawFd); + let fd_one = unsafe { BorrowedFd::borrow_raw(1) }; + let fd_FD_SETSIZE_devided_by_two = + unsafe { BorrowedFd::borrow_raw((FD_SETSIZE / 2) as RawFd) }; + let fd_FD_SETSIZE_minus_one = + unsafe { BorrowedFd::borrow_raw((FD_SETSIZE - 1) as RawFd) }; + fd_set.insert(&fd_one); + fd_set.insert(&fd_FD_SETSIZE_devided_by_two); + fd_set.insert(&fd_FD_SETSIZE_minus_one); fd_set.clear(); for i in 0..FD_SETSIZE { - assert!(!fd_set.contains(i as RawFd)); + let borrowed_i = unsafe { BorrowedFd::borrow_raw(i as RawFd) }; + assert!(!fd_set.contains(&borrowed_i)); } } #[test] fn fdset_highest() { let mut set = FdSet::new(); - assert_eq!(set.highest(), None); - set.insert(0); - assert_eq!(set.highest(), Some(0)); - set.insert(90); - assert_eq!(set.highest(), Some(90)); - set.remove(0); - assert_eq!(set.highest(), Some(90)); - set.remove(90); - assert_eq!(set.highest(), None); - - set.insert(4); - set.insert(5); - set.insert(7); - assert_eq!(set.highest(), Some(7)); + assert_eq!( + set.highest().map(|borrowed_fd| borrowed_fd.as_raw_fd()), + None + ); + let fd_zero = unsafe { BorrowedFd::borrow_raw(0) }; + let fd_ninety = unsafe { BorrowedFd::borrow_raw(90) }; + set.insert(&fd_zero); + assert_eq!( + set.highest().map(|borrowed_fd| borrowed_fd.as_raw_fd()), + Some(0) + ); + set.insert(&fd_ninety); + assert_eq!( + set.highest().map(|borrowed_fd| borrowed_fd.as_raw_fd()), + Some(90) + ); + set.remove(&fd_zero); + assert_eq!( + set.highest().map(|borrowed_fd| borrowed_fd.as_raw_fd()), + Some(90) + ); + set.remove(&fd_ninety); + assert_eq!( + set.highest().map(|borrowed_fd| borrowed_fd.as_raw_fd()), + None + ); + + let fd_four = unsafe { BorrowedFd::borrow_raw(4) }; + let fd_five = unsafe { BorrowedFd::borrow_raw(5) }; + let fd_seven = unsafe { BorrowedFd::borrow_raw(7) }; + set.insert(&fd_four); + set.insert(&fd_five); + set.insert(&fd_seven); + assert_eq!( + set.highest().map(|borrowed_fd| borrowed_fd.as_raw_fd()), + Some(7) + ); } #[test] fn fdset_fds() { let mut set = FdSet::new(); - assert_eq!(set.fds(None).collect::>(), vec![]); - set.insert(0); - assert_eq!(set.fds(None).collect::>(), vec![0]); - set.insert(90); - assert_eq!(set.fds(None).collect::>(), vec![0, 90]); + let fd_zero = unsafe { BorrowedFd::borrow_raw(0) }; + let fd_ninety = unsafe { BorrowedFd::borrow_raw(90) }; + assert_eq!( + set.fds(None) + .map(|borrowed_fd| borrowed_fd.as_raw_fd()) + .collect::>(), + vec![] + ); + set.insert(&fd_zero); + assert_eq!( + set.fds(None) + .map(|borrowed_fd| borrowed_fd.as_raw_fd()) + .collect::>(), + vec![0] + ); + set.insert(&fd_ninety); + assert_eq!( + set.fds(None) + .map(|borrowed_fd| borrowed_fd.as_raw_fd()) + .collect::>(), + vec![0, 90] + ); // highest limit - assert_eq!(set.fds(Some(89)).collect::>(), vec![0]); - assert_eq!(set.fds(Some(90)).collect::>(), vec![0, 90]); + assert_eq!( + set.fds(Some(89)) + .map(|borrowed_fd| borrowed_fd.as_raw_fd()) + .collect::>(), + vec![0] + ); + assert_eq!( + set.fds(Some(90)) + .map(|borrowed_fd| borrowed_fd.as_raw_fd()) + .collect::>(), + vec![0, 90] + ); } #[test] fn test_select() { let (r1, w1) = pipe().unwrap(); - write(w1, b"hi!").unwrap(); + let r1 = unsafe { OwnedFd::from_raw_fd(r1) }; + let w1 = unsafe { OwnedFd::from_raw_fd(w1) }; let (r2, _w2) = pipe().unwrap(); + let r2 = unsafe { OwnedFd::from_raw_fd(r2) }; + write(w1.as_raw_fd(), b"hi!").unwrap(); let mut fd_set = FdSet::new(); - fd_set.insert(r1); - fd_set.insert(r2); + fd_set.insert(&r1); + fd_set.insert(&r2); let mut timeout = TimeVal::seconds(10); assert_eq!( 1, select(None, &mut fd_set, None, None, &mut timeout).unwrap() ); - assert!(fd_set.contains(r1)); - assert!(!fd_set.contains(r2)); + assert!(fd_set.contains(&r1)); + assert!(!fd_set.contains(&r2)); + close(_w2).unwrap(); } #[test] fn test_select_nfds() { let (r1, w1) = pipe().unwrap(); - write(w1, b"hi!").unwrap(); let (r2, _w2) = pipe().unwrap(); + let r1 = unsafe { OwnedFd::from_raw_fd(r1) }; + let w1 = unsafe { OwnedFd::from_raw_fd(w1) }; + let r2 = unsafe { OwnedFd::from_raw_fd(r2) }; + write(w1.as_raw_fd(), b"hi!").unwrap(); let mut fd_set = FdSet::new(); - fd_set.insert(r1); - fd_set.insert(r2); + fd_set.insert(&r1); + fd_set.insert(&r2); let mut timeout = TimeVal::seconds(10); - assert_eq!( - 1, - select( - Some(fd_set.highest().unwrap() + 1), - &mut fd_set, - None, - None, - &mut timeout - ) - .unwrap() - ); - assert!(fd_set.contains(r1)); - assert!(!fd_set.contains(r2)); + { + assert_eq!( + 1, + select( + Some( + fd_set + .highest() + .map(|borrowed_fd| borrowed_fd.as_raw_fd()) + .unwrap() + + 1 + ), + &mut fd_set, + None, + None, + &mut timeout + ) + .unwrap() + ); + } + assert!(fd_set.contains(&r1)); + assert!(!fd_set.contains(&r2)); + close(_w2).unwrap(); } #[test] @@ -432,16 +529,17 @@ mod tests { let (r1, w1) = pipe().unwrap(); write(w1, b"hi!").unwrap(); let (r2, _w2) = pipe().unwrap(); - + let r1 = unsafe { OwnedFd::from_raw_fd(r1) }; + let r2 = unsafe { OwnedFd::from_raw_fd(r2) }; let mut fd_set = FdSet::new(); - fd_set.insert(r1); - fd_set.insert(r2); + fd_set.insert(&r1); + fd_set.insert(&r2); let mut timeout = TimeVal::seconds(10); assert_eq!( 1, select( - ::std::cmp::max(r1, r2) + 1, + std::cmp::max(r1.as_raw_fd(), r2.as_raw_fd()) + 1, &mut fd_set, None, None, @@ -449,7 +547,8 @@ mod tests { ) .unwrap() ); - assert!(fd_set.contains(r1)); - assert!(!fd_set.contains(r2)); + assert!(fd_set.contains(&r1)); + assert!(!fd_set.contains(&r2)); + close(_w2).unwrap(); } } diff --git a/test/sys/test_select.rs b/test/sys/test_select.rs index 40bda4d90a..79f75de3b4 100644 --- a/test/sys/test_select.rs +++ b/test/sys/test_select.rs @@ -2,6 +2,7 @@ use nix::sys::select::*; use nix::sys::signal::SigSet; use nix::sys::time::{TimeSpec, TimeValLike}; use nix::unistd::{pipe, write}; +use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd}; #[test] pub fn test_pselect() { @@ -9,11 +10,13 @@ pub fn test_pselect() { let (r1, w1) = pipe().unwrap(); write(w1, b"hi!").unwrap(); + let r1 = unsafe { OwnedFd::from_raw_fd(r1) }; let (r2, _w2) = pipe().unwrap(); + let r2 = unsafe { OwnedFd::from_raw_fd(r2) }; let mut fd_set = FdSet::new(); - fd_set.insert(r1); - fd_set.insert(r2); + fd_set.insert(&r1); + fd_set.insert(&r2); let timeout = TimeSpec::seconds(10); let sigmask = SigSet::empty(); @@ -21,25 +24,27 @@ pub fn test_pselect() { 1, pselect(None, &mut fd_set, None, None, &timeout, &sigmask).unwrap() ); - assert!(fd_set.contains(r1)); - assert!(!fd_set.contains(r2)); + assert!(fd_set.contains(&r1)); + assert!(!fd_set.contains(&r2)); } #[test] pub fn test_pselect_nfds2() { let (r1, w1) = pipe().unwrap(); write(w1, b"hi!").unwrap(); + let r1 = unsafe { OwnedFd::from_raw_fd(r1) }; let (r2, _w2) = pipe().unwrap(); + let r2 = unsafe { OwnedFd::from_raw_fd(r2) }; let mut fd_set = FdSet::new(); - fd_set.insert(r1); - fd_set.insert(r2); + fd_set.insert(&r1); + fd_set.insert(&r2); let timeout = TimeSpec::seconds(10); assert_eq!( 1, pselect( - ::std::cmp::max(r1, r2) + 1, + std::cmp::max(r1.as_raw_fd(), r2.as_raw_fd()) + 1, &mut fd_set, None, None, @@ -48,8 +53,8 @@ pub fn test_pselect_nfds2() { ) .unwrap() ); - assert!(fd_set.contains(r1)); - assert!(!fd_set.contains(r2)); + assert!(fd_set.contains(&r1)); + assert!(!fd_set.contains(&r2)); } macro_rules! generate_fdset_bad_fd_tests { @@ -58,17 +63,13 @@ macro_rules! generate_fdset_bad_fd_tests { #[test] #[should_panic] fn $method() { - FdSet::new().$method($fd); + let bad_fd = unsafe{BorrowedFd::borrow_raw($fd)}; + FdSet::new().$method(&bad_fd); } )* } } -mod test_fdset_negative_fd { - use super::*; - generate_fdset_bad_fd_tests!(-1, insert, remove, contains); -} - mod test_fdset_too_large_fd { use super::*; use std::convert::TryInto;