From afbe06056e9756582e022cf6a9bf3ffdbfd967f2 Mon Sep 17 00:00:00 2001 From: Michael Baikov Date: Wed, 22 Jun 2022 10:11:33 +0800 Subject: [PATCH] also support sendmmsg renames: RecvMMsg -> MultHdrs RecvMMsgItems -> MultiResults Adding a lifetime reference to RecvMsg The name is not 100% correct now, it can be useful for both sending and receiving messages: to collect hardware sending timestamps you need to use control messages as well --- src/sys/socket/mod.rs | 155 +++++++++++++++++++++------------------- test/sys/test_socket.rs | 48 ++++++------- 2 files changed, 107 insertions(+), 96 deletions(-) diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 1de1d479f0..2680c43fc3 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -574,15 +574,20 @@ macro_rules! cmsg_space { } #[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct RecvMsg<'a, S> { +/// Contains outcome of sending or receiving a message +/// +/// Use [`cmsgs`][RecvMsg::cmsgs] to access all the control messages present, and +/// [`iovs`][RecvMsg::iovs`] to access underlying io slices. +pub struct RecvMsg<'a, 's, S> { pub bytes: usize, cmsghdr: Option<&'a cmsghdr>, pub address: Option, pub flags: MsgFlags, + iobufs: std::marker::PhantomData<& 's()>, mhdr: msghdr, } -impl<'a, S> RecvMsg<'a, S> { +impl<'a, S> RecvMsg<'a, '_, S> { /// Iterate over the valid control messages pointed to by this /// msghdr. pub fn cmsgs(&self) -> CmsgIterator { @@ -1411,24 +1416,6 @@ pub fn sendmsg(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage], Errno::result(ret).map(|r| r as usize) } -#[cfg(any( - target_os = "linux", - target_os = "android", - target_os = "freebsd", - target_os = "netbsd", -))] -#[derive(Debug)] -pub struct SendMmsgData<'a, I, C, S> - where - I: AsRef<[IoSlice<'a>]>, - C: AsRef<[ControlMessage<'a>]>, - S: SockaddrLike + 'a -{ - pub iov: I, - pub cmsgs: C, - pub addr: Option, - pub _lt: std::marker::PhantomData<&'a I>, -} /// An extension of `sendmsg` that allows the caller to transmit multiple /// messages on a socket using a single system call. This has performance @@ -1453,51 +1440,66 @@ pub struct SendMmsgData<'a, I, C, S> target_os = "freebsd", target_os = "netbsd", ))] -pub fn sendmmsg<'a, I, C, S>( +pub fn sendmmsg<'a, XS, AS, C, I, S>( fd: RawFd, - data: impl std::iter::IntoIterator>, + data: &'a mut MultHdrs, + slices: XS, + // one address per group of slices + addrs: AS, + // shared across all the messages + cmsgs: C, flags: MsgFlags -) -> Result> +) -> crate::Result> where + XS: IntoIterator, + AS: AsRef<[Option]>, I: AsRef<[IoSlice<'a>]> + 'a, C: AsRef<[ControlMessage<'a>]> + 'a, S: SockaddrLike + 'a { - let iter = data.into_iter(); - let size_hint = iter.size_hint(); - let reserve_items = size_hint.1.unwrap_or(size_hint.0); + let mut count = 0; - let mut output = Vec::::with_capacity(reserve_items); - let mut cmsgs_buffers = Vec::>::with_capacity(reserve_items); - - for d in iter { - let capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum(); - let mut cmsgs_buffer = vec![0u8; capacity]; + for (i, ((slice, addr), mmsghdr)) in slices.into_iter().zip(addrs.as_ref()).zip(data.items.iter_mut() ).enumerate() { + let mut p = &mut mmsghdr.msg_hdr; + p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec; + p.msg_iovlen = slice.as_ref().len() as _; - output.push(libc::mmsghdr { - msg_hdr: pack_mhdr_to_send( - &mut cmsgs_buffer, - &d.iov, - &d.cmsgs, - d.addr.as_ref() - ), - msg_len: 0, - }); - cmsgs_buffers.push(cmsgs_buffer); - }; + (*p).msg_namelen = addr.as_ref().map_or(0, S::len); + (*p).msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr) as _; + + // Encode each cmsg. This must happen after initializing the header because + // CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields. + // CMSG_FIRSTHDR is always safe + let mut pmhdr: *mut cmsghdr = unsafe { CMSG_FIRSTHDR(p) }; + for cmsg in cmsgs.as_ref() { + assert_ne!(pmhdr, ptr::null_mut()); + // Safe because we know that pmhdr is valid, and we initialized it with + // sufficient space + unsafe { cmsg.encode_into(pmhdr) }; + // Safe because mhdr is valid + pmhdr = unsafe { CMSG_NXTHDR(p, pmhdr) }; + } - let ret = unsafe { libc::sendmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _) }; + count = i+1; + } - let sent_messages = Errno::result(ret)? as usize; - let mut sent_bytes = Vec::with_capacity(sent_messages); + let sent = Errno::result(unsafe { + libc::sendmmsg( + fd, + data.items.as_mut_ptr(), + count as _, + flags.bits() as _ + ) + })? as usize; - for item in &output { - sent_bytes.push(item.msg_len as usize); - } + Ok(MultiResults { + rmm: data, + current_index: 0, + received: sent + }) - Ok(sent_bytes) } @@ -1508,8 +1510,8 @@ pub fn sendmmsg<'a, I, C, S>( target_os = "netbsd", ))] #[derive(Debug)] -/// Preallocated structures needed for [`recvmmsg`] function -pub struct RecvMMsg { +/// Preallocated structures needed for [`recvmmsg`] and [`sendmmsg`] functions +pub struct MultHdrs { // preallocated boxed slice of mmsghdr items: Box<[libc::mmsghdr]>, addresses: Box<[mem::MaybeUninit]>, @@ -1526,8 +1528,8 @@ pub struct RecvMMsg { target_os = "freebsd", target_os = "netbsd", ))] -impl RecvMMsg { - /// Preallocate structure used by [`recvmmsg`], takes number of headers to preallocate +impl MultHdrs { + /// Preallocate structure used by [`recvmmsg`] and [`sendmmsg`] takes number of headers to preallocate /// /// `cmsg_buffer` should be created with [`cmsg_space!`] if needed pub fn preallocate(num_slices: usize, cmsg_buffer: Option>) -> Self @@ -1598,21 +1600,21 @@ impl RecvMMsg { ))] pub fn recvmmsg<'a, XS, S, I>( fd: RawFd, - data: &'a mut RecvMMsg, + data: &'a mut MultHdrs, slices: XS, flags: MsgFlags, mut timeout: Option, -) -> crate::Result> +) -> crate::Result> where - XS: ExactSizeIterator, + XS: IntoIterator, I: AsRef<[IoSliceMut<'a>]>, { - let count = std::cmp::min(slices.len(), data.items.len()); - - for (slice, mmsghdr) in slices.zip(data.items.iter_mut()) { + let mut count = 0; + for (i, (slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() { let mut p = &mut mmsghdr.msg_hdr; p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec; p.msg_iovlen = slice.as_ref().len() as _; + count = i + 1; } let timeout_ptr = timeout @@ -1629,7 +1631,7 @@ where ) })? as usize; - Ok(RecvMMsgItems { + Ok(MultiResults { rmm: data, current_index: 0, received, @@ -1643,9 +1645,12 @@ where target_os = "netbsd", ))] #[derive(Debug)] -pub struct RecvMMsgItems<'a, S> { +/// Iterator over results of [`recvmmsg`]/[`sendmmsg`] +/// +/// +pub struct MultiResults<'a, S> { // preallocated structures - rmm: &'a RecvMMsg, + rmm: &'a MultHdrs, current_index: usize, received: usize, } @@ -1656,11 +1661,11 @@ pub struct RecvMMsgItems<'a, S> { target_os = "freebsd", target_os = "netbsd", ))] -impl<'a, S> Iterator for RecvMMsgItems<'a, S> +impl<'a, S> Iterator for MultiResults<'a, S> where S: Copy + SockaddrLike, { - type Item = RecvMsg<'a, S>; + type Item = RecvMsg<'a, 'a, S>; fn next(&mut self) -> Option { if self.current_index >= self.received { @@ -1684,13 +1689,17 @@ where } } -impl<'a, S> RecvMsg<'a, S> { +impl<'a, S> RecvMsg<'_, 'a, S> { /// Iterate over the filled io slices pointed by this msghdr - pub fn iovs(&self) -> IoSliceIterator { + pub fn iovs(&self) -> IoSliceIterator<'a> { IoSliceIterator { index: 0, remaining: self.bytes, slices: unsafe { + // safe for as long as mgdr is properly initialized and references are valid. + // for multi messages API we initialize it with an empty + // slice and replace with a concrete buffer + // for single message API we hold a lifetime reference to ioslices std::slice::from_raw_parts(self.mhdr.msg_iov as *const _, self.mhdr.msg_iovlen as _) }, } @@ -1782,7 +1791,7 @@ mod test { let cmsg = cmsg_space!(crate::sys::socket::Timestamps); sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap(); - let mut data = super::RecvMMsg::<()>::preallocate(recv_iovs.len(), Some(cmsg)); + let mut data = super::MultHdrs::<()>::preallocate(recv_iovs.len(), Some(cmsg)); let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10)); @@ -1817,12 +1826,12 @@ mod test { Ok(()) } } -unsafe fn read_mhdr<'a, S>( +unsafe fn read_mhdr<'a, 'i, S>( mhdr: msghdr, r: isize, msg_controllen: usize, address: S, -) -> RecvMsg<'a, S> +) -> RecvMsg<'a, 'i, S> where S: SockaddrLike { let cmsghdr = { @@ -1841,6 +1850,7 @@ unsafe fn read_mhdr<'a, S>( address: Some(address), flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), mhdr, + iobufs: std::marker::PhantomData, } } @@ -1948,8 +1958,9 @@ fn pack_mhdr_to_send<'a, I, C, S>( /// [recvmsg(2)](https://pubs.opengroup.org/onlinepubs/9699919799/functions/recvmsg.html) pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'inner>], mut cmsg_buffer: Option<&'a mut Vec>, - flags: MsgFlags) -> Result> - where S: SockaddrLike + 'a + flags: MsgFlags) -> Result> + where S: SockaddrLike + 'a, + 'inner: 'outer { let mut address = mem::MaybeUninit::uninit(); diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index 50565094de..ded2647934 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -501,31 +501,31 @@ mod recvfrom { rsock, ssock, move |s, m, flags| { - let iov = [IoSlice::new(m)]; - let mut msgs = vec![SendMmsgData { - iov: &iov, - cmsgs: &[], - addr: Some(sock_addr), - _lt: Default::default(), - }]; - let batch_size = 15; + let mut iovs = Vec::with_capacity(1 + batch_size); + let mut addrs = Vec::with_capacity(1 + batch_size); + let mut data = MultHdrs::preallocate(1 + batch_size, None); + let iov = IoSlice::new(m); + // first chunk: + iovs.push([iov]); + addrs.push(Some(sock_addr)); for _ in 0..batch_size { - msgs.push(SendMmsgData { - iov: &iov, - cmsgs: &[], - addr: Some(sock_addr2), - _lt: Default::default(), - }); + iovs.push([iov]); + addrs.push(Some(sock_addr2)); } - sendmmsg(s, msgs.iter(), flags).map(move |sent_bytes| { - assert!(!sent_bytes.is_empty()); - for sent in &sent_bytes { - assert_eq!(*sent, m.len()); - } - sent_bytes.len() - }) + + let res = sendmmsg(s, &mut data, &iovs, addrs, &[], flags)?; + let mut sent_messages = 0; + let mut sent_bytes = 0; + for item in res { + sent_messages += 1; + sent_bytes += item.bytes; + } + // + assert_eq!(sent_messages, iovs.len()); + assert_eq!(sent_bytes, sent_messages * m.len()); + Ok(sent_messages) }, |_, _| {}, ); @@ -582,7 +582,7 @@ mod recvfrom { .iter_mut() .map(|buf| [IoSliceMut::new(&mut buf[..])]), ); - let mut data = RecvMMsg::::preallocate(msgs.len(), None); + let mut data = MultHdrs::::preallocate(msgs.len(), None); let res: Vec> = recvmmsg(rsock, &mut data, msgs.iter(), MsgFlags::empty(), None) @@ -658,7 +658,7 @@ mod recvfrom { ); let mut data = - RecvMMsg::::preallocate(NUM_MESSAGES_SENT + 2, None); + MultHdrs::::preallocate(NUM_MESSAGES_SENT + 2, None); let res: Vec> = recvmmsg( rsock, @@ -1943,7 +1943,7 @@ fn test_recvmmsg_timestampns() { let mut buffer = vec![0u8; message.len()]; let cmsgspace = nix::cmsg_space!(TimeSpec); let iov = vec![[IoSliceMut::new(&mut buffer)]]; - let mut data = RecvMMsg::preallocate(1, Some(cmsgspace)); + let mut data = MultHdrs::preallocate(1, Some(cmsgspace)); let r: Vec> = recvmmsg(in_socket, &mut data, iov.iter(), flags, None) .unwrap()