Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix passing multiple file descriptors / control messages via sendmsg #918

Merged
merged 1 commit into from
Jul 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Fixed
- Made `preadv` take immutable slice of IoVec.
([#914](https://github.com/nix-rust/nix/pull/914))
- Fixed passing multiple file descriptors over Unix Sockets.
([#918](https://github.com/nix-rust/nix/pull/918))

### Removed

Expand Down
113 changes: 79 additions & 34 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,37 @@ impl fmt::Debug for Ipv6MembershipRequest {
}
}

/// Copy the in-memory representation of src into the byte slice dst,
/// updating the slice to point to the remainder of dst only. Unsafe
/// because it exposes all bytes in src, which may be UB if some of them
/// are uninitialized (including padding).
unsafe fn copy_bytes<'a, 'b, T: ?Sized>(src: &T, dst: &'a mut &'b mut [u8]) {
/// Copy the in-memory representation of `src` into the byte slice `dst`.
///
/// Returns the remainder of `dst`.
///
/// Panics when `dst` is too small for `src` (more precisely, panics if
/// `mem::size_of_val(src) >= dst.len()`).
///
/// Unsafe because it transmutes `src` to raw bytes, which is only safe for some
/// types `T`. Refer to the [Rustonomicon] for details.
///
/// [Rustonomicon]: https://doc.rust-lang.org/nomicon/transmutes.html
unsafe fn copy_bytes<'a, T: ?Sized>(src: &T, dst: &'a mut [u8]) -> &'a mut [u8] {
let srclen = mem::size_of_val(src);
let mut tmpdst = &mut [][..];
mem::swap(&mut tmpdst, dst);
let (target, mut remainder) = tmpdst.split_at_mut(srclen);
// Safe because the mutable borrow of dst guarantees that src does not alias it.
ptr::copy_nonoverlapping(src as *const T as *const u8, target.as_mut_ptr(), srclen);
mem::swap(dst, &mut remainder);
ptr::copy_nonoverlapping(
src as *const T as *const u8,
dst[..srclen].as_mut_ptr(),
srclen
);

&mut dst[srclen..]
}

/// Fills `dst` with `len` zero bytes and returns the remainder of the slice.
///
/// Panics when `len >= dst.len()`.
fn pad_bytes(len: usize, dst: &mut [u8]) -> &mut [u8] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer than the previous version. I like.

for pad in &mut dst[..len] {
*pad = 0;
}

&mut dst[len..]
}

cfg_if! {
Expand Down Expand Up @@ -434,6 +453,11 @@ pub enum ControlMessage<'a> {
///
/// See the description in the "Ancillary messages" section of the
/// [unix(7) man page](http://man7.org/linux/man-pages/man7/unix.7.html).
///
/// Using multiple `ScmRights` messages for a single `sendmsg` call isn't recommended since it
/// causes platform-dependent behaviour: It might swallow all but the first `ScmRights` message
/// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights`
/// message.
ScmRights(&'a [RawFd]),
/// A message of type `SCM_TIMESTAMP`, containing the time the
/// packet was received by the kernel.
Expand Down Expand Up @@ -545,7 +569,7 @@ impl<'a> ControlMessage<'a> {

// Unsafe: start and end of buffer must be cmsg_align'd. Updates
// the provided slice; panics if the buffer is too small.
unsafe fn encode_into<'b>(&self, buf: &mut &'b mut [u8]) {
unsafe fn encode_into(&self, buf: &mut [u8]) {
match *self {
ControlMessage::ScmRights(fds) => {
let cmsg = cmsghdr {
Expand All @@ -554,17 +578,16 @@ impl<'a> ControlMessage<'a> {
cmsg_type: libc::SCM_RIGHTS,
..mem::uninitialized()
};
copy_bytes(&cmsg, buf);
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let mut tmpbuf = &mut [][..];
mem::swap(&mut tmpbuf, buf);
let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen);
mem::swap(buf, &mut remainder);
let buf = copy_bytes(fds, buf);

copy_bytes(fds, buf);
let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::ScmTimestamp(t) => {
let cmsg = cmsghdr {
Expand All @@ -573,21 +596,28 @@ impl<'a> ControlMessage<'a> {
cmsg_type: libc::SCM_TIMESTAMP,
..mem::uninitialized()
};
copy_bytes(&cmsg, buf);
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let mut tmpbuf = &mut [][..];
mem::swap(&mut tmpbuf, buf);
let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen);
mem::swap(buf, &mut remainder);
let buf = copy_bytes(t, buf);

copy_bytes(t, buf);
let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => {
copy_bytes(orig_cmsg, buf);
copy_bytes(bytes, buf);
let buf = copy_bytes(orig_cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(bytes, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
}
}
}
Expand All @@ -600,23 +630,25 @@ impl<'a> ControlMessage<'a> {
///
/// Allocates if cmsgs is nonempty.
pub fn sendmsg<'a>(fd: RawFd, iov: &[IoVec<&'a [u8]>], cmsgs: &[ControlMessage<'a>], flags: MsgFlags, addr: Option<&'a SockAddr>) -> Result<usize> {
let mut len = 0;
let mut capacity = 0;
for cmsg in cmsgs {
len += cmsg.len();
capacity += cmsg.space();
}
// Note that the resulting vector claims to have length == capacity,
// so it's presently uninitialized.
let mut cmsg_buffer = unsafe {
let mut vec = Vec::<u8>::with_capacity(len);
vec.set_len(len);
let mut vec = Vec::<u8>::with_capacity(capacity);
vec.set_len(capacity);
vec
};
{
let mut ptr = &mut cmsg_buffer[..];
let mut ofs = 0;
for cmsg in cmsgs {
unsafe { cmsg.encode_into(&mut ptr) };
let mut ptr = &mut cmsg_buffer[ofs..];
unsafe {
cmsg.encode_into(ptr);
}
ofs += cmsg.space();
}
}

Expand Down Expand Up @@ -669,10 +701,23 @@ pub fn recvmsg<'a, T>(fd: RawFd, iov: &[IoVec<&mut [u8]>], cmsg_buffer: Option<&
};
let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };

let cmsg_buffer = if msg_controllen > 0 {
// got control message(s)
debug_assert!(!mhdr.msg_control.is_null());
unsafe {
// Safe: The pointer is not null and the length is correct as part of `recvmsg`s
// contract.
slice::from_raw_parts(mhdr.msg_control as *const u8,
mhdr.msg_controllen as usize)
}
} else {
// No control message, create an empty buffer to avoid creating a slice from a null pointer
&[]
};

Ok(unsafe { RecvMsg {
bytes: try!(Errno::result(ret)) as usize,
cmsg_buffer: slice::from_raw_parts(mhdr.msg_control as *const u8,
mhdr.msg_controllen as usize),
cmsg_buffer,
address: sockaddr_storage_to_addr(&address,
mhdr.msg_namelen as usize).ok(),
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
Expand Down
45 changes: 45 additions & 0 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,51 @@ pub fn test_scm_rights() {
close(w).unwrap();
}

/// Tests that passing multiple fds using a single `ControlMessage` works.
#[test]
fn test_scm_rights_single_cmsg_multiple_fds() {
use std::os::unix::net::UnixDatagram;
use std::os::unix::io::{RawFd, AsRawFd};
use std::thread;
use nix::sys::socket::{CmsgSpace, ControlMessage, MsgFlags, sendmsg, recvmsg};
use nix::sys::uio::IoVec;
use libc;

let (send, receive) = UnixDatagram::pair().unwrap();
let thread = thread::spawn(move || {
let mut buf = [0u8; 8];
let iovec = [IoVec::from_mut_slice(&mut buf)];
let mut space = CmsgSpace::<[RawFd; 2]>::new();
let msg = recvmsg(
receive.as_raw_fd(),
&iovec,
Some(&mut space),
MsgFlags::empty()
).unwrap();
assert!(!msg.flags.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
match cmsgs.next() {
Some(ControlMessage::ScmRights(fds)) => {
assert_eq!(fds.len(), 2,
"unexpected fd count (expected 2 fds, got {})",
fds.len());
},
_ => panic!(),
}
assert!(cmsgs.next().is_none(), "unexpected control msg");

assert_eq!(iovec[0].as_slice(), [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8]);
});

let slice = [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8];
let iov = [IoVec::from_slice(&slice)];
let fds = [libc::STDIN_FILENO, libc::STDOUT_FILENO]; // pass stdin and stdout
let cmsg = [ControlMessage::ScmRights(&fds)];
sendmsg(send.as_raw_fd(), &iov, &cmsg, MsgFlags::empty(), None).unwrap();
thread.join().unwrap();
}

// Verify `sendmsg` builds a valid `msghdr` when passing an empty
// `cmsgs` argument. This should result in a msghdr with a nullptr
// msg_control field and a msg_controllen of 0 when calling into the
Expand Down