From cc19539b3d69f595377b92210edaebc1ee2c7476 Mon Sep 17 00:00:00 2001 From: Andrew Walbran Date: Tue, 7 Jun 2022 14:44:55 +0100 Subject: [PATCH] Use SockaddrLike trait rather than (deprecated) SockAddr enum. This fixes warnings, but is a breaking API change. Signed-off-by: Andrew Walbran --- echo_server/src/main.rs | 9 ++---- src/lib.rs | 66 +++++++++++++---------------------------- tests/vsock.rs | 5 ++-- 3 files changed, 25 insertions(+), 55 deletions(-) diff --git a/echo_server/src/main.rs b/echo_server/src/main.rs index 2e60cd8..1906f35 100644 --- a/echo_server/src/main.rs +++ b/echo_server/src/main.rs @@ -19,7 +19,7 @@ use std::io::Read; use std::io::Write; use std::net::Shutdown; use std::thread; -use vsock::{SockAddr, VsockAddr, VsockListener}; +use vsock::{VsockAddr, VsockListener}; const BLOCK_SIZE: usize = 16384; @@ -48,11 +48,8 @@ fn main() { .parse::() .expect("port must be a valid integer"); - let listener = VsockListener::bind(&SockAddr::Vsock(VsockAddr::new( - libc::VMADDR_CID_ANY, - listen_port, - ))) - .expect("bind and listen failed"); + let listener = VsockListener::bind(&VsockAddr::new(libc::VMADDR_CID_ANY, listen_port)) + .expect("bind and listen failed"); println!("Server listening for connections on port {}", listen_port); for stream in listener.incoming() { diff --git a/src/lib.rs b/src/lib.rs index d8c9371..2341e46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ //! Virtio socket support for Rust. use libc::*; -use nix::ioctl_read_bad; +use nix::{ioctl_read_bad, sys::socket::AddressFamily}; use std::ffi::c_void; use std::fs::File; use std::io::{Error, ErrorKind, Read, Result, Write}; @@ -28,7 +28,7 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::time::Duration; pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL}; -pub use nix::sys::socket::{SockAddr, VsockAddr}; +pub use nix::sys::socket::{SockaddrLike, VsockAddr}; fn new_socket() -> libc::c_int { unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) } @@ -56,28 +56,20 @@ pub struct VsockListener { impl VsockListener { /// Create a new VsockListener which is bound and listening on the socket address. - pub fn bind(addr: &SockAddr) -> Result { - let mut vsock_addr = if let SockAddr::Vsock(addr) = addr { - addr.as_ref().to_owned() - } else { + pub fn bind(addr: &impl SockaddrLike) -> Result { + if addr.family() != Some(AddressFamily::Vsock) { return Err(Error::new( ErrorKind::Other, "requires a virtio socket address", )); - }; + } let socket = new_socket(); if socket < 0 { return Err(Error::last_os_error()); } - let res = unsafe { - bind( - socket, - &mut vsock_addr as *mut _ as *mut sockaddr, - size_of::() as socklen_t, - ) - }; + let res = unsafe { bind(socket, addr.as_ptr(), addr.len()) }; if res < 0 { return Err(Error::last_os_error()); } @@ -93,11 +85,11 @@ impl VsockListener { /// Create a new VsockListener with specified cid and port. pub fn bind_with_cid_port(cid: u32, port: u32) -> Result { - Self::bind(&SockAddr::Vsock(VsockAddr::new(cid, port))) + Self::bind(&VsockAddr::new(cid, port)) } /// The local socket address of the listener. - pub fn local_addr(&self) -> Result { + pub fn local_addr(&self) -> Result { let mut vsock_addr = sockaddr_vm { svm_family: AF_VSOCK as sa_family_t, svm_reserved1: 0, @@ -116,10 +108,7 @@ impl VsockListener { { Err(Error::last_os_error()) } else { - Ok(SockAddr::Vsock(VsockAddr::new( - vsock_addr.svm_cid, - vsock_addr.svm_port, - ))) + Ok(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)) } } @@ -129,7 +118,7 @@ impl VsockListener { } /// Accept a new incoming connection from this listener. - pub fn accept(&self) -> Result<(VsockStream, SockAddr)> { + pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> { let mut vsock_addr = sockaddr_vm { svm_family: AF_VSOCK as sa_family_t, svm_reserved1: 0, @@ -151,7 +140,7 @@ impl VsockListener { } else { Ok(( unsafe { VsockStream::from_raw_fd(socket as RawFd) }, - SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)), + VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port), )) } } @@ -230,28 +219,19 @@ pub struct VsockStream { impl VsockStream { /// Open a connection to a remote host. - pub fn connect(addr: &SockAddr) -> Result { - let vsock_addr = if let SockAddr::Vsock(addr) = addr { - addr.as_ref() - } else { + pub fn connect(addr: &VsockAddr) -> Result { + if addr.family() != Some(AddressFamily::Vsock) { return Err(Error::new( ErrorKind::Other, "requires a virtio socket address", )); - }; + } let sock = new_socket(); if sock < 0 { return Err(Error::last_os_error()); } - if unsafe { - connect( - sock, - vsock_addr as *const _ as *const sockaddr, - size_of::() as socklen_t, - ) - } < 0 - { + if unsafe { connect(sock, addr.as_ptr(), addr.len()) } < 0 { Err(Error::last_os_error()) } else { Ok(unsafe { VsockStream::from_raw_fd(sock) }) @@ -260,11 +240,11 @@ impl VsockStream { /// Open a connection to a remote host with specified cid and port. pub fn connect_with_cid_port(cid: u32, port: u32) -> Result { - Self::connect(&SockAddr::Vsock(VsockAddr::new(cid, port))) + Self::connect(&VsockAddr::new(cid, port)) } /// Virtio socket address of the remote peer associated with this connection. - pub fn peer_addr(&self) -> Result { + pub fn peer_addr(&self) -> Result { let mut vsock_addr = sockaddr_vm { svm_family: AF_VSOCK as sa_family_t, svm_reserved1: 0, @@ -283,15 +263,12 @@ impl VsockStream { { Err(Error::last_os_error()) } else { - Ok(SockAddr::Vsock(VsockAddr::new( - vsock_addr.svm_cid, - vsock_addr.svm_port, - ))) + Ok(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)) } } /// Virtio socket address of the local address associated with this connection. - pub fn local_addr(&self) -> Result { + pub fn local_addr(&self) -> Result { let mut vsock_addr = sockaddr_vm { svm_family: AF_VSOCK as sa_family_t, svm_reserved1: 0, @@ -310,10 +287,7 @@ impl VsockStream { { Err(Error::last_os_error()) } else { - Ok(SockAddr::Vsock(VsockAddr::new( - vsock_addr.svm_cid, - vsock_addr.svm_port, - ))) + Ok(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)) } } diff --git a/tests/vsock.rs b/tests/vsock.rs index 52d908b..99d47f7 100644 --- a/tests/vsock.rs +++ b/tests/vsock.rs @@ -17,7 +17,7 @@ use rand::RngCore; use sha2::{Digest, Sha256}; use std::io::{Read, Write}; -use vsock::{get_local_cid, SockAddr, VsockAddr, VsockStream, VMADDR_CID_HOST}; +use vsock::{get_local_cid, VsockAddr, VsockStream, VMADDR_CID_HOST}; const TEST_BLOB_SIZE: usize = 1_000_000; const TEST_BLOCK_SIZE: usize = 5_000; @@ -39,8 +39,7 @@ fn test_vsock() { rx_blob.resize(TEST_BLOB_SIZE, 0); rng.fill_bytes(&mut blob); - let mut stream = - VsockStream::connect(&SockAddr::Vsock(VsockAddr::new(3, 8000))).expect("connection failed"); + let mut stream = VsockStream::connect(&VsockAddr::new(3, 8000)).expect("connection failed"); while tx_pos < TEST_BLOB_SIZE { let written_bytes = stream