Skip to content

Commit

Permalink
Merge #918
Browse files Browse the repository at this point in the history
918: Fix passing multiple file descriptors / control messages via sendmsg r=asomers a=jonas-schievink

Fixes #464
Closes #874 because it's incorporated here
Closes #756 because it adds the test from that issue (with fixes)

Co-authored-by: alecmocatta <[email protected]>
  • Loading branch information
bors[bot] and alecmocatta committed Jul 5, 2018
2 parents e0577cc + 5d6dc26 commit 90b1c17
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Fixed
- Made `preadv` take immutable slice of IoVec.
([#914](https://github.com/nix-rust/nix/pull/914))
- Fixed passing multiple file descriptors over Unix Sockets.
([#918](https://github.com/nix-rust/nix/pull/918))

### Removed

Expand Down
113 changes: 79 additions & 34 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,37 @@ impl fmt::Debug for Ipv6MembershipRequest {
}
}

/// Copy the in-memory representation of src into the byte slice dst,
/// updating the slice to point to the remainder of dst only. Unsafe
/// because it exposes all bytes in src, which may be UB if some of them
/// are uninitialized (including padding).
unsafe fn copy_bytes<'a, 'b, T: ?Sized>(src: &T, dst: &'a mut &'b mut [u8]) {
/// Copy the in-memory representation of `src` into the byte slice `dst`.
///
/// Returns the remainder of `dst`.
///
/// Panics when `dst` is too small for `src` (more precisely, panics if
/// `mem::size_of_val(src) >= dst.len()`).
///
/// Unsafe because it transmutes `src` to raw bytes, which is only safe for some
/// types `T`. Refer to the [Rustonomicon] for details.
///
/// [Rustonomicon]: https://doc.rust-lang.org/nomicon/transmutes.html
unsafe fn copy_bytes<'a, T: ?Sized>(src: &T, dst: &'a mut [u8]) -> &'a mut [u8] {
let srclen = mem::size_of_val(src);
let mut tmpdst = &mut [][..];
mem::swap(&mut tmpdst, dst);
let (target, mut remainder) = tmpdst.split_at_mut(srclen);
// Safe because the mutable borrow of dst guarantees that src does not alias it.
ptr::copy_nonoverlapping(src as *const T as *const u8, target.as_mut_ptr(), srclen);
mem::swap(dst, &mut remainder);
ptr::copy_nonoverlapping(
src as *const T as *const u8,
dst[..srclen].as_mut_ptr(),
srclen
);

&mut dst[srclen..]
}

/// Fills `dst` with `len` zero bytes and returns the remainder of the slice.
///
/// Panics when `len >= dst.len()`.
fn pad_bytes(len: usize, dst: &mut [u8]) -> &mut [u8] {
for pad in &mut dst[..len] {
*pad = 0;
}

&mut dst[len..]
}

cfg_if! {
Expand Down Expand Up @@ -434,6 +453,11 @@ pub enum ControlMessage<'a> {
///
/// See the description in the "Ancillary messages" section of the
/// [unix(7) man page](http://man7.org/linux/man-pages/man7/unix.7.html).
///
/// Using multiple `ScmRights` messages for a single `sendmsg` call isn't recommended since it
/// causes platform-dependent behaviour: It might swallow all but the first `ScmRights` message
/// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights`
/// message.
ScmRights(&'a [RawFd]),
/// A message of type `SCM_TIMESTAMP`, containing the time the
/// packet was received by the kernel.
Expand Down Expand Up @@ -545,7 +569,7 @@ impl<'a> ControlMessage<'a> {

// Unsafe: start and end of buffer must be cmsg_align'd. Updates
// the provided slice; panics if the buffer is too small.
unsafe fn encode_into<'b>(&self, buf: &mut &'b mut [u8]) {
unsafe fn encode_into(&self, buf: &mut [u8]) {
match *self {
ControlMessage::ScmRights(fds) => {
let cmsg = cmsghdr {
Expand All @@ -554,17 +578,16 @@ impl<'a> ControlMessage<'a> {
cmsg_type: libc::SCM_RIGHTS,
..mem::uninitialized()
};
copy_bytes(&cmsg, buf);
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let mut tmpbuf = &mut [][..];
mem::swap(&mut tmpbuf, buf);
let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen);
mem::swap(buf, &mut remainder);
let buf = copy_bytes(fds, buf);

copy_bytes(fds, buf);
let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::ScmTimestamp(t) => {
let cmsg = cmsghdr {
Expand All @@ -573,21 +596,28 @@ impl<'a> ControlMessage<'a> {
cmsg_type: libc::SCM_TIMESTAMP,
..mem::uninitialized()
};
copy_bytes(&cmsg, buf);
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let mut tmpbuf = &mut [][..];
mem::swap(&mut tmpbuf, buf);
let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen);
mem::swap(buf, &mut remainder);
let buf = copy_bytes(t, buf);

copy_bytes(t, buf);
let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => {
copy_bytes(orig_cmsg, buf);
copy_bytes(bytes, buf);
let buf = copy_bytes(orig_cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(bytes, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
}
}
}
Expand All @@ -600,23 +630,25 @@ impl<'a> ControlMessage<'a> {
///
/// Allocates if cmsgs is nonempty.
pub fn sendmsg<'a>(fd: RawFd, iov: &[IoVec<&'a [u8]>], cmsgs: &[ControlMessage<'a>], flags: MsgFlags, addr: Option<&'a SockAddr>) -> Result<usize> {
let mut len = 0;
let mut capacity = 0;
for cmsg in cmsgs {
len += cmsg.len();
capacity += cmsg.space();
}
// Note that the resulting vector claims to have length == capacity,
// so it's presently uninitialized.
let mut cmsg_buffer = unsafe {
let mut vec = Vec::<u8>::with_capacity(len);
vec.set_len(len);
let mut vec = Vec::<u8>::with_capacity(capacity);
vec.set_len(capacity);
vec
};
{
let mut ptr = &mut cmsg_buffer[..];
let mut ofs = 0;
for cmsg in cmsgs {
unsafe { cmsg.encode_into(&mut ptr) };
let mut ptr = &mut cmsg_buffer[ofs..];
unsafe {
cmsg.encode_into(ptr);
}
ofs += cmsg.space();
}
}

Expand Down Expand Up @@ -669,10 +701,23 @@ pub fn recvmsg<'a, T>(fd: RawFd, iov: &[IoVec<&mut [u8]>], cmsg_buffer: Option<&
};
let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };

let cmsg_buffer = if msg_controllen > 0 {
// got control message(s)
debug_assert!(!mhdr.msg_control.is_null());
unsafe {
// Safe: The pointer is not null and the length is correct as part of `recvmsg`s
// contract.
slice::from_raw_parts(mhdr.msg_control as *const u8,
mhdr.msg_controllen as usize)
}
} else {
// No control message, create an empty buffer to avoid creating a slice from a null pointer
&[]
};

Ok(unsafe { RecvMsg {
bytes: try!(Errno::result(ret)) as usize,
cmsg_buffer: slice::from_raw_parts(mhdr.msg_control as *const u8,
mhdr.msg_controllen as usize),
cmsg_buffer,
address: sockaddr_storage_to_addr(&address,
mhdr.msg_namelen as usize).ok(),
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
Expand Down
45 changes: 45 additions & 0 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,51 @@ pub fn test_scm_rights() {
close(w).unwrap();
}

/// Tests that passing multiple fds using a single `ControlMessage` works.
#[test]
fn test_scm_rights_single_cmsg_multiple_fds() {
use std::os::unix::net::UnixDatagram;
use std::os::unix::io::{RawFd, AsRawFd};
use std::thread;
use nix::sys::socket::{CmsgSpace, ControlMessage, MsgFlags, sendmsg, recvmsg};
use nix::sys::uio::IoVec;
use libc;

let (send, receive) = UnixDatagram::pair().unwrap();
let thread = thread::spawn(move || {
let mut buf = [0u8; 8];
let iovec = [IoVec::from_mut_slice(&mut buf)];
let mut space = CmsgSpace::<[RawFd; 2]>::new();
let msg = recvmsg(
receive.as_raw_fd(),
&iovec,
Some(&mut space),
MsgFlags::empty()
).unwrap();
assert!(!msg.flags.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
match cmsgs.next() {
Some(ControlMessage::ScmRights(fds)) => {
assert_eq!(fds.len(), 2,
"unexpected fd count (expected 2 fds, got {})",
fds.len());
},
_ => panic!(),
}
assert!(cmsgs.next().is_none(), "unexpected control msg");

assert_eq!(iovec[0].as_slice(), [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8]);
});

let slice = [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8];
let iov = [IoVec::from_slice(&slice)];
let fds = [libc::STDIN_FILENO, libc::STDOUT_FILENO]; // pass stdin and stdout
let cmsg = [ControlMessage::ScmRights(&fds)];
sendmsg(send.as_raw_fd(), &iov, &cmsg, MsgFlags::empty(), None).unwrap();
thread.join().unwrap();
}

// Verify `sendmsg` builds a valid `msghdr` when passing an empty
// `cmsgs` argument. This should result in a msghdr with a nullptr
// msg_control field and a msg_controllen of 0 when calling into the
Expand Down

0 comments on commit 90b1c17

Please sign in to comment.