Skip to content

Commit

Permalink
Use SockaddrLike trait rather than (deprecated) SockAddr enum.
Browse files Browse the repository at this point in the history
This fixes warnings, but is a breaking API change.
  • Loading branch information
qwandor committed Jun 7, 2022
1 parent b4a73ef commit 23b9fb3
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 55 deletions.
9 changes: 3 additions & 6 deletions echo_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::io::Read;
use std::io::Write;
use std::net::Shutdown;
use std::thread;
use vsock::{SockAddr, VsockAddr, VsockListener};
use vsock::{VsockAddr, VsockListener};

const BLOCK_SIZE: usize = 16384;

Expand Down Expand Up @@ -48,11 +48,8 @@ fn main() {
.parse::<u32>()
.expect("port must be a valid integer");

let listener = VsockListener::bind(&SockAddr::Vsock(VsockAddr::new(
libc::VMADDR_CID_ANY,
listen_port,
)))
.expect("bind and listen failed");
let listener = VsockListener::bind(&VsockAddr::new(libc::VMADDR_CID_ANY, listen_port))
.expect("bind and listen failed");

println!("Server listening for connections on port {}", listen_port);
for stream in listener.incoming() {
Expand Down
66 changes: 20 additions & 46 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Virtio socket support for Rust.
use libc::*;
use nix::ioctl_read_bad;
use nix::{ioctl_read_bad, sys::socket::AddressFamily};
use std::ffi::c_void;
use std::fs::File;
use std::io::{Error, ErrorKind, Read, Result, Write};
Expand All @@ -28,7 +28,7 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::time::Duration;

pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
pub use nix::sys::socket::{SockAddr, VsockAddr};
pub use nix::sys::socket::{SockaddrLike, VsockAddr};

fn new_socket() -> libc::c_int {
unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) }
Expand Down Expand Up @@ -56,28 +56,20 @@ pub struct VsockListener {

impl VsockListener {
/// Create a new VsockListener which is bound and listening on the socket address.
pub fn bind(addr: &SockAddr) -> Result<VsockListener> {
let mut vsock_addr = if let SockAddr::Vsock(addr) = addr {
addr.as_ref().to_owned()
} else {
pub fn bind(addr: &impl SockaddrLike) -> Result<VsockListener> {
if addr.family() != Some(AddressFamily::Vsock) {
return Err(Error::new(
ErrorKind::Other,
"requires a virtio socket address",
));
};
}

let socket = new_socket();
if socket < 0 {
return Err(Error::last_os_error());
}

let res = unsafe {
bind(
socket,
&mut vsock_addr as *mut _ as *mut sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
};
let res = unsafe { bind(socket, addr.as_ptr(), addr.len()) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -93,11 +85,11 @@ impl VsockListener {

/// Create a new VsockListener with specified cid and port.
pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
Self::bind(&SockAddr::Vsock(VsockAddr::new(cid, port)))
Self::bind(&VsockAddr::new(cid, port))
}

/// The local socket address of the listener.
pub fn local_addr(&self) -> Result<SockAddr> {
pub fn local_addr(&self) -> Result<VsockAddr> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
Expand All @@ -116,10 +108,7 @@ impl VsockListener {
{
Err(Error::last_os_error())
} else {
Ok(SockAddr::Vsock(VsockAddr::new(
vsock_addr.svm_cid,
vsock_addr.svm_port,
)))
Ok(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port))
}
}

Expand All @@ -129,7 +118,7 @@ impl VsockListener {
}

/// Accept a new incoming connection from this listener.
pub fn accept(&self) -> Result<(VsockStream, SockAddr)> {
pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
Expand All @@ -151,7 +140,7 @@ impl VsockListener {
} else {
Ok((
unsafe { VsockStream::from_raw_fd(socket as RawFd) },
SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)),
VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port),
))
}
}
Expand Down Expand Up @@ -230,28 +219,19 @@ pub struct VsockStream {

impl VsockStream {
/// Open a connection to a remote host.
pub fn connect(addr: &SockAddr) -> Result<Self> {
let vsock_addr = if let SockAddr::Vsock(addr) = addr {
addr.as_ref()
} else {
pub fn connect(addr: &VsockAddr) -> Result<Self> {
if addr.family() != Some(AddressFamily::Vsock) {
return Err(Error::new(
ErrorKind::Other,
"requires a virtio socket address",
));
};
}

let sock = new_socket();
if sock < 0 {
return Err(Error::last_os_error());
}
if unsafe {
connect(
sock,
vsock_addr as *const _ as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
} < 0
{
if unsafe { connect(sock, addr.as_ptr(), addr.len()) } < 0 {
Err(Error::last_os_error())
} else {
Ok(unsafe { VsockStream::from_raw_fd(sock) })
Expand All @@ -260,11 +240,11 @@ impl VsockStream {

/// Open a connection to a remote host with specified cid and port.
pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
Self::connect(&SockAddr::Vsock(VsockAddr::new(cid, port)))
Self::connect(&VsockAddr::new(cid, port))
}

/// Virtio socket address of the remote peer associated with this connection.
pub fn peer_addr(&self) -> Result<SockAddr> {
pub fn peer_addr(&self) -> Result<VsockAddr> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
Expand All @@ -283,15 +263,12 @@ impl VsockStream {
{
Err(Error::last_os_error())
} else {
Ok(SockAddr::Vsock(VsockAddr::new(
vsock_addr.svm_cid,
vsock_addr.svm_port,
)))
Ok(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port))
}
}

/// Virtio socket address of the local address associated with this connection.
pub fn local_addr(&self) -> Result<SockAddr> {
pub fn local_addr(&self) -> Result<VsockAddr> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
Expand All @@ -310,10 +287,7 @@ impl VsockStream {
{
Err(Error::last_os_error())
} else {
Ok(SockAddr::Vsock(VsockAddr::new(
vsock_addr.svm_cid,
vsock_addr.svm_port,
)))
Ok(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port))
}
}

Expand Down
5 changes: 2 additions & 3 deletions tests/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use rand::RngCore;
use sha2::{Digest, Sha256};
use std::io::{Read, Write};
use vsock::{get_local_cid, SockAddr, VsockAddr, VsockStream, VMADDR_CID_HOST};
use vsock::{get_local_cid, VsockAddr, VsockStream, VMADDR_CID_HOST};

const TEST_BLOB_SIZE: usize = 1_000_000;
const TEST_BLOCK_SIZE: usize = 5_000;
Expand All @@ -39,8 +39,7 @@ fn test_vsock() {
rx_blob.resize(TEST_BLOB_SIZE, 0);
rng.fill_bytes(&mut blob);

let mut stream =
VsockStream::connect(&SockAddr::Vsock(VsockAddr::new(3, 8000))).expect("connection failed");
let mut stream = VsockStream::connect(&VsockAddr::new(3, 8000)).expect("connection failed");

while tx_pos < TEST_BLOB_SIZE {
let written_bytes = stream
Expand Down

0 comments on commit 23b9fb3

Please sign in to comment.