From c0d989ed6b4b840a290a80ec0cdbc8edbce2ee57 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Wed, 16 Mar 2016 20:50:45 -0700 Subject: [PATCH] Add unix socket support to the standard library --- src/libstd/fs.rs | 37 +- src/libstd/sys/common/io.rs | 43 +- src/libstd/sys/common/net.rs | 21 +- src/libstd/sys/unix/ext/mod.rs | 1 + src/libstd/sys/unix/ext/net.rs | 1042 ++++++++++++++++++++++++++++++++ src/libstd/sys/unix/net.rs | 43 +- src/libstd/sys/windows/net.rs | 9 + 7 files changed, 1141 insertions(+), 55 deletions(-) create mode 100644 src/libstd/sys/unix/ext/net.rs diff --git a/src/libstd/fs.rs b/src/libstd/fs.rs index b98344b01a7a3..96534f817f80e 100644 --- a/src/libstd/fs.rs +++ b/src/libstd/fs.rs @@ -1457,12 +1457,12 @@ mod tests { use prelude::v1::*; use io::prelude::*; - use env; use fs::{self, File, OpenOptions}; use io::{ErrorKind, SeekFrom}; - use path::{Path, PathBuf}; - use rand::{self, StdRng, Rng}; + use path::Path; + use rand::{StdRng, Rng}; use str; + use sys_common::io::test::{TempDir, tmpdir}; #[cfg(windows)] use os::windows::fs::{symlink_dir, symlink_file}; @@ -1490,37 +1490,6 @@ mod tests { } ) } - pub struct TempDir(PathBuf); - - impl TempDir { - fn join(&self, path: &str) -> PathBuf { - let TempDir(ref p) = *self; - p.join(path) - } - - fn path<'a>(&'a self) -> &'a Path { - let TempDir(ref p) = *self; - p - } - } - - impl Drop for TempDir { - fn drop(&mut self) { - // Gee, seeing how we're testing the fs module I sure hope that we - // at least implement this correctly! - let TempDir(ref p) = *self; - check!(fs::remove_dir_all(p)); - } - } - - pub fn tmpdir() -> TempDir { - let p = env::temp_dir(); - let mut r = rand::thread_rng(); - let ret = p.join(&format!("rust-{}", r.next_u32())); - check!(fs::create_dir(&ret)); - TempDir(ret) - } - // Several test fail on windows if the user does not have permission to // create symlinks (the `SeCreateSymbolicLinkPrivilege`). Instead of // disabling these test on Windows, use this function to test whether we diff --git a/src/libstd/sys/common/io.rs b/src/libstd/sys/common/io.rs index 9f2f0df3a6470..7b08852ba51d1 100644 --- a/src/libstd/sys/common/io.rs +++ b/src/libstd/sys/common/io.rs @@ -51,6 +51,46 @@ pub unsafe fn read_to_end_uninitialized(r: &mut Read, buf: &mut Vec) -> io:: } } +#[cfg(test)] +pub mod test { + use prelude::v1::*; + use path::{Path, PathBuf}; + use env; + use rand::{self, Rng}; + use fs; + + pub struct TempDir(PathBuf); + + impl TempDir { + pub fn join(&self, path: &str) -> PathBuf { + let TempDir(ref p) = *self; + p.join(path) + } + + pub fn path<'a>(&'a self) -> &'a Path { + let TempDir(ref p) = *self; + p + } + } + + impl Drop for TempDir { + fn drop(&mut self) { + // Gee, seeing how we're testing the fs module I sure hope that we + // at least implement this correctly! + let TempDir(ref p) = *self; + fs::remove_dir_all(p).unwrap(); + } + } + + pub fn tmpdir() -> TempDir { + let p = env::temp_dir(); + let mut r = rand::thread_rng(); + let ret = p.join(&format!("rust-{}", r.next_u32())); + fs::create_dir(&ret).unwrap(); + TempDir(ret) + } +} + #[cfg(test)] mod tests { use prelude::v1::*; @@ -58,7 +98,6 @@ mod tests { use super::*; use io; use io::{ErrorKind, Take, Repeat, repeat}; - use test; use slice::from_raw_parts; struct ErrorRepeat { @@ -129,7 +168,7 @@ mod tests { } #[bench] - fn bench_uninitialized(b: &mut test::Bencher) { + fn bench_uninitialized(b: &mut ::test::Bencher) { b.iter(|| { let mut lr = repeat(1).take(10000000); let mut vec = Vec::with_capacity(1024); diff --git a/src/libstd/sys/common/net.rs b/src/libstd/sys/common/net.rs index 42714feb9217e..3fa70d0ce4be6 100644 --- a/src/libstd/sys/common/net.rs +++ b/src/libstd/sys/common/net.rs @@ -257,12 +257,7 @@ impl TcpStream { } pub fn take_error(&self) -> io::Result> { - let raw: c_int = try!(getsockopt(&self.inner, c::SOL_SOCKET, c::SO_ERROR)); - if raw == 0 { - Ok(None) - } else { - Ok(Some(io::Error::from_raw_os_error(raw as i32))) - } + self.inner.take_error() } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { @@ -367,12 +362,7 @@ impl TcpListener { } pub fn take_error(&self) -> io::Result> { - let raw: c_int = try!(getsockopt(&self.inner, c::SOL_SOCKET, c::SO_ERROR)); - if raw == 0 { - Ok(None) - } else { - Ok(Some(io::Error::from_raw_os_error(raw as i32))) - } + self.inner.take_error() } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { @@ -564,12 +554,7 @@ impl UdpSocket { } pub fn take_error(&self) -> io::Result> { - let raw: c_int = try!(getsockopt(&self.inner, c::SOL_SOCKET, c::SO_ERROR)); - if raw == 0 { - Ok(None) - } else { - Ok(Some(io::Error::from_raw_os_error(raw as i32))) - } + self.inner.take_error() } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { diff --git a/src/libstd/sys/unix/ext/mod.rs b/src/libstd/sys/unix/ext/mod.rs index 276ef25a03a44..4d8f12c2d7c42 100644 --- a/src/libstd/sys/unix/ext/mod.rs +++ b/src/libstd/sys/unix/ext/mod.rs @@ -35,6 +35,7 @@ pub mod fs; pub mod process; pub mod raw; pub mod thread; +pub mod net; /// A prelude for conveniently writing platform-specific code. /// diff --git a/src/libstd/sys/unix/ext/net.rs b/src/libstd/sys/unix/ext/net.rs new file mode 100644 index 0000000000000..2ee825f1ec290 --- /dev/null +++ b/src/libstd/sys/unix/ext/net.rs @@ -0,0 +1,1042 @@ +// Copyright 2016 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +#![unstable(feature = "unix_socket", reason = "newly added", issue = "32312")] + +//! Unix-specific networking functionality + +use libc; + +use prelude::v1::*; +use ascii; +use ffi::OsStr; +use fmt; +use io; +use mem; +use net::Shutdown; +use os::unix::ffi::OsStrExt; +use os::unix::io::{RawFd, AsRawFd, FromRawFd, IntoRawFd}; +use path::Path; +use time::Duration; +use sys::cvt; +use sys::net::Socket; +use sys_common::{AsInner, FromInner, IntoInner}; + +fn sun_path_offset() -> usize { + unsafe { + // Work with an actual instance of the type since using a null pointer is UB + let addr: libc::sockaddr_un = mem::uninitialized(); + let base = &addr as *const _ as usize; + let path = &addr.sun_path as *const _ as usize; + path - base + } +} + +unsafe fn sockaddr_un(path: &Path) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> { + let mut addr: libc::sockaddr_un = mem::zeroed(); + addr.sun_family = libc::AF_UNIX as libc::sa_family_t; + + let bytes = path.as_os_str().as_bytes(); + + if bytes.contains(&0) { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes")); + } + + if bytes.len() >= addr.sun_path.len() { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN")); + } + for (dst, src) in addr.sun_path.iter_mut().zip(bytes.iter()) { + *dst = *src as libc::c_char; + } + // null byte for pathname addresses is already there because we zeroed the + // struct + + let mut len = sun_path_offset() + bytes.len(); + match bytes.get(0) { + Some(&0) | None => {} + Some(_) => len += 1, + } + Ok((addr, len as libc::socklen_t)) +} + +enum AddressKind<'a> { + Unnamed, + Pathname(&'a Path), + Abstract(&'a [u8]), +} + +/// An address associated with a Unix socket. +#[derive(Clone)] +pub struct SocketAddr { + addr: libc::sockaddr_un, + len: libc::socklen_t, +} + +impl SocketAddr { + fn new(f: F) -> io::Result + where F: FnOnce(*mut libc::sockaddr, *mut libc::socklen_t) -> libc::c_int + { + unsafe { + let mut addr: libc::sockaddr_un = mem::zeroed(); + let mut len = mem::size_of::() as libc::socklen_t; + try!(cvt(f(&mut addr as *mut _ as *mut _, &mut len))); + SocketAddr::from_parts(addr, len) + } + } + + fn from_parts(addr: libc::sockaddr_un, mut len: libc::socklen_t) -> io::Result { + if len == 0 { + // When there is a datagram from unnamed unix socket + // linux returns zero bytes of address + len = sun_path_offset() as libc::socklen_t; // i.e. zero-length address + } else if addr.sun_family != libc::AF_UNIX as libc::sa_family_t { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + "file descriptor did not correspond to a Unix socket")); + } + + Ok(SocketAddr { + addr: addr, + len: len, + }) + } + + /// Returns true iff the address is unnamed. + pub fn is_unnamed(&self) -> bool { + if let AddressKind::Unnamed = self.address() { + true + } else { + false + } + } + + /// Returns the contents of this address if it is a `pathname` address. + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } + } + + fn address<'a>(&'a self) -> AddressKind<'a> { + let len = self.len as usize - sun_path_offset(); + let path = unsafe { mem::transmute::<&[libc::c_char], &[u8]>(&self.addr.sun_path) }; + + // OSX seems to return a len of 16 and a zeroed sun_path for unnamed addresses + if len == 0 || (cfg!(not(target_os = "linux")) && self.addr.sun_path[0] == 0) { + AddressKind::Unnamed + } else if self.addr.sun_path[0] == 0 { + AddressKind::Abstract(&path[1..len]) + } else { + AddressKind::Pathname(OsStr::from_bytes(&path[..len - 1]).as_ref()) + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self.address() { + AddressKind::Unnamed => write!(fmt, "(unnamed)"), + AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), + AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), + } + } +} + +struct AsciiEscaped<'a>(&'a [u8]); + +impl<'a> fmt::Display for AsciiEscaped<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + try!(write!(fmt, "\"")); + for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { + try!(write!(fmt, "{}", byte as char)); + } + write!(fmt, "\"") + } +} + +/// A Unix stream socket. +/// +/// # Examples +/// +/// ```rust,no_run +/// #![feature(unix_socket)] +/// +/// use std::os::unix::net::UnixStream; +/// use std::io::prelude::*; +/// +/// let mut stream = UnixStream::connect("/path/to/my/socket").unwrap(); +/// stream.write_all(b"hello world").unwrap(); +/// let mut response = String::new(); +/// stream.read_to_string(&mut response).unwrap(); +/// println!("{}", response); +/// ``` +pub struct UnixStream(Socket); + +impl fmt::Debug for UnixStream { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixStream"); + builder.field("fd", self.0.as_inner()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} + +impl UnixStream { + /// Connects to the socket named by `path`. + pub fn connect>(path: P) -> io::Result { + fn inner(path: &Path) -> io::Result { + unsafe { + let inner = try!(Socket::new_raw(libc::AF_UNIX, libc::SOCK_STREAM)); + let (addr, len) = try!(sockaddr_un(path)); + + try!(cvt(libc::connect(*inner.as_inner(), &addr as *const _ as *const _, len))); + Ok(UnixStream(inner)) + } + } + inner(path.as_ref()) + } + + /// Creates an unnamed pair of connected sockets. + /// + /// Returns two `UnixStream`s which are connected to each other. + pub fn pair() -> io::Result<(UnixStream, UnixStream)> { + let (i1, i2) = try!(Socket::new_pair(libc::AF_UNIX, libc::SOCK_STREAM)); + Ok((UnixStream(i1), UnixStream(i2))) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixStream` is a reference to the same stream that this + /// object references. Both handles will read and write the same stream of + /// data, and options set on one stream will be propogated to the other + /// stream. + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixStream) + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { libc::getsockname(*self.0.as_inner(), addr, len) }) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { libc::getpeername(*self.0.as_inner(), addr, len) }) + } + + /// Sets the read timeout for the socket. + /// + /// If the provided value is `None`, then `read` calls will block + /// indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + self.0.set_timeout(timeout, libc::SO_RCVTIMEO) + } + + /// Sets the write timeout for the socket. + /// + /// If the provided value is `None`, then `write` calls will block + /// indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, timeout: Option) -> io::Result<()> { + self.0.set_timeout(timeout, libc::SO_SNDTIMEO) + } + + /// Returns the read timeout of this socket. + pub fn read_timeout(&self) -> io::Result> { + self.0.timeout(libc::SO_RCVTIMEO) + } + + /// Returns the write timeout of this socket. + pub fn write_timeout(&self) -> io::Result> { + self.0.timeout(libc::SO_SNDTIMEO) + } + + /// Moves the socket into or out of nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } +} + +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + io::Read::read_to_end(&mut &*self, buf) + } +} + +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.0.read_to_end(buf) + } +} + +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} + +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsRawFd for UnixStream { + fn as_raw_fd(&self) -> RawFd { + *self.0.as_inner() + } +} + +impl FromRawFd for UnixStream { + unsafe fn from_raw_fd(fd: RawFd) -> UnixStream { + UnixStream(Socket::from_inner(fd)) + } +} + +impl IntoRawFd for UnixStream { + fn into_raw_fd(self) -> RawFd { + self.0.into_inner() + } +} + +/// A structure representing a Unix domain socket server. +/// +/// # Examples +/// +/// ```rust,no_run +/// #![feature(unix_socket)] +/// +/// use std::thread; +/// use std::os::unix::net::{UnixStream, UnixListener}; +/// +/// fn handle_client(stream: UnixStream) { +/// // ... +/// } +/// +/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); +/// +/// // accept connections and process them, spawning a new thread for each one +/// for stream in listener.incoming() { +/// match stream { +/// Ok(stream) => { +/// /* connection succeeded */ +/// thread::spawn(|| handle_client(stream)); +/// } +/// Err(err) => { +/// /* connection failed */ +/// break; +/// } +/// } +/// } +/// +/// // close the listener socket +/// drop(listener); +/// ``` +pub struct UnixListener(Socket); + +impl fmt::Debug for UnixListener { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixListener"); + builder.field("fd", self.0.as_inner()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + builder.finish() + } +} + +impl UnixListener { + /// Creates a new `UnixListener` bound to the specified socket. + pub fn bind>(path: P) -> io::Result { + fn inner(path: &Path) -> io::Result { + unsafe { + let inner = try!(Socket::new_raw(libc::AF_UNIX, libc::SOCK_STREAM)); + let (addr, len) = try!(sockaddr_un(path)); + + try!(cvt(libc::bind(*inner.as_inner(), &addr as *const _ as *const _, len))); + try!(cvt(libc::listen(*inner.as_inner(), 128))); + + Ok(UnixListener(inner)) + } + } + inner(path.as_ref()) + } + + /// Accepts a new incoming connection to this listener. + /// + /// This function will block the calling thread until a new Unix connection + /// is established. When established, the corersponding `UnixStream` and + /// the remote peer's address will be returned. + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let mut storage: libc::sockaddr_un = unsafe { mem::zeroed() }; + let mut len = mem::size_of_val(&storage) as libc::socklen_t; + let sock = try!(self.0.accept(&mut storage as *mut _ as *mut _, &mut len)); + let addr = try!(SocketAddr::from_parts(storage, len)); + Ok((UnixStream(sock), addr)) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixListener` is a reference to the same socket that this + /// object references. Both handles can be used to accept incoming + /// connections and options set on one listener will affect the other. + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixListener) + } + + /// Returns the local socket address of this listener. + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { libc::getsockname(*self.0.as_inner(), addr, len) }) + } + + /// Moves the socket into or out of nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Returns an iterator over incoming connections. + /// + /// The iterator will never return `None` and will also not yield the + /// peer's `SocketAddr` structure. + pub fn incoming<'a>(&'a self) -> Incoming<'a> { + Incoming { listener: self } + } +} + +impl AsRawFd for UnixListener { + fn as_raw_fd(&self) -> RawFd { + *self.0.as_inner() + } +} + +impl FromRawFd for UnixListener { + unsafe fn from_raw_fd(fd: RawFd) -> UnixListener { + UnixListener(Socket::from_inner(fd)) + } +} + +impl IntoRawFd for UnixListener { + fn into_raw_fd(self) -> RawFd { + self.0.into_inner() + } +} + +impl<'a> IntoIterator for &'a UnixListener { + type Item = io::Result; + type IntoIter = Incoming<'a>; + + fn into_iter(self) -> Incoming<'a> { + self.incoming() + } +} + +/// An iterator over incoming connections to a `UnixListener`. +/// +/// It will never return `None`. +#[derive(Debug)] +pub struct Incoming<'a> { + listener: &'a UnixListener, +} + +impl<'a> Iterator for Incoming<'a> { + type Item = io::Result; + + fn next(&mut self) -> Option> { + Some(self.listener.accept().map(|s| s.0)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), None) + } +} + +/// A Unix datagram socket. +/// +/// # Examples +/// +/// ```rust,no_run +/// #![feature(unix_socket)] +/// +/// use std::os::unix::net::UnixDatagram; +/// +/// let socket = UnixDatagram::bind("/path/to/my/socket").unwrap(); +/// socket.send_to(b"hello world", "/path/to/other/socket").unwrap(); +/// let mut buf = [0; 100]; +/// let (count, address) = socket.recv_from(&mut buf).unwrap(); +/// println!("socket {:?} sent {:?}", address, &buf[..count]); +/// ``` +pub struct UnixDatagram(Socket); + +impl fmt::Debug for UnixDatagram { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixDatagram"); + builder.field("fd", self.0.as_inner()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} + +impl UnixDatagram { + /// Creates a Unix datagram socket bound to the given path. + pub fn bind>(path: P) -> io::Result { + fn inner(path: &Path) -> io::Result { + unsafe { + let socket = try!(UnixDatagram::unbound()); + let (addr, len) = try!(sockaddr_un(path)); + + try!(cvt(libc::bind(*socket.0.as_inner(), &addr as *const _ as *const _, len))); + + Ok(socket) + } + } + inner(path.as_ref()) + } + + /// Creates a Unix Datagram socket which is not bound to any address. + pub fn unbound() -> io::Result { + let inner = try!(Socket::new_raw(libc::AF_UNIX, libc::SOCK_DGRAM)); + Ok(UnixDatagram(inner)) + } + + /// Create an unnamed pair of connected sockets. + /// + /// Returns two `UnixDatagrams`s which are connected to each other. + pub fn pair() -> io::Result<(UnixDatagram, UnixDatagram)> { + let (i1, i2) = try!(Socket::new_pair(libc::AF_UNIX, libc::SOCK_DGRAM)); + Ok((UnixDatagram(i1), UnixDatagram(i2))) + } + + /// Connects the socket to the specified address. + /// + /// The `send` method may be used to send data to the specified address. + /// `recv` and `recv_from` will only receive data from that address. + pub fn connect>(&self, path: P) -> io::Result<()> { + fn inner(d: &UnixDatagram, path: &Path) -> io::Result<()> { + unsafe { + let (addr, len) = try!(sockaddr_un(path)); + + try!(cvt(libc::connect(*d.0.as_inner(), &addr as *const _ as *const _, len))); + + Ok(()) + } + } + inner(self, path.as_ref()) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixListener` is a reference to the same socket that this + /// object references. Both handles can be used to accept incoming + /// connections and options set on one listener will affect the other. + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixDatagram) + } + + /// Returns the address of this socket. + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { libc::getsockname(*self.0.as_inner(), addr, len) }) + } + + /// Returns the address of this socket's peer. + /// + /// The `connect` method will connect the socket to a peer. + pub fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { libc::getpeername(*self.0.as_inner(), addr, len) }) + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read and the address from + /// whence the data came. + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + let mut count = 0; + let addr = try!(SocketAddr::new(|addr, len| { + unsafe { + count = libc::recvfrom(*self.0.as_inner(), + buf.as_mut_ptr() as *mut _, + buf.len(), + 0, + addr, + len); + if count > 0 { + 1 + } else if count == 0 { + 0 + } else { + -1 + } + } + })); + + Ok((count as usize, addr)) + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read. + pub fn recv(&self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + + /// Sends data on the socket to the specified address. + /// + /// On success, returns the number of bytes written. + pub fn send_to>(&self, buf: &[u8], path: P) -> io::Result { + fn inner(d: &UnixDatagram, buf: &[u8], path: &Path) -> io::Result { + unsafe { + let (addr, len) = try!(sockaddr_un(path)); + + let count = try!(cvt(libc::sendto(*d.0.as_inner(), + buf.as_ptr() as *const _, + buf.len(), + 0, + &addr as *const _ as *const _, + len))); + Ok(count as usize) + } + } + inner(self, buf, path.as_ref()) + } + + /// Sends data on the socket to the socket's peer. + /// + /// The peer address may be set by the `connect` method, and this method + /// will return an error if the socket has not already been connected. + /// + /// On success, returns the number of bytes written. + pub fn send(&self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + /// Sets the read timeout for the socket. + /// + /// If the provided value is `None`, then `recv` and `recv_from` calls will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + self.0.set_timeout(timeout, libc::SO_RCVTIMEO) + } + + /// Sets the write timeout for the socket. + /// + /// If the provided value is `None`, then `send` and `send_to` calls will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, timeout: Option) -> io::Result<()> { + self.0.set_timeout(timeout, libc::SO_SNDTIMEO) + } + + /// Returns the read timeout of this socket. + pub fn read_timeout(&self) -> io::Result> { + self.0.timeout(libc::SO_RCVTIMEO) + } + + /// Returns the write timeout of this socket. + pub fn write_timeout(&self) -> io::Result> { + self.0.timeout(libc::SO_SNDTIMEO) + } + + /// Moves the socket into or out of nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Shut down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } +} + +impl AsRawFd for UnixDatagram { + fn as_raw_fd(&self) -> RawFd { + *self.0.as_inner() + } +} + +impl FromRawFd for UnixDatagram { + unsafe fn from_raw_fd(fd: RawFd) -> UnixDatagram { + UnixDatagram(Socket::from_inner(fd)) + } +} + +impl IntoRawFd for UnixDatagram { + fn into_raw_fd(self) -> RawFd { + self.0.into_inner() + } +} + +#[cfg(test)] +mod test { + use prelude::v1::*; + use thread; + use io; + use io::prelude::*; + use time::Duration; + use sys_common::io::test::tmpdir; + + use super::*; + + macro_rules! or_panic { + ($e:expr) => { + match $e { + Ok(e) => e, + Err(e) => panic!("{}", e), + } + } + } + + #[test] + fn basic() { + let dir = tmpdir(); + let socket_path = dir.path().join("sock"); + let msg1 = b"hello"; + let msg2 = b"world!"; + + let listener = or_panic!(UnixListener::bind(&socket_path)); + let thread = thread::spawn(move || { + let mut stream = or_panic!(listener.accept()).0; + let mut buf = [0; 5]; + or_panic!(stream.read(&mut buf)); + assert_eq!(&msg1[..], &buf[..]); + or_panic!(stream.write_all(msg2)); + }); + + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + assert_eq!(Some(&*socket_path), + stream.peer_addr().unwrap().as_pathname()); + or_panic!(stream.write_all(msg1)); + let mut buf = vec![]; + or_panic!(stream.read_to_end(&mut buf)); + assert_eq!(&msg2[..], &buf[..]); + drop(stream); + + thread.join().unwrap(); + } + + #[test] + fn pair() { + let msg1 = b"hello"; + let msg2 = b"world!"; + + let (mut s1, mut s2) = or_panic!(UnixStream::pair()); + let thread = thread::spawn(move || { + // s1 must be moved in or the test will hang! + let mut buf = [0; 5]; + or_panic!(s1.read(&mut buf)); + assert_eq!(&msg1[..], &buf[..]); + or_panic!(s1.write_all(msg2)); + }); + + or_panic!(s2.write_all(msg1)); + let mut buf = vec![]; + or_panic!(s2.read_to_end(&mut buf)); + assert_eq!(&msg2[..], &buf[..]); + drop(s2); + + thread.join().unwrap(); + } + + #[test] + fn try_clone() { + let dir = tmpdir(); + let socket_path = dir.path().join("sock"); + let msg1 = b"hello"; + let msg2 = b"world"; + + let listener = or_panic!(UnixListener::bind(&socket_path)); + let thread = thread::spawn(move || { + let mut stream = or_panic!(listener.accept()).0; + or_panic!(stream.write_all(msg1)); + or_panic!(stream.write_all(msg2)); + }); + + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + let mut stream2 = or_panic!(stream.try_clone()); + + let mut buf = [0; 5]; + or_panic!(stream.read(&mut buf)); + assert_eq!(&msg1[..], &buf[..]); + or_panic!(stream2.read(&mut buf)); + assert_eq!(&msg2[..], &buf[..]); + + thread.join().unwrap(); + } + + #[test] + fn iter() { + let dir = tmpdir(); + let socket_path = dir.path().join("sock"); + + let listener = or_panic!(UnixListener::bind(&socket_path)); + let thread = thread::spawn(move || { + for stream in listener.incoming().take(2) { + let mut stream = or_panic!(stream); + let mut buf = [0]; + or_panic!(stream.read(&mut buf)); + } + }); + + for _ in 0..2 { + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + or_panic!(stream.write_all(&[0])); + } + + thread.join().unwrap(); + } + + #[test] + fn long_path() { + let dir = tmpdir(); + let socket_path = dir.path() + .join("asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfa\ + sasdfasdfasdasdfasdfasdfadfasdfasdfasdfasdfasdf"); + match UnixStream::connect(&socket_path) { + Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} + Err(e) => panic!("unexpected error {}", e), + Ok(_) => panic!("unexpected success"), + } + + match UnixListener::bind(&socket_path) { + Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} + Err(e) => panic!("unexpected error {}", e), + Ok(_) => panic!("unexpected success"), + } + + match UnixDatagram::bind(&socket_path) { + Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} + Err(e) => panic!("unexpected error {}", e), + Ok(_) => panic!("unexpected success"), + } + } + + #[test] + fn timeouts() { + let dir = tmpdir(); + let socket_path = dir.path().join("sock"); + + let _listener = or_panic!(UnixListener::bind(&socket_path)); + + let stream = or_panic!(UnixStream::connect(&socket_path)); + let dur = Duration::new(15410, 0); + + assert_eq!(None, or_panic!(stream.read_timeout())); + + or_panic!(stream.set_read_timeout(Some(dur))); + assert_eq!(Some(dur), or_panic!(stream.read_timeout())); + + assert_eq!(None, or_panic!(stream.write_timeout())); + + or_panic!(stream.set_write_timeout(Some(dur))); + assert_eq!(Some(dur), or_panic!(stream.write_timeout())); + + or_panic!(stream.set_read_timeout(None)); + assert_eq!(None, or_panic!(stream.read_timeout())); + + or_panic!(stream.set_write_timeout(None)); + assert_eq!(None, or_panic!(stream.write_timeout())); + } + + #[test] + fn test_read_timeout() { + let dir = tmpdir(); + let socket_path = dir.path().join("sock"); + + let _listener = or_panic!(UnixListener::bind(&socket_path)); + + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + or_panic!(stream.set_read_timeout(Some(Duration::from_millis(1000)))); + + let mut buf = [0; 10]; + let kind = stream.read(&mut buf).err().expect("expected error").kind(); + assert!(kind == io::ErrorKind::WouldBlock || kind == io::ErrorKind::TimedOut); + } + + #[test] + fn test_read_with_timeout() { + let dir = tmpdir(); + let socket_path = dir.path().join("sock"); + + let listener = or_panic!(UnixListener::bind(&socket_path)); + + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + or_panic!(stream.set_read_timeout(Some(Duration::from_millis(1000)))); + + let mut other_end = or_panic!(listener.accept()).0; + or_panic!(other_end.write_all(b"hello world")); + + let mut buf = [0; 11]; + or_panic!(stream.read(&mut buf)); + assert_eq!(b"hello world", &buf[..]); + + let kind = stream.read(&mut buf).err().expect("expected error").kind(); + assert!(kind == io::ErrorKind::WouldBlock || kind == io::ErrorKind::TimedOut); + } + + #[test] + fn test_unix_datagram() { + let dir = tmpdir(); + let path1 = dir.path().join("sock1"); + let path2 = dir.path().join("sock2"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::bind(&path2)); + + let msg = b"hello world"; + or_panic!(sock1.send_to(msg, &path2)); + let mut buf = [0; 11]; + or_panic!(sock2.recv_from(&mut buf)); + assert_eq!(msg, &buf[..]); + } + + #[test] + fn test_unnamed_unix_datagram() { + let dir = tmpdir(); + let path1 = dir.path().join("sock1"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::unbound()); + + let msg = b"hello world"; + or_panic!(sock2.send_to(msg, &path1)); + let mut buf = [0; 11]; + let (usize, addr) = or_panic!(sock1.recv_from(&mut buf)); + assert_eq!(usize, 11); + assert!(addr.is_unnamed()); + assert_eq!(msg, &buf[..]); + } + + #[test] + fn test_connect_unix_datagram() { + let dir = tmpdir(); + let path1 = dir.path().join("sock1"); + let path2 = dir.path().join("sock2"); + + let bsock1 = or_panic!(UnixDatagram::bind(&path1)); + let bsock2 = or_panic!(UnixDatagram::bind(&path2)); + let sock = or_panic!(UnixDatagram::unbound()); + or_panic!(sock.connect(&path1)); + + // Check send() + let msg = b"hello there"; + or_panic!(sock.send(msg)); + let mut buf = [0; 11]; + let (usize, addr) = or_panic!(bsock1.recv_from(&mut buf)); + assert_eq!(usize, 11); + assert!(addr.is_unnamed()); + assert_eq!(msg, &buf[..]); + + // Changing default socket works too + or_panic!(sock.connect(&path2)); + or_panic!(sock.send(msg)); + or_panic!(bsock2.recv_from(&mut buf)); + } + + #[test] + fn test_unix_datagram_recv() { + let dir = tmpdir(); + let path1 = dir.path().join("sock1"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::unbound()); + or_panic!(sock2.connect(&path1)); + + let msg = b"hello world"; + or_panic!(sock2.send(msg)); + let mut buf = [0; 11]; + let size = or_panic!(sock1.recv(&mut buf)); + assert_eq!(size, 11); + assert_eq!(msg, &buf[..]); + } + + #[test] + fn datagram_pair() { + let msg1 = b"hello"; + let msg2 = b"world!"; + + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + // s1 must be moved in or the test will hang! + let mut buf = [0; 5]; + or_panic!(s1.recv(&mut buf)); + assert_eq!(&msg1[..], &buf[..]); + or_panic!(s1.send(msg2)); + }); + + or_panic!(s2.send(msg1)); + let mut buf = [0; 6]; + or_panic!(s2.recv(&mut buf)); + assert_eq!(&msg2[..], &buf[..]); + drop(s2); + + thread.join().unwrap(); + } + + #[test] + fn abstract_namespace_not_allowed() { + assert!(UnixStream::connect("\0asdf").is_err()); + } +} diff --git a/src/libstd/sys/unix/net.rs b/src/libstd/sys/unix/net.rs index acf501d5fda88..d7b353c9b2aab 100644 --- a/src/libstd/sys/unix/net.rs +++ b/src/libstd/sys/unix/net.rs @@ -57,13 +57,17 @@ impl Socket { SocketAddr::V4(..) => libc::AF_INET, SocketAddr::V6(..) => libc::AF_INET6, }; + Socket::new_raw(fam, ty) + } + + pub fn new_raw(fam: c_int, ty: c_int) -> io::Result { unsafe { // On linux we first attempt to pass the SOCK_CLOEXEC flag to // atomically create the socket and set it as CLOEXEC. Support for // this option, however, was added in 2.6.27, and we still support // 2.6.18 as a kernel, so if the returned error is EINVAL we // fallthrough to the fallback. - if cfg!(target_os = "linux") { + if cfg!(linux) { match cvt(libc::socket(fam, ty | SOCK_CLOEXEC, 0)) { Ok(fd) => return Ok(Socket(FileDesc::new(fd))), Err(ref e) if e.raw_os_error() == Some(libc::EINVAL) => {} @@ -78,6 +82,30 @@ impl Socket { } } + pub fn new_pair(fam: c_int, ty: c_int) -> io::Result<(Socket, Socket)> { + unsafe { + let mut fds = [0, 0]; + + // Like above, see if we can set cloexec atomically + if cfg!(linux) { + match cvt(libc::socketpair(fam, ty | SOCK_CLOEXEC, 0, fds.as_mut_ptr())) { + Ok(_) => { + return Ok((Socket(FileDesc::new(fds[0])), Socket(FileDesc::new(fds[1])))); + } + Err(ref e) if e.raw_os_error() == Some(libc::EINVAL) => {}, + Err(e) => return Err(e), + } + } + + try!(cvt(libc::socketpair(fam, ty, 0, fds.as_mut_ptr()))); + let a = FileDesc::new(fds[0]); + a.set_cloexec(); + let b = FileDesc::new(fds[1]); + b.set_cloexec(); + Ok((Socket(a), Socket(b))) + } + } + pub fn accept(&self, storage: *mut sockaddr, len: *mut socklen_t) -> io::Result { // Unfortunately the only known way right now to accept a socket and @@ -120,6 +148,10 @@ impl Socket { self.0.read_to_end(buf) } + pub fn write(&self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + pub fn set_timeout(&self, dur: Option, kind: libc::c_int) -> io::Result<()> { let timeout = match dur { Some(dur) => { @@ -186,6 +218,15 @@ impl Socket { let mut nonblocking = nonblocking as libc::c_ulong; cvt(unsafe { libc::ioctl(*self.as_inner(), libc::FIONBIO, &mut nonblocking) }).map(|_| ()) } + + pub fn take_error(&self) -> io::Result> { + let raw: c_int = try!(getsockopt(self, libc::SOL_SOCKET, libc::SO_ERROR)); + if raw == 0 { + Ok(None) + } else { + Ok(Some(io::Error::from_raw_os_error(raw as i32))) + } + } } impl AsInner for Socket { diff --git a/src/libstd/sys/windows/net.rs b/src/libstd/sys/windows/net.rs index bb3c79c5a84fd..3a2f06418cf8d 100644 --- a/src/libstd/sys/windows/net.rs +++ b/src/libstd/sys/windows/net.rs @@ -214,6 +214,15 @@ impl Socket { let raw: c::BYTE = try!(net::getsockopt(self, c::IPPROTO_TCP, c::TCP_NODELAY)); Ok(raw != 0) } + + pub fn take_error(&self) -> io::Result> { + let raw: c_int = try!(net::getsockopt(self, c::SOL_SOCKET, c::SO_ERROR)); + if raw == 0 { + Ok(None) + } else { + Ok(Some(io::Error::from_raw_os_error(raw as i32))) + } + } } #[unstable(reason = "not public", issue = "0", feature = "fd_read")]