diff --git a/CHANGELOG.md b/CHANGELOG.md index 240054bb5a..9bdd6a2e24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Fixed - Made `preadv` take immutable slice of IoVec. ([#914](https://github.com/nix-rust/nix/pull/914)) +- Fixed passing multiple file descriptors over Unix Sockets. + ([#918](https://github.com/nix-rust/nix/pull/918)) ### Removed diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index b46fa8b096..0706618a2b 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -287,18 +287,37 @@ impl fmt::Debug for Ipv6MembershipRequest { } } -/// Copy the in-memory representation of src into the byte slice dst, -/// updating the slice to point to the remainder of dst only. Unsafe -/// because it exposes all bytes in src, which may be UB if some of them -/// are uninitialized (including padding). -unsafe fn copy_bytes<'a, 'b, T: ?Sized>(src: &T, dst: &'a mut &'b mut [u8]) { +/// Copy the in-memory representation of `src` into the byte slice `dst`. +/// +/// Returns the remainder of `dst`. +/// +/// Panics when `dst` is too small for `src` (more precisely, panics if +/// `mem::size_of_val(src) >= dst.len()`). +/// +/// Unsafe because it transmutes `src` to raw bytes, which is only safe for some +/// types `T`. Refer to the [Rustonomicon] for details. +/// +/// [Rustonomicon]: https://doc.rust-lang.org/nomicon/transmutes.html +unsafe fn copy_bytes<'a, T: ?Sized>(src: &T, dst: &'a mut [u8]) -> &'a mut [u8] { let srclen = mem::size_of_val(src); - let mut tmpdst = &mut [][..]; - mem::swap(&mut tmpdst, dst); - let (target, mut remainder) = tmpdst.split_at_mut(srclen); - // Safe because the mutable borrow of dst guarantees that src does not alias it. - ptr::copy_nonoverlapping(src as *const T as *const u8, target.as_mut_ptr(), srclen); - mem::swap(dst, &mut remainder); + ptr::copy_nonoverlapping( + src as *const T as *const u8, + dst[..srclen].as_mut_ptr(), + srclen + ); + + &mut dst[srclen..] +} + +/// Fills `dst` with `len` zero bytes and returns the remainder of the slice. +/// +/// Panics when `len >= dst.len()`. +fn pad_bytes(len: usize, dst: &mut [u8]) -> &mut [u8] { + for pad in &mut dst[..len] { + *pad = 0; + } + + &mut dst[len..] } cfg_if! { @@ -434,6 +453,11 @@ pub enum ControlMessage<'a> { /// /// See the description in the "Ancillary messages" section of the /// [unix(7) man page](http://man7.org/linux/man-pages/man7/unix.7.html). + /// + /// Using multiple `ScmRights` messages for a single `sendmsg` call isn't recommended since it + /// causes platform-dependent behaviour: It might swallow all but the first `ScmRights` message + /// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights` + /// message. ScmRights(&'a [RawFd]), /// A message of type `SCM_TIMESTAMP`, containing the time the /// packet was received by the kernel. @@ -545,7 +569,7 @@ impl<'a> ControlMessage<'a> { // Unsafe: start and end of buffer must be cmsg_align'd. Updates // the provided slice; panics if the buffer is too small. - unsafe fn encode_into<'b>(&self, buf: &mut &'b mut [u8]) { + unsafe fn encode_into(&self, buf: &mut [u8]) { match *self { ControlMessage::ScmRights(fds) => { let cmsg = cmsghdr { @@ -554,17 +578,16 @@ impl<'a> ControlMessage<'a> { cmsg_type: libc::SCM_RIGHTS, ..mem::uninitialized() }; - copy_bytes(&cmsg, buf); + let buf = copy_bytes(&cmsg, buf); let padlen = cmsg_align(mem::size_of_val(&cmsg)) - mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); - let mut tmpbuf = &mut [][..]; - mem::swap(&mut tmpbuf, buf); - let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen); - mem::swap(buf, &mut remainder); + let buf = copy_bytes(fds, buf); - copy_bytes(fds, buf); + let padlen = self.space() - self.len(); + pad_bytes(padlen, buf); }, ControlMessage::ScmTimestamp(t) => { let cmsg = cmsghdr { @@ -573,21 +596,28 @@ impl<'a> ControlMessage<'a> { cmsg_type: libc::SCM_TIMESTAMP, ..mem::uninitialized() }; - copy_bytes(&cmsg, buf); + let buf = copy_bytes(&cmsg, buf); let padlen = cmsg_align(mem::size_of_val(&cmsg)) - mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); - let mut tmpbuf = &mut [][..]; - mem::swap(&mut tmpbuf, buf); - let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen); - mem::swap(buf, &mut remainder); + let buf = copy_bytes(t, buf); - copy_bytes(t, buf); + let padlen = self.space() - self.len(); + pad_bytes(padlen, buf); }, ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => { - copy_bytes(orig_cmsg, buf); - copy_bytes(bytes, buf); + let buf = copy_bytes(orig_cmsg, buf); + + let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - + mem::size_of_val(&orig_cmsg); + let buf = pad_bytes(padlen, buf); + + let buf = copy_bytes(bytes, buf); + + let padlen = self.space() - self.len(); + pad_bytes(padlen, buf); } } } @@ -600,23 +630,25 @@ impl<'a> ControlMessage<'a> { /// /// Allocates if cmsgs is nonempty. pub fn sendmsg<'a>(fd: RawFd, iov: &[IoVec<&'a [u8]>], cmsgs: &[ControlMessage<'a>], flags: MsgFlags, addr: Option<&'a SockAddr>) -> Result { - let mut len = 0; let mut capacity = 0; for cmsg in cmsgs { - len += cmsg.len(); capacity += cmsg.space(); } // Note that the resulting vector claims to have length == capacity, // so it's presently uninitialized. let mut cmsg_buffer = unsafe { - let mut vec = Vec::::with_capacity(len); - vec.set_len(len); + let mut vec = Vec::::with_capacity(capacity); + vec.set_len(capacity); vec }; { - let mut ptr = &mut cmsg_buffer[..]; + let mut ofs = 0; for cmsg in cmsgs { - unsafe { cmsg.encode_into(&mut ptr) }; + let mut ptr = &mut cmsg_buffer[ofs..]; + unsafe { + cmsg.encode_into(ptr); + } + ofs += cmsg.space(); } } @@ -669,10 +701,23 @@ pub fn recvmsg<'a, T>(fd: RawFd, iov: &[IoVec<&mut [u8]>], cmsg_buffer: Option<& }; let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) }; + let cmsg_buffer = if msg_controllen > 0 { + // got control message(s) + debug_assert!(!mhdr.msg_control.is_null()); + unsafe { + // Safe: The pointer is not null and the length is correct as part of `recvmsg`s + // contract. + slice::from_raw_parts(mhdr.msg_control as *const u8, + mhdr.msg_controllen as usize) + } + } else { + // No control message, create an empty buffer to avoid creating a slice from a null pointer + &[] + }; + Ok(unsafe { RecvMsg { bytes: try!(Errno::result(ret)) as usize, - cmsg_buffer: slice::from_raw_parts(mhdr.msg_control as *const u8, - mhdr.msg_controllen as usize), + cmsg_buffer, address: sockaddr_storage_to_addr(&address, mhdr.msg_namelen as usize).ok(), flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index a997fbca9e..35e3bf9052 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -167,6 +167,51 @@ pub fn test_scm_rights() { close(w).unwrap(); } +/// Tests that passing multiple fds using a single `ControlMessage` works. +#[test] +fn test_scm_rights_single_cmsg_multiple_fds() { + use std::os::unix::net::UnixDatagram; + use std::os::unix::io::{RawFd, AsRawFd}; + use std::thread; + use nix::sys::socket::{CmsgSpace, ControlMessage, MsgFlags, sendmsg, recvmsg}; + use nix::sys::uio::IoVec; + use libc; + + let (send, receive) = UnixDatagram::pair().unwrap(); + let thread = thread::spawn(move || { + let mut buf = [0u8; 8]; + let iovec = [IoVec::from_mut_slice(&mut buf)]; + let mut space = CmsgSpace::<[RawFd; 2]>::new(); + let msg = recvmsg( + receive.as_raw_fd(), + &iovec, + Some(&mut space), + MsgFlags::empty() + ).unwrap(); + assert!(!msg.flags.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); + + let mut cmsgs = msg.cmsgs(); + match cmsgs.next() { + Some(ControlMessage::ScmRights(fds)) => { + assert_eq!(fds.len(), 2, + "unexpected fd count (expected 2 fds, got {})", + fds.len()); + }, + _ => panic!(), + } + assert!(cmsgs.next().is_none(), "unexpected control msg"); + + assert_eq!(iovec[0].as_slice(), [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8]); + }); + + let slice = [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8]; + let iov = [IoVec::from_slice(&slice)]; + let fds = [libc::STDIN_FILENO, libc::STDOUT_FILENO]; // pass stdin and stdout + let cmsg = [ControlMessage::ScmRights(&fds)]; + sendmsg(send.as_raw_fd(), &iov, &cmsg, MsgFlags::empty(), None).unwrap(); + thread.join().unwrap(); +} + // Verify `sendmsg` builds a valid `msghdr` when passing an empty // `cmsgs` argument. This should result in a msghdr with a nullptr // msg_control field and a msg_controllen of 0 when calling into the