diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e55b7ae19..a5bf6fff6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ This project adheres to [Semantic Versioning](http://semver.org/). ([#922](https://github.com/nix-rust/nix/pull/922)) - Support the `SO_PEERCRED` socket option and the `UnixCredentials` type on all Linux and Android targets. ([#921](https://github.com/nix-rust/nix/pull/921)) +- Added support for `SCM_CREDENTIALS`, allowing to send process credentials over Unix sockets. + ([#923](https://github.com/nix-rust/nix/pull/923)) ### Changed diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index ecef05d8d6..e9537c4b0c 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -205,6 +205,18 @@ cfg_if! { } impl Eq for UnixCredentials {} + impl From for UnixCredentials { + fn from(cred: libc::ucred) -> Self { + UnixCredentials(cred) + } + } + + impl Into for UnixCredentials { + fn into(self) -> libc::ucred { + self.0 + } + } + impl fmt::Debug for UnixCredentials { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("UnixCredentials") @@ -359,7 +371,7 @@ impl CmsgSpace { } } -#[allow(missing_debug_implementations)] +#[derive(Debug)] pub struct RecvMsg<'a> { // The number of bytes received. pub bytes: usize, @@ -374,15 +386,14 @@ impl<'a> RecvMsg<'a> { pub fn cmsgs(&self) -> CmsgIterator { CmsgIterator { buf: self.cmsg_buffer, - next: 0 } } } -#[allow(missing_debug_implementations)] +#[derive(Debug)] pub struct CmsgIterator<'a> { + /// Control message buffer to decode from. Must adhere to cmsg alignment. buf: &'a [u8], - next: usize, } impl<'a> Iterator for CmsgIterator<'a> { @@ -392,53 +403,27 @@ impl<'a> Iterator for CmsgIterator<'a> { // although we handle the invariants in slightly different places to // get a better iterator interface. fn next(&mut self) -> Option> { - let sizeof_cmsghdr = mem::size_of::(); - if self.buf.len() < sizeof_cmsghdr { + if self.buf.len() == 0 { + // The iterator assumes that `self.buf` always contains exactly the + // bytes we need, so we're at the end when the buffer is empty. return None; } - let cmsg: &'a cmsghdr = unsafe { &*(self.buf.as_ptr() as *const cmsghdr) }; - // This check is only in the glibc implementation of CMSG_NXTHDR - // (although it claims the kernel header checks this), but such - // a structure is clearly invalid, either way. - let cmsg_len = cmsg.cmsg_len as usize; - if cmsg_len < sizeof_cmsghdr { - return None; - } - let len = cmsg_len - sizeof_cmsghdr; - let aligned_cmsg_len = if self.next == 0 { - // CMSG_FIRSTHDR - cmsg_len - } else { - // CMSG_NXTHDR - cmsg_align(cmsg_len) + // Safe if: `self.buf` is `cmsghdr`-aligned. + let cmsg: &'a cmsghdr = unsafe { + &*(self.buf[..mem::size_of::()].as_ptr() as *const cmsghdr) }; + let cmsg_len = cmsg.cmsg_len as usize; + // Advance our internal pointer. - if aligned_cmsg_len > self.buf.len() { - return None; - } - let cmsg_data = &self.buf[cmsg_align(sizeof_cmsghdr)..cmsg_len]; - self.buf = &self.buf[aligned_cmsg_len..]; - self.next += 1; - - match (cmsg.cmsg_level, cmsg.cmsg_type) { - (libc::SOL_SOCKET, libc::SCM_RIGHTS) => unsafe { - Some(ControlMessage::ScmRights( - slice::from_raw_parts(cmsg_data.as_ptr() as *const _, - cmsg_data.len() / mem::size_of::()))) - }, - (libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => unsafe { - Some(ControlMessage::ScmTimestamp( - &*(cmsg_data.as_ptr() as *const _))) - }, - (_, _) => unsafe { - Some(ControlMessage::Unknown(UnknownCmsg( - cmsg, - slice::from_raw_parts( - cmsg_data.as_ptr() as *const _, - len)))) - } + let cmsg_data = &self.buf[cmsg_align(mem::size_of::())..cmsg_len]; + self.buf = &self.buf[cmsg_align(cmsg_len)..]; + + // Safe if: `cmsg_data` contains the expected (amount of) content. This + // is verified by the kernel. + unsafe { + Some(ControlMessage::decode_from(cmsg, cmsg_data)) } } } @@ -459,6 +444,20 @@ pub enum ControlMessage<'a> { /// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights` /// message. ScmRights(&'a [RawFd]), + /// A message of type `SCM_CREDENTIALS`, containing the pid, uid and gid of + /// a process connected to the socket. + /// + /// This is similar to the socket option `SO_PEERCRED`, but requires a + /// process to explicitly send its credentials. A process running as root is + /// allowed to specify any credentials, while credentials sent by other + /// processes are verified by the kernel. + /// + /// For further information, please refer to the + /// [`unix(7)`](http://man7.org/linux/man-pages/man7/unix.7.html) man page. + // FIXME: When `#[repr(transparent)]` is stable, use it on `UnixCredentials` + // and put that in here instead of a raw ucred. + #[cfg(any(target_os = "android", target_os = "linux"))] + ScmCredentials(&'a libc::ucred), /// A message of type `SCM_TIMESTAMP`, containing the time the /// packet was received by the kernel. /// @@ -527,6 +526,7 @@ pub enum ControlMessage<'a> { /// nix::unistd::close(in_socket).unwrap(); /// ``` ScmTimestamp(&'a TimeVal), + /// Catch-all variant for unimplemented cmsg types. #[doc(hidden)] Unknown(UnknownCmsg<'a>), } @@ -558,6 +558,10 @@ impl<'a> ControlMessage<'a> { ControlMessage::ScmRights(fds) => { mem::size_of_val(fds) }, + #[cfg(any(target_os = "android", target_os = "linux"))] + ControlMessage::ScmCredentials(creds) => { + mem::size_of_val(creds) + } ControlMessage::ScmTimestamp(t) => { mem::size_of_val(t) }, @@ -567,57 +571,87 @@ impl<'a> ControlMessage<'a> { } } + /// Returns the value to put into the `cmsg_type` field of the header. + fn cmsg_type(&self) -> libc::c_int { + match *self { + ControlMessage::ScmRights(_) => libc::SCM_RIGHTS, + #[cfg(any(target_os = "android", target_os = "linux"))] + ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS, + ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP, + ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type, + } + } + // Unsafe: start and end of buffer must be cmsg_align'd. Updates // the provided slice; panics if the buffer is too small. unsafe fn encode_into(&self, buf: &mut [u8]) { - match *self { - ControlMessage::ScmRights(fds) => { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_RIGHTS, - ..mem::uninitialized() - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(fds, buf); - - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); - }, - ControlMessage::ScmTimestamp(t) => { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_TIMESTAMP, - ..mem::uninitialized() - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(t, buf); - - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); - }, - ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => { - let buf = copy_bytes(orig_cmsg, buf); + let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self { + let &UnknownCmsg(orig_cmsg, bytes) = cmsg; + + let buf = copy_bytes(orig_cmsg, buf); - let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - - mem::size_of_val(&orig_cmsg); - let buf = pad_bytes(padlen, buf); + let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - + mem::size_of_val(&orig_cmsg); + let buf = pad_bytes(padlen, buf); - let buf = copy_bytes(bytes, buf); + copy_bytes(bytes, buf) + } else { + let cmsg = cmsghdr { + cmsg_len: self.len() as _, + cmsg_level: libc::SOL_SOCKET, + cmsg_type: self.cmsg_type(), + ..mem::zeroed() // zero out platform-dependent padding fields + }; + let buf = copy_bytes(&cmsg, buf); + + let padlen = cmsg_align(mem::size_of_val(&cmsg)) - + mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); + + match *self { + ControlMessage::ScmRights(fds) => { + copy_bytes(fds, buf) + }, + #[cfg(any(target_os = "android", target_os = "linux"))] + ControlMessage::ScmCredentials(creds) => { + copy_bytes(creds, buf) + } + ControlMessage::ScmTimestamp(t) => { + copy_bytes(t, buf) + }, + ControlMessage::Unknown(_) => unreachable!(), + } + }; - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); + let padlen = self.space() - self.len(); + pad_bytes(padlen, final_buf); + } + + /// Decodes a `ControlMessage` from raw bytes. + /// + /// This is only safe to call if the data is correct for the message type + /// specified in the header. Normally, the kernel ensures that this is the + /// case. "Correct" in this case includes correct length, alignment and + /// actual content. + unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> { + match (header.cmsg_level, header.cmsg_type) { + (libc::SOL_SOCKET, libc::SCM_RIGHTS) => { + ControlMessage::ScmRights( + slice::from_raw_parts(data.as_ptr() as *const _, + data.len() / mem::size_of::())) + }, + #[cfg(any(target_os = "android", target_os = "linux"))] + (libc::SOL_SOCKET, libc::SCM_CREDENTIALS) => { + ControlMessage::ScmCredentials( + &*(data.as_ptr() as *const _) + ) + } + (libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => { + ControlMessage::ScmTimestamp( + &*(data.as_ptr() as *const _)) + }, + (_, _) => { + ControlMessage::Unknown(UnknownCmsg(header, data)) } } } diff --git a/src/sys/socket/sockopt.rs b/src/sys/socket/sockopt.rs index 56f3a1ee10..494de4f577 100644 --- a/src/sys/socket/sockopt.rs +++ b/src/sys/socket/sockopt.rs @@ -255,6 +255,8 @@ sockopt_impl!(Both, BindAny, libc::SOL_SOCKET, libc::SO_BINDANY, bool); sockopt_impl!(Both, BindAny, libc::IPPROTO_IP, libc::IP_BINDANY, bool); #[cfg(target_os = "linux")] sockopt_impl!(Both, Mark, libc::SOL_SOCKET, libc::SO_MARK, u32); +#[cfg(any(target_os = "android", target_os = "linux"))] +sockopt_impl!(Both, PassCred, libc::SOL_SOCKET, libc::SO_PASSCRED, bool); /* * diff --git a/src/unistd.rs b/src/unistd.rs index 8022aa0b2c..32d0405e0e 100644 --- a/src/unistd.rs +++ b/src/unistd.rs @@ -48,6 +48,11 @@ impl Uid { pub fn is_root(&self) -> bool { *self == ROOT } + + /// Get the raw `uid_t` wrapped by `self`. + pub fn as_raw(&self) -> uid_t { + self.0 + } } impl From for uid_t { @@ -87,6 +92,11 @@ impl Gid { pub fn effective() -> Self { getegid() } + + /// Get the raw `gid_t` wrapped by `self`. + pub fn as_raw(&self) -> gid_t { + self.0 + } } impl From for gid_t { @@ -123,6 +133,11 @@ impl Pid { pub fn parent() -> Self { getppid() } + + /// Get the raw `pid_t` wrapped by `self`. + pub fn as_raw(&self) -> pid_t { + self.0 + } } impl From for pid_t { diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index 35e3bf9052..e5a69eb0f2 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -247,6 +247,158 @@ pub fn test_sendmsg_empty_cmsgs() { } } +#[cfg(any(target_os = "android", target_os = "linux"))] +#[test] +fn test_scm_credentials() { + use libc; + use nix::sys::uio::IoVec; + use nix::unistd::{close, getpid, getuid, getgid}; + use nix::sys::socket::{socketpair, sendmsg, recvmsg, setsockopt, + AddressFamily, SockType, SockFlag, + ControlMessage, CmsgSpace, MsgFlags}; + use nix::sys::socket::sockopt::PassCred; + + let (send, recv) = socketpair(AddressFamily::Unix, SockType::Stream, None, SockFlag::empty()) + .unwrap(); + setsockopt(recv, PassCred, &true).unwrap(); + + { + let iov = [IoVec::from_slice(b"hello")]; + let cred = libc::ucred { + pid: getpid().as_raw(), + uid: getuid().as_raw(), + gid: getgid().as_raw(), + }; + let cmsg = ControlMessage::ScmCredentials(&cred); + assert_eq!(sendmsg(send, &iov, &[cmsg], MsgFlags::empty(), None).unwrap(), 5); + close(send).unwrap(); + } + + { + let mut buf = [0u8; 5]; + let iov = [IoVec::from_mut_slice(&mut buf[..])]; + let mut cmsgspace: CmsgSpace = CmsgSpace::new(); + let msg = recvmsg(recv, &iov, Some(&mut cmsgspace), MsgFlags::empty()).unwrap(); + let mut received_cred = None; + + for cmsg in msg.cmsgs() { + if let ControlMessage::ScmCredentials(cred) = cmsg { + assert!(received_cred.is_none()); + assert_eq!(cred.pid, getpid().as_raw()); + assert_eq!(cred.uid, getuid().as_raw()); + assert_eq!(cred.gid, getgid().as_raw()); + received_cred = Some(*cred); + } else { + panic!("unexpected cmsg"); + } + } + received_cred.expect("no creds received"); + assert!(!msg.flags.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); + close(recv).unwrap(); + } +} + +/// Ensure that we can send `SCM_CREDENTIALS` and `SCM_RIGHTS` with a single +/// `sendmsg` call. +#[cfg(any(target_os = "android", target_os = "linux"))] +// qemu's handling of multiple cmsgs is bugged, ignore tests on non-x86 +// see https://bugs.launchpad.net/qemu/+bug/1781280 +#[cfg_attr(not(any(target_arch = "x86_64", target_arch = "x86")), ignore)] +#[test] +fn test_scm_credentials_and_rights() { + use nix::sys::socket::CmsgSpace; + use libc; + + test_impl_scm_credentials_and_rights(CmsgSpace::<(libc::ucred, CmsgSpace)>::new()); +} + +/// Ensure that passing a `CmsgSpace` with too much space for the received +/// messages still works. +#[cfg(any(target_os = "android", target_os = "linux"))] +// qemu's handling of multiple cmsgs is bugged, ignore tests on non-x86 +// see https://bugs.launchpad.net/qemu/+bug/1781280 +#[cfg_attr(not(any(target_arch = "x86_64", target_arch = "x86")), ignore)] +#[test] +fn test_too_large_cmsgspace() { + use nix::sys::socket::CmsgSpace; + + test_impl_scm_credentials_and_rights(CmsgSpace::<[u8; 1024]>::new()); +} + +#[cfg(any(target_os = "android", target_os = "linux"))] +fn test_impl_scm_credentials_and_rights(mut space: ::nix::sys::socket::CmsgSpace) { + use libc; + use nix::sys::uio::IoVec; + use nix::unistd::{pipe, read, write, close, getpid, getuid, getgid}; + use nix::sys::socket::{socketpair, sendmsg, recvmsg, setsockopt, + AddressFamily, SockType, SockFlag, + ControlMessage, MsgFlags}; + use nix::sys::socket::sockopt::PassCred; + + let (send, recv) = socketpair(AddressFamily::Unix, SockType::Stream, None, SockFlag::empty()) + .unwrap(); + setsockopt(recv, PassCred, &true).unwrap(); + + let (r, w) = pipe().unwrap(); + let mut received_r: Option = None; + + { + let iov = [IoVec::from_slice(b"hello")]; + let cred = libc::ucred { + pid: getpid().as_raw(), + uid: getuid().as_raw(), + gid: getgid().as_raw(), + }; + let fds = [r]; + let cmsgs = [ + ControlMessage::ScmCredentials(&cred), + ControlMessage::ScmRights(&fds), + ]; + assert_eq!(sendmsg(send, &iov, &cmsgs, MsgFlags::empty(), None).unwrap(), 5); + close(r).unwrap(); + close(send).unwrap(); + } + + { + let mut buf = [0u8; 5]; + let iov = [IoVec::from_mut_slice(&mut buf[..])]; + let msg = recvmsg(recv, &iov, Some(&mut space), MsgFlags::empty()).unwrap(); + let mut received_cred = None; + + assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs"); + + for cmsg in msg.cmsgs() { + match cmsg { + ControlMessage::ScmRights(fds) => { + assert_eq!(received_r, None, "already received fd"); + assert_eq!(fds.len(), 1); + received_r = Some(fds[0]); + } + ControlMessage::ScmCredentials(cred) => { + assert!(received_cred.is_none()); + assert_eq!(cred.pid, getpid().as_raw()); + assert_eq!(cred.uid, getuid().as_raw()); + assert_eq!(cred.gid, getgid().as_raw()); + received_cred = Some(*cred); + } + _ => panic!("unexpected cmsg"), + } + } + received_cred.expect("no creds received"); + assert!(!msg.flags.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); + close(recv).unwrap(); + } + + let received_r = received_r.expect("Did not receive passed fd"); + // Ensure that the received file descriptor works + write(w, b"world").unwrap(); + let mut buf = [0u8; 5]; + read(received_r, &mut buf).unwrap(); + assert_eq!(&buf[..], b"world"); + close(received_r).unwrap(); + close(w).unwrap(); +} + // Test creating and using named unix domain sockets #[test] pub fn test_unixdomain() {