Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wol] Extend wol to support sending magic pattern in udp payload #20523

Merged
merged 7 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sonic-nettools/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
target/
bin/
.vscode/
.vscode/
*.pcap
5 changes: 3 additions & 2 deletions src/sonic-nettools/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion src/sonic-nettools/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@ endif

clean:
rm -rf target
rm -rf bin
rm -rf bin

test:
cargo test

fmt:
cargo clippy
1 change: 1 addition & 0 deletions src/sonic-nettools/wol/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ version = "0.0.1"
[dependencies]
pnet = "0.35.0"
clap = { version = "4.5.7", features = ["derive"] }
libc = "0.2.159"
2 changes: 2 additions & 0 deletions src/sonic-nettools/wol/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod wol;
mod socket;

extern crate clap;
extern crate pnet;
extern crate libc;

fn main() {
if let Err(e) = wol::build_and_send() {
Expand Down
336 changes: 336 additions & 0 deletions src/sonic-nettools/wol/src/socket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
use libc;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::{convert::TryInto, io, ptr, mem};
use std::ffi::CString;
use std::os::raw::c_int;
use std::result::Result;
use std::str::FromStr;

use crate::wol::{ WolErr, WolErrCode, vprintln};

const ANY_INTERFACE: &str = "ANY_INTERFACE";
const IPV4_ANY_ADDR : &str = "0.0.0.0";
const IPV6_ANY_ADDR : &str = "::";

type CSocket = c_int;

pub trait WolSocket {
fn get_socket(&self) -> CSocket;

fn send_magic_packet(&self, buf : &[u8]) -> Result<usize, WolErr> {
let res = unsafe {
libc::send(self.get_socket(), buf.as_ptr() as *const libc::c_void, buf.len(), 0)
};
if res < 0 {
Err(WolErr {
msg: format!("Failed to send packet, rc={}, error: {}", res, io::Error::last_os_error()),
code: WolErrCode::InternalError as i32
})
} else {
Ok(res as usize)
}
}
}

pub struct RawSocket{
pub cs: CSocket
}

impl RawSocket {
pub fn new(intf_name: &str) -> Result<RawSocket, WolErr> {
vprintln(format!("Creating raw socket for interface: {}", intf_name));
let res = unsafe {
libc::socket(libc::AF_PACKET, libc::SOCK_RAW, libc::ETH_P_ALL.to_be())
};
if res < 0 {
return Err(WolErr {
msg: format!("Failed to create raw socket, rc={}, error: {}", res, io::Error::last_os_error()),
code: WolErrCode::InternalError as i32
})
}
let _socket = RawSocket{ cs: res };
_socket.bind_to_intf(intf_name)?;
Ok(_socket)
}

fn bind_to_intf(&self, intf_name: &str) -> Result<(), WolErr> {
vprintln(format!("Binding raw socket to interface: {}", intf_name));
let addr_ll: libc::sockaddr_ll = RawSocket::generate_sockaddr_ll(intf_name)?;

vprintln(format!("Interface index={}, MAC={}", addr_ll.sll_ifindex, addr_ll.sll_addr.iter().map(|x| format!("{:02x}", x)).collect::<Vec<String>>().join(":")));

let res = unsafe {
libc::bind(
self.get_socket(),
&addr_ll as *const libc::sockaddr_ll as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_ll>() as u32,
)
};

assert_return_code_is_zero(res, "Failed to bind raw socket to intface", WolErrCode::SocketError)?;

Ok(())
}

fn generate_sockaddr_ll(intf_name: &str) -> Result<libc::sockaddr_ll, WolErr> {
let mut addr_ll: libc::sockaddr_ll = unsafe { mem::zeroed() };
addr_ll.sll_family = libc::AF_PACKET as u16;
addr_ll.sll_protocol = (libc::ETH_P_ALL as u16).to_be();
addr_ll.sll_halen = 6; // MAC address length in bytes

let mut addrs: *mut libc::ifaddrs = ptr::null_mut();
let res = unsafe {
libc::getifaddrs(&mut addrs)
};
assert_return_code_is_zero(res, "Failed on getifaddrs function", WolErrCode::InternalError)?;

let mut addr = addrs;
while !addr.is_null() {
let addr_ref = unsafe { *addr };
if addr_ref.ifa_name.is_null() {
addr = addr_ref.ifa_next;
continue;
}

let _in = unsafe {
std::ffi::CStr::from_ptr(addr_ref.ifa_name).to_str().unwrap()
};
if _in == intf_name {
addr_ll.sll_ifindex = unsafe {
libc::if_nametoindex(addr_ref.ifa_name) as i32
};
addr_ll.sll_addr = unsafe { (*(addr_ref.ifa_addr as *const libc::sockaddr_ll)).sll_addr };
break;
w1nda marked this conversation as resolved.
Show resolved Hide resolved
}

addr = addr_ref.ifa_next;
}

if addr.is_null() {
return Err(WolErr {
msg: format!("Failed to find interface: {}", intf_name),
code: WolErrCode::InternalError as i32
});
}

unsafe {
libc::freeifaddrs(addrs);
}

Ok(addr_ll)
}
}

impl WolSocket for RawSocket {
fn get_socket(&self) -> CSocket { self.cs }
}

#[derive(Debug)]
pub struct UdpSocket{
pub cs: CSocket
}

impl UdpSocket {
pub fn new(intf_name: &str, dst_port: u16, ip_addr: &str) -> Result<UdpSocket, WolErr> {
vprintln(format!("Creating udp socket for interface: {}, destination port: {}, ip address: {}", intf_name, dst_port, ip_addr));
let res = match ip_addr.contains(':') {
true => unsafe {libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP)},
false => unsafe {libc::socket(libc::AF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP)},
};
if res < 0 {
return Err(WolErr {
msg: format!("Failed to create udp socket, rc={}, error: {}", res, io::Error::last_os_error()),
code: WolErrCode::SocketError as i32
})
}
let _socket = UdpSocket{ cs: res };
_socket.enable_broadcast()?;
_socket.bind_to_intf(intf_name)?;
_socket.connect_to_addr(dst_port, ip_addr)?;
Ok(_socket)
}

fn enable_broadcast(&self) -> Result<(), WolErr> {
vprintln(String::from("Enabling broadcast on udp socket"));
let res = unsafe {
libc::setsockopt(
self.get_socket(),
libc::SOL_SOCKET,
libc::SO_BROADCAST,
&1 as *const i32 as *const libc::c_void,
mem::size_of::<i32>().try_into().unwrap(),
)
};
assert_return_code_is_zero(res, "Failed to enable broadcast on udp socket", WolErrCode::SocketError)?;

Ok(())
}

fn bind_to_intf(&self, intf: &str) -> Result<(), WolErr> {
vprintln(format!("Binding udp socket to interface: {}", intf));
let c_intf = CString::new(intf).map_err(|_| WolErr {
msg: String::from("Invalid interface name for binding"),
code: WolErrCode::SocketError as i32,
})?;
let res = unsafe {
libc::setsockopt(
self.get_socket(),
libc::SOL_SOCKET,
libc::SO_BINDTODEVICE,
c_intf.as_ptr() as *const libc::c_void,
c_intf.as_bytes_with_nul().len() as u32,
)
};
assert_return_code_is_zero(res, "Failed to bind udp socket to interface", WolErrCode::SocketError)?;

Ok(())
}

fn connect_to_addr(&self, port: u16, ip_addr: &str) -> Result<(), WolErr> {
vprintln(format!("Setting udp socket destination as address: {}, port: {}", ip_addr, port));
let (addr, addr_len) = match ip_addr.contains(':') {
true => (
&ipv6_addr(port, ip_addr, ANY_INTERFACE)? as *const libc::sockaddr_in6 as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_in6>() as u32
),
false => (
&ipv4_addr(port, ip_addr)? as *const libc::sockaddr_in as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_in>() as u32
),
};
let res = unsafe {
libc::connect(
self.get_socket(),
addr,
addr_len
)
};
assert_return_code_is_zero(res,"Failed to connect udp socket to address", WolErrCode::SocketError)?;

Ok(())
}
}

impl WolSocket for UdpSocket {
fn get_socket(&self) -> CSocket { self.cs }
}

fn ipv4_addr(port: u16, addr: &str) -> Result<libc::sockaddr_in, WolErr> {
let _addr = match addr == IPV4_ANY_ADDR {
true => libc::in_addr { s_addr: libc::INADDR_ANY },
false => libc::in_addr { s_addr: u32::from(Ipv4Addr::from_str(addr).map_err(|e|
WolErr{
msg: format!("Failed to parse ipv4 address: {}", e),
code: WolErrCode::SocketError as i32
}
)?).to_be()
},
};
Ok(
libc::sockaddr_in {
sin_family: libc::AF_INET as u16,
sin_port: port.to_be(),
sin_addr: _addr,
sin_zero: [0; 8],
}
)
}

fn ipv6_addr(port: u16, addr: &str, intf_name: &str) -> Result<libc::sockaddr_in6, WolErr> {
let _addr = match addr == IPV6_ANY_ADDR {
true => libc::IN6ADDR_ANY_INIT,
false => libc::in6_addr { s6_addr: Ipv6Addr::from_str(addr).map_err(|e|
WolErr{
msg: format!("Failed to parse ipv6 address: {}", e),
code: WolErrCode::SocketError as i32
}
)?.octets()
},
};
let _scope_id= match intf_name == ANY_INTERFACE {
true => 0,
false => unsafe { libc::if_nametoindex(CString::new(intf_name).map_err(|_|
WolErr{
msg: String::from("Invalid interface name for binding"),
code: WolErrCode::SocketError as i32
}
)?.as_ptr()) as u32
}
};
Ok(
libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as u16,
sin6_port: port.to_be(),
sin6_flowinfo: 0,
sin6_addr: _addr,
sin6_scope_id: _scope_id.to_be(),
}
)
}

fn assert_return_code_is_zero(rc: i32, msg: &str, err_code: WolErrCode) -> Result<(), WolErr> {
if rc != 0 {
Err(WolErr {
msg: format!("{}, rc={},error: {}", msg, rc, io::Error::last_os_error()),
code: err_code as i32,
})
} else {
Ok(())
}
}


#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_ipv4_addr() {
let port = 1234;
let addr = ipv4_addr(port, IPV4_ANY_ADDR).unwrap();
assert_eq!(addr.sin_family , libc::AF_INET as u16);
assert_eq!(addr.sin_port.to_le() , port.to_be());
assert_eq!(addr.sin_addr.s_addr , libc::INADDR_ANY);
assert_eq!(addr.sin_zero , [0; 8]);

let ip = "1.1.1.1";
let addr = ipv4_addr(port, &ip).unwrap();
assert_eq!(addr.sin_family , libc::AF_INET as u16);
assert_eq!(addr.sin_port.to_le() , port.to_be());
assert_eq!(addr.sin_addr.s_addr , u32::from(Ipv4Addr::from_str(ip).unwrap()).to_be());
assert_eq!(addr.sin_zero , [0; 8]);
}

#[test]
fn test_ipv6_addr() {
let port = 1234;
let addr = ipv6_addr(port, IPV6_ANY_ADDR, ANY_INTERFACE).unwrap();
assert_eq!(addr.sin6_family , libc::AF_INET6 as u16);
assert_eq!(addr.sin6_port.to_le() , port.to_be());
assert_eq!(addr.sin6_flowinfo , 0);
assert_eq!(addr.sin6_addr.s6_addr , libc::IN6ADDR_ANY_INIT.s6_addr);
assert_eq!(addr.sin6_scope_id , 0);

let ip = "2001:db8::1";
let addr = ipv6_addr(port, &ip, ANY_INTERFACE).unwrap();
assert_eq!(addr.sin6_family , libc::AF_INET6 as u16);
assert_eq!(addr.sin6_port.to_le() , port.to_be());
assert_eq!(addr.sin6_flowinfo , 0);
assert_eq!(addr.sin6_addr.s6_addr , Ipv6Addr::from_str(ip).unwrap().octets());
assert_eq!(addr.sin6_scope_id , 0);
}

#[test]
fn test_assert_return_code_is_zero() {
let rc = 0;
assert_eq!(assert_return_code_is_zero(rc, "", WolErrCode::InternalError).is_ok(), true);

let rc = -1;
let msg = "test";
let err_code = WolErrCode::InternalError;
let result = assert_return_code_is_zero(rc, msg, err_code);
assert_eq!(result.is_err(), true);
assert_eq!(result.as_ref().unwrap_err().code, WolErrCode::InternalError as i32);
assert_eq!(result.unwrap_err().msg, format!("{}, rc=-1,error: {}", msg, io::Error::last_os_error()));
}
}
Loading
Loading