diff --git a/src/sonic-nettools/.gitignore b/src/sonic-nettools/.gitignore index 046d25633d83..5a87e35f5600 100644 --- a/src/sonic-nettools/.gitignore +++ b/src/sonic-nettools/.gitignore @@ -1,3 +1,4 @@ target/ bin/ -.vscode/ \ No newline at end of file +.vscode/ +*.pcap \ No newline at end of file diff --git a/src/sonic-nettools/Cargo.lock b/src/sonic-nettools/Cargo.lock index 695393102228..2f75b62942c2 100644 --- a/src/sonic-nettools/Cargo.lock +++ b/src/sonic-nettools/Cargo.lock @@ -135,9 +135,9 @@ checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "memchr" @@ -438,5 +438,6 @@ name = "wol" version = "0.0.1" dependencies = [ "clap", + "libc", "pnet", ] diff --git a/src/sonic-nettools/Makefile b/src/sonic-nettools/Makefile index 021b1eb805c6..db8242cf4ff0 100644 --- a/src/sonic-nettools/Makefile +++ b/src/sonic-nettools/Makefile @@ -14,4 +14,10 @@ endif clean: rm -rf target - rm -rf bin \ No newline at end of file + rm -rf bin + +test: + cargo test + +fmt: + cargo clippy diff --git a/src/sonic-nettools/wol/Cargo.toml b/src/sonic-nettools/wol/Cargo.toml index 6b2276e79724..26c8fd1c4f94 100644 --- a/src/sonic-nettools/wol/Cargo.toml +++ b/src/sonic-nettools/wol/Cargo.toml @@ -5,3 +5,4 @@ version = "0.0.1" [dependencies] pnet = "0.35.0" clap = { version = "4.5.7", features = ["derive"] } +libc = "0.2.159" diff --git a/src/sonic-nettools/wol/src/main.rs b/src/sonic-nettools/wol/src/main.rs index c04ba9411bcd..368f35a3c8ba 100644 --- a/src/sonic-nettools/wol/src/main.rs +++ b/src/sonic-nettools/wol/src/main.rs @@ -1,7 +1,9 @@ mod wol; +mod socket; extern crate clap; extern crate pnet; +extern crate libc; fn main() { if let Err(e) = wol::build_and_send() { diff --git a/src/sonic-nettools/wol/src/socket.rs b/src/sonic-nettools/wol/src/socket.rs new file mode 100644 index 000000000000..c67bcba82a32 --- /dev/null +++ b/src/sonic-nettools/wol/src/socket.rs @@ -0,0 +1,336 @@ +use libc; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::{convert::TryInto, io, ptr, mem}; +use std::ffi::CString; +use std::os::raw::c_int; +use std::result::Result; +use std::str::FromStr; + +use crate::wol::{ WolErr, WolErrCode, vprintln}; + +const ANY_INTERFACE: &str = "ANY_INTERFACE"; +const IPV4_ANY_ADDR : &str = "0.0.0.0"; +const IPV6_ANY_ADDR : &str = "::"; + +type CSocket = c_int; + +pub trait WolSocket { + fn get_socket(&self) -> CSocket; + + fn send_magic_packet(&self, buf : &[u8]) -> Result { + let res = unsafe { + libc::send(self.get_socket(), buf.as_ptr() as *const libc::c_void, buf.len(), 0) + }; + if res < 0 { + Err(WolErr { + msg: format!("Failed to send packet, rc={}, error: {}", res, io::Error::last_os_error()), + code: WolErrCode::InternalError as i32 + }) + } else { + Ok(res as usize) + } + } +} + +pub struct RawSocket{ + pub cs: CSocket +} + +impl RawSocket { + pub fn new(intf_name: &str) -> Result { + vprintln(format!("Creating raw socket for interface: {}", intf_name)); + let res = unsafe { + libc::socket(libc::AF_PACKET, libc::SOCK_RAW, libc::ETH_P_ALL.to_be()) + }; + if res < 0 { + return Err(WolErr { + msg: format!("Failed to create raw socket, rc={}, error: {}", res, io::Error::last_os_error()), + code: WolErrCode::InternalError as i32 + }) + } + let _socket = RawSocket{ cs: res }; + _socket.bind_to_intf(intf_name)?; + Ok(_socket) + } + + fn bind_to_intf(&self, intf_name: &str) -> Result<(), WolErr> { + vprintln(format!("Binding raw socket to interface: {}", intf_name)); + let addr_ll: libc::sockaddr_ll = RawSocket::generate_sockaddr_ll(intf_name)?; + + vprintln(format!("Interface index={}, MAC={}", addr_ll.sll_ifindex, addr_ll.sll_addr.iter().map(|x| format!("{:02x}", x)).collect::>().join(":"))); + + let res = unsafe { + libc::bind( + self.get_socket(), + &addr_ll as *const libc::sockaddr_ll as *const libc::sockaddr, + mem::size_of::() as u32, + ) + }; + + assert_return_code_is_zero(res, "Failed to bind raw socket to intface", WolErrCode::SocketError)?; + + Ok(()) + } + + fn generate_sockaddr_ll(intf_name: &str) -> Result { + let mut addr_ll: libc::sockaddr_ll = unsafe { mem::zeroed() }; + addr_ll.sll_family = libc::AF_PACKET as u16; + addr_ll.sll_protocol = (libc::ETH_P_ALL as u16).to_be(); + addr_ll.sll_halen = 6; // MAC address length in bytes + + let mut addrs: *mut libc::ifaddrs = ptr::null_mut(); + let res = unsafe { + libc::getifaddrs(&mut addrs) + }; + assert_return_code_is_zero(res, "Failed on getifaddrs function", WolErrCode::InternalError)?; + + let mut addr = addrs; + while !addr.is_null() { + let addr_ref = unsafe { *addr }; + if addr_ref.ifa_name.is_null() { + addr = addr_ref.ifa_next; + continue; + } + + let _in = unsafe { + std::ffi::CStr::from_ptr(addr_ref.ifa_name).to_str().unwrap() + }; + if _in == intf_name { + addr_ll.sll_ifindex = unsafe { + libc::if_nametoindex(addr_ref.ifa_name) as i32 + }; + addr_ll.sll_addr = unsafe { (*(addr_ref.ifa_addr as *const libc::sockaddr_ll)).sll_addr }; + break; + } + + addr = addr_ref.ifa_next; + } + + if addr.is_null() { + return Err(WolErr { + msg: format!("Failed to find interface: {}", intf_name), + code: WolErrCode::InternalError as i32 + }); + } + + unsafe { + libc::freeifaddrs(addrs); + } + + Ok(addr_ll) + } +} + +impl WolSocket for RawSocket { + fn get_socket(&self) -> CSocket { self.cs } +} + +#[derive(Debug)] +pub struct UdpSocket{ + pub cs: CSocket +} + +impl UdpSocket { + pub fn new(intf_name: &str, dst_port: u16, ip_addr: &str) -> Result { + vprintln(format!("Creating udp socket for interface: {}, destination port: {}, ip address: {}", intf_name, dst_port, ip_addr)); + let res = match ip_addr.contains(':') { + true => unsafe {libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP)}, + false => unsafe {libc::socket(libc::AF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP)}, + }; + if res < 0 { + return Err(WolErr { + msg: format!("Failed to create udp socket, rc={}, error: {}", res, io::Error::last_os_error()), + code: WolErrCode::SocketError as i32 + }) + } + let _socket = UdpSocket{ cs: res }; + _socket.enable_broadcast()?; + _socket.bind_to_intf(intf_name)?; + _socket.connect_to_addr(dst_port, ip_addr)?; + Ok(_socket) + } + + fn enable_broadcast(&self) -> Result<(), WolErr> { + vprintln(String::from("Enabling broadcast on udp socket")); + let res = unsafe { + libc::setsockopt( + self.get_socket(), + libc::SOL_SOCKET, + libc::SO_BROADCAST, + &1 as *const i32 as *const libc::c_void, + mem::size_of::().try_into().unwrap(), + ) + }; + assert_return_code_is_zero(res, "Failed to enable broadcast on udp socket", WolErrCode::SocketError)?; + + Ok(()) + } + + fn bind_to_intf(&self, intf: &str) -> Result<(), WolErr> { + vprintln(format!("Binding udp socket to interface: {}", intf)); + let c_intf = CString::new(intf).map_err(|_| WolErr { + msg: String::from("Invalid interface name for binding"), + code: WolErrCode::SocketError as i32, + })?; + let res = unsafe { + libc::setsockopt( + self.get_socket(), + libc::SOL_SOCKET, + libc::SO_BINDTODEVICE, + c_intf.as_ptr() as *const libc::c_void, + c_intf.as_bytes_with_nul().len() as u32, + ) + }; + assert_return_code_is_zero(res, "Failed to bind udp socket to interface", WolErrCode::SocketError)?; + + Ok(()) + } + + fn connect_to_addr(&self, port: u16, ip_addr: &str) -> Result<(), WolErr> { + vprintln(format!("Setting udp socket destination as address: {}, port: {}", ip_addr, port)); + let (addr, addr_len) = match ip_addr.contains(':') { + true => ( + &ipv6_addr(port, ip_addr, ANY_INTERFACE)? as *const libc::sockaddr_in6 as *const libc::sockaddr, + mem::size_of::() as u32 + ), + false => ( + &ipv4_addr(port, ip_addr)? as *const libc::sockaddr_in as *const libc::sockaddr, + mem::size_of::() as u32 + ), + }; + let res = unsafe { + libc::connect( + self.get_socket(), + addr, + addr_len + ) + }; + assert_return_code_is_zero(res,"Failed to connect udp socket to address", WolErrCode::SocketError)?; + + Ok(()) + } +} + +impl WolSocket for UdpSocket { + fn get_socket(&self) -> CSocket { self.cs } +} + +fn ipv4_addr(port: u16, addr: &str) -> Result { + let _addr = match addr == IPV4_ANY_ADDR { + true => libc::in_addr { s_addr: libc::INADDR_ANY }, + false => libc::in_addr { s_addr: u32::from(Ipv4Addr::from_str(addr).map_err(|e| + WolErr{ + msg: format!("Failed to parse ipv4 address: {}", e), + code: WolErrCode::SocketError as i32 + } + )?).to_be() + }, + }; + Ok( + libc::sockaddr_in { + sin_family: libc::AF_INET as u16, + sin_port: port.to_be(), + sin_addr: _addr, + sin_zero: [0; 8], + } + ) +} + +fn ipv6_addr(port: u16, addr: &str, intf_name: &str) -> Result { + let _addr = match addr == IPV6_ANY_ADDR { + true => libc::IN6ADDR_ANY_INIT, + false => libc::in6_addr { s6_addr: Ipv6Addr::from_str(addr).map_err(|e| + WolErr{ + msg: format!("Failed to parse ipv6 address: {}", e), + code: WolErrCode::SocketError as i32 + } + )?.octets() + }, + }; + let _scope_id= match intf_name == ANY_INTERFACE { + true => 0, + false => unsafe { libc::if_nametoindex(CString::new(intf_name).map_err(|_| + WolErr{ + msg: String::from("Invalid interface name for binding"), + code: WolErrCode::SocketError as i32 + } + )?.as_ptr()) as u32 + } + }; + Ok( + libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as u16, + sin6_port: port.to_be(), + sin6_flowinfo: 0, + sin6_addr: _addr, + sin6_scope_id: _scope_id.to_be(), + } + ) +} + +fn assert_return_code_is_zero(rc: i32, msg: &str, err_code: WolErrCode) -> Result<(), WolErr> { + if rc != 0 { + Err(WolErr { + msg: format!("{}, rc={},error: {}", msg, rc, io::Error::last_os_error()), + code: err_code as i32, + }) + } else { + Ok(()) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ipv4_addr() { + let port = 1234; + let addr = ipv4_addr(port, IPV4_ANY_ADDR).unwrap(); + assert_eq!(addr.sin_family , libc::AF_INET as u16); + assert_eq!(addr.sin_port.to_le() , port.to_be()); + assert_eq!(addr.sin_addr.s_addr , libc::INADDR_ANY); + assert_eq!(addr.sin_zero , [0; 8]); + + let ip = "1.1.1.1"; + let addr = ipv4_addr(port, &ip).unwrap(); + assert_eq!(addr.sin_family , libc::AF_INET as u16); + assert_eq!(addr.sin_port.to_le() , port.to_be()); + assert_eq!(addr.sin_addr.s_addr , u32::from(Ipv4Addr::from_str(ip).unwrap()).to_be()); + assert_eq!(addr.sin_zero , [0; 8]); + } + + #[test] + fn test_ipv6_addr() { + let port = 1234; + let addr = ipv6_addr(port, IPV6_ANY_ADDR, ANY_INTERFACE).unwrap(); + assert_eq!(addr.sin6_family , libc::AF_INET6 as u16); + assert_eq!(addr.sin6_port.to_le() , port.to_be()); + assert_eq!(addr.sin6_flowinfo , 0); + assert_eq!(addr.sin6_addr.s6_addr , libc::IN6ADDR_ANY_INIT.s6_addr); + assert_eq!(addr.sin6_scope_id , 0); + + let ip = "2001:db8::1"; + let addr = ipv6_addr(port, &ip, ANY_INTERFACE).unwrap(); + assert_eq!(addr.sin6_family , libc::AF_INET6 as u16); + assert_eq!(addr.sin6_port.to_le() , port.to_be()); + assert_eq!(addr.sin6_flowinfo , 0); + assert_eq!(addr.sin6_addr.s6_addr , Ipv6Addr::from_str(ip).unwrap().octets()); + assert_eq!(addr.sin6_scope_id , 0); + } + + #[test] + fn test_assert_return_code_is_zero() { + let rc = 0; + assert_eq!(assert_return_code_is_zero(rc, "", WolErrCode::InternalError).is_ok(), true); + + let rc = -1; + let msg = "test"; + let err_code = WolErrCode::InternalError; + let result = assert_return_code_is_zero(rc, msg, err_code); + assert_eq!(result.is_err(), true); + assert_eq!(result.as_ref().unwrap_err().code, WolErrCode::InternalError as i32); + assert_eq!(result.unwrap_err().msg, format!("{}, rc=-1,error: {}", msg, io::Error::last_os_error())); + } +} \ No newline at end of file diff --git a/src/sonic-nettools/wol/src/wol.rs b/src/sonic-nettools/wol/src/wol.rs index b5d598d369f3..ff04c017cb78 100644 --- a/src/sonic-nettools/wol/src/wol.rs +++ b/src/sonic-nettools/wol/src/wol.rs @@ -1,15 +1,26 @@ use clap::builder::ArgPredicate; use clap::Parser; -use pnet::datalink::Channel::Ethernet; -use pnet::datalink::{self, DataLinkSender, MacAddr, NetworkInterface}; +use pnet::datalink; use std::fs::read_to_string; +use std::net::IpAddr; use std::result::Result; use std::str::FromStr; +use std::sync::Mutex; use std::thread; use std::time::Duration; +use crate::socket::{WolSocket, RawSocket, UdpSocket}; + const BROADCAST_MAC: [u8; 6] = [0xff, 0xff, 0xff, 0xff, 0xff, 0xff]; +pub static VERBOSE_OUTPUT: Mutex = Mutex::new(false); + +pub fn vprintln(msg: String) { + if *VERBOSE_OUTPUT.lock().unwrap() { + println!("{}", msg); + } +} + #[derive(Parser, Debug)] #[command( next_line_help = true, @@ -20,7 +31,13 @@ Examples: wol Ethernet10 00:11:22:33:44:55 wol Ethernet10 00:11:22:33:44:55 -b wol Vlan1000 00:11:22:33:44:55,11:33:55:77:99:bb -p 00:22:44:66:88:aa - wol Vlan1000 00:11:22:33:44:55,11:33:55:77:99:bb -p 192.168.1.1 -c 3 -i 2000" + wol Vlan1000 00:11:22:33:44:55,11:33:55:77:99:bb -p 192.168.1.1 -c 3 -i 2000 + wol Ethernet10 00:11:22:33:44:55,11:33:55:77:99:bb -u + wol Vlan1000 00:11:22:33:44:55,11:33:55:77:99:bb -u -c 3 -i 2000 + wol Vlan1000 00:11:22:33:44:55,11:33:55:77:99:bb -u -a 192.168.255.255 + wol Vlan1000 00:11:22:33:44:55,11:33:55:77:99:bb -u -a ffff::ffff + wol Vlan1000 00:11:22:33:44:55,11:33:55:77:99:bb -u -a +" )] struct WolArgs { /// The name of the network interface to send the magic packet through @@ -30,29 +47,31 @@ struct WolArgs { target_mac: String, /// The flag to indicate if use broadcast MAC address instead of target device's MAC address as Destination MAC Address in Ethernet Frame Header [default: false] - #[arg(short, long, default_value_t = false)] + #[arg(short, long, default_value_t = false, conflicts_with("udp"))] broadcast: bool, + /// The flag to indicate if send udp packet [default: false] + #[arg(short, long, default_value_t = false, conflicts_with("broadcast"))] + udp: bool, + + /// The destination ip address, both IPv4 address and IPv6 address are supported + #[arg(short = 'a', long, default_value_t = String::from("255.255.255.255"), requires_if(ArgPredicate::IsPresent, "udp"))] + ip_address: String, + + /// The destination udp port. + #[arg(short = 't', long, default_value_t = 9, requires_if(ArgPredicate::IsPresent, "udp"))] + udp_port: u16, + /// An optional 4 or 6 byte password, in ethernet hex format or quad-dotted decimal (e.g. "127.0.0.1" or "00:11:22:33:44:55") #[arg(short, long, value_parser = parse_password)] password: Option, - /// For each target MAC address, the count of magic packets to send. count must between 1 and 5. This param must use with -i. [default: 1] - #[arg( - short, - long, - default_value_t = 1, - requires_if(ArgPredicate::IsPresent, "interval") - )] + /// For each target MAC address, the count of magic packets to send. count must between 1 and 5. This param must use with -i. + #[arg(short, long, default_value_t = 1, requires_if(ArgPredicate::IsPresent, "interval"))] count: u8, - /// Wait interval milliseconds between sending each magic packet. interval must between 0 and 2000. This param must use with -c. [default: 0] - #[arg( - short, - long, - default_value_t = 0, - requires_if(ArgPredicate::IsPresent, "count") - )] + /// Wait interval milliseconds between sending each magic packet. interval must between 0 and 2000. This param must use with -c. + #[arg(short, long, default_value_t = 0, requires_if(ArgPredicate::IsPresent, "count"))] interval: u64, /// The flag to indicate if we should print verbose output @@ -83,42 +102,31 @@ impl std::fmt::Display for WolErr { } } -enum WolErrCode { +pub enum WolErrCode { SocketError = 1, InvalidArguments = 2, - UnknownError = 999, + InternalError = 255, } pub fn build_and_send() -> Result<(), WolErr> { let args = WolArgs::parse(); let target_macs = parse_target_macs(&args)?; valide_arguments(&args)?; + *VERBOSE_OUTPUT.lock().unwrap() = args.verbose; let src_mac = get_interface_mac(&args.interface)?; - let mut tx = open_tx_channel(&args.interface)?; + let socket = create_wol_socket(&args)?; for target_mac in target_macs { - if args.verbose { - println!( + vprintln(format!( "Building and sending packet to target mac address {}", target_mac .iter() .map(|b| format!("{:02X}", b)) .collect::>() .join(":") - ); - } - let dst_mac = if args.broadcast { - BROADCAST_MAC - } else { - target_mac - }; - let magic_bytes = build_magic_packet(&src_mac, &dst_mac, &target_mac, &args.password)?; - send_magic_packet( - &mut tx, - magic_bytes, - &args.count, - &args.interval, - &args.verbose, - )?; + ) + ); + let magic_bytes = build_magic_bytes(&args, &src_mac, &target_mac, &args.password)?; + send_magic_packet(socket.as_ref(), magic_bytes, &args.count, &args.interval)?; } Ok(()) @@ -149,11 +157,18 @@ fn valide_arguments(args: &WolArgs) -> Result<(), WolErr> { }); } + if IpAddr::from_str(&args.ip_address).is_err() { + return Err(WolErr { + msg: String::from("Invalid ip address"), + code: WolErrCode::InvalidArguments as i32, + }); + } + Ok(()) } fn parse_mac_addr(mac_str: &str) -> Result<[u8; 6], WolErr> { - MacAddr::from_str(mac_str) + datalink::MacAddr::from_str(mac_str) .map(|mac| mac.octets()) .map_err(|_| WolErr { msg: String::from("Invalid MAC address"), @@ -230,14 +245,14 @@ fn is_ipv4_address_valid(ipv4_str: &str) -> bool { fn get_interface_mac(interface_name: &String) -> Result<[u8; 6], WolErr> { if let Some(interface) = datalink::interfaces() .into_iter() - .find(|iface: &NetworkInterface| iface.name == *interface_name) + .find(|iface: &datalink::NetworkInterface| iface.name == *interface_name) { if let Some(mac) = interface.mac { Ok(mac.octets()) } else { Err(WolErr { msg: String::from("Could not get MAC address of target interface"), - code: WolErrCode::UnknownError as i32, + code: WolErrCode::InternalError as i32, }) } } else { @@ -248,91 +263,76 @@ fn get_interface_mac(interface_name: &String) -> Result<[u8; 6], WolErr> { } } -fn build_magic_packet( +fn build_magic_bytes( + args: &WolArgs, src_mac: &[u8; 6], - dst_mac: &[u8; 6], target_mac: &[u8; 6], password: &Option, ) -> Result, WolErr> { let password_len = password.as_ref().map_or(0, |p| p.ref_bytes().len()); - let mut pkt = vec![0u8; 116 + password_len]; - pkt[0..6].copy_from_slice(dst_mac); - pkt[6..12].copy_from_slice(src_mac); - pkt[12..14].copy_from_slice(&[0x08, 0x42]); - pkt[14..20].copy_from_slice(&[0xff; 6]); - pkt[20..116].copy_from_slice(&target_mac.repeat(16)); + let mut mbs = vec![0u8; 102 + password_len]; + mbs[0..6].copy_from_slice(&[0xff; 6]); + mbs[6..102].copy_from_slice(&target_mac.repeat(16)); if let Some(p) = password { - pkt[116..116 + password_len].copy_from_slice(p.ref_bytes()); + mbs[102..102 + password_len].copy_from_slice(p.ref_bytes()); + } + if !args.udp { + let mut _ether_header = vec![0u8; 14]; + _ether_header[0..6].copy_from_slice( if args.broadcast { &BROADCAST_MAC } else { target_mac }); + _ether_header[6..12].copy_from_slice(src_mac); + _ether_header[12..14].copy_from_slice(&[0x08, 0x42]); // EtherType for WOL + mbs.splice(0..0, _ether_header); } - Ok(pkt) + Ok(mbs) } fn send_magic_packet( - tx: &mut Box, - packet: Vec, + socket: &dyn WolSocket, + payload: Vec, count: &u8, - interval: &u64, - verbose: &bool, -) -> Result<(), WolErr> { + interval: &u64 +) -> Result<(), WolErr> +{ for nth in 0..*count { - match tx.send_to(&packet, None) { - Some(Ok(_)) => {} - Some(Err(e)) => { + match socket.send_magic_packet(&payload) { + Ok(_) => {} + Err(e) => { return Err(WolErr { msg: format!("Network is down: {}", e), code: WolErrCode::SocketError as i32, }); } - None => { - return Err(WolErr { - msg: String::from("Network is down"), - code: WolErrCode::SocketError as i32, - }); - } } - if *verbose { - println!( + + vprintln( + format!( " | -> Sent the {}th packet and sleep for {} seconds", &nth + 1, &interval - ); - println!( - " | -> Packet bytes in hex {}", - &packet + ) + ); + vprintln( + format!( + " | -> paylod bytes in hex {}", + &payload .iter() .fold(String::new(), |acc, b| acc + &format!("{:02X}", b)) ) - } + ); thread::sleep(Duration::from_millis(*interval)); } Ok(()) } -fn open_tx_channel(interface: &str) -> Result, WolErr> { - if let Some(interface) = datalink::interfaces() - .into_iter() - .find(|iface: &NetworkInterface| iface.name == interface) - { - match datalink::channel(&interface, Default::default()) { - Ok(Ethernet(tx, _)) => Ok(tx), - Ok(_) => Err(WolErr { - msg: String::from("Network is down"), - code: WolErrCode::SocketError as i32, - }), - Err(e) => Err(WolErr { - msg: format!("Network is down: {}", e), - code: WolErrCode::SocketError as i32, - }), - } + +fn create_wol_socket(args: &WolArgs) -> Result, WolErr> { + let _socket: Box = if args.udp { + Box::new(UdpSocket::new(&args.interface, args.udp_port, &args.ip_address)?) } else { - Err(WolErr { - msg: format!( - "Invalid value for \"INTERFACE\": interface {} is not up", - interface - ), - code: WolErrCode::InvalidArguments as i32, - }) - } + Box::new(RawSocket::new(&args.interface)?) + }; + + Ok(_socket) } #[cfg(test)] @@ -460,6 +460,9 @@ mod tests { interface: "Ethernet10".to_string(), target_mac: "00:11:22:33:44:55".to_string(), broadcast: false, + udp: false, + ip_address: String::from(""), + udp_port: 9, password: None, count: 1, interval: 0, @@ -509,12 +512,13 @@ mod tests { } #[test] - fn test_build_magic_packet() { + fn test_build_magic_bytes() { + let args = WolArgs::try_parse_from(&["wol", "dontcare", "dontcare"]).unwrap(); let src_mac = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55]; let target_mac = [0xff, 0xff, 0xff, 0xff, 0xff, 0xff]; let four_bytes_password = Some(Password(vec![0x00, 0x11, 0x22, 0x33])); let magic_packet = - build_magic_packet(&src_mac, &target_mac, &target_mac, &four_bytes_password).unwrap(); + build_magic_bytes(&args, &src_mac, &target_mac, &four_bytes_password).unwrap(); assert_eq!(magic_packet.len(), 120); assert_eq!(&magic_packet[0..6], &target_mac); assert_eq!(&magic_packet[6..12], &src_mac); @@ -524,7 +528,7 @@ mod tests { assert_eq!(&magic_packet[116..120], &[0x00, 0x11, 0x22, 0x33]); let six_bytes_password = Some(Password(vec![0x00, 0x11, 0x22, 0x33, 0x44, 0x55])); let magic_packet = - build_magic_packet(&src_mac, &target_mac, &target_mac, &six_bytes_password).unwrap(); + build_magic_bytes(&args, &src_mac, &target_mac, &six_bytes_password).unwrap(); assert_eq!(magic_packet.len(), 122); assert_eq!(&magic_packet[0..6], &target_mac); assert_eq!(&magic_packet[6..12], &src_mac); @@ -538,13 +542,23 @@ mod tests { } #[test] - fn test_build_magic_packet_without_password() { + fn test_build_magic_bytes_without_password() { + let args = WolArgs::try_parse_from(&["wol", "dontcare", "dontcare"]).unwrap(); let src_mac = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55]; - let dst_mac = [0xff, 0xff, 0xff, 0xff, 0xff, 0xff]; let target_mac = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]; - let magic_packet = build_magic_packet(&src_mac, &dst_mac, &target_mac, &None).unwrap(); + let magic_packet = build_magic_bytes(&args, &src_mac, &target_mac, &None).unwrap(); assert_eq!(magic_packet.len(), 116); - assert_eq!(&magic_packet[0..6], &dst_mac); + assert_eq!(&magic_packet[0..6], &target_mac); + assert_eq!(&magic_packet[6..12], &src_mac); + assert_eq!(&magic_packet[12..14], &[0x08, 0x42]); + assert_eq!(&magic_packet[14..20], &[0xff; 6]); + assert_eq!(&magic_packet[20..116], target_mac.repeat(16)); + let args = WolArgs::try_parse_from(&["wol", "dontcare", "dontcare", "-b"]).unwrap(); + let src_mac = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55]; + let target_mac = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]; + let magic_packet = build_magic_bytes(&args, &src_mac, &target_mac, &None).unwrap(); + assert_eq!(magic_packet.len(), 116); + assert_eq!(&magic_packet[0..6], BROADCAST_MAC); assert_eq!(&magic_packet[6..12], &src_mac); assert_eq!(&magic_packet[12..14], &[0x08, 0x42]); assert_eq!(&magic_packet[14..20], &[0xff; 6]); @@ -679,9 +693,59 @@ mod tests { "error: the following required arguments were not provided:\n --interval \n\nUsage: wol --count --interval \n\nFor more information, try '--help'.\n" ); // Verbose can be set - let args = - WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-b", "--verbose"]) - .unwrap(); + let args = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-b", "--verbose"]).unwrap(); assert_eq!(args.verbose, true); + // Ip address should be valid + let args = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-b", "-a", "xxx"]); + let result = valide_arguments(&args.unwrap()); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Error: Invalid ip address" + ); + // Udp port should be in 0-65535 + let args = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-u", "-t", "65535"]) + .unwrap(); + assert_eq!(args.udp_port, 65535); + let result = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-u", "-t", "65536"]); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "error: invalid value '65536' for '--udp-port ': 65536 is not in 0..=65535\n\nFor more information, try '--help'.\n" + ); + // Udp port should be specified with udp flag + let result = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-t", "9"]); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "error: the following required arguments were not provided:\n --udp\n\nUsage: wol --udp --udp-port \n\nFor more information, try '--help'.\n" + ); + // Ip address should be specified with udp flag + let result = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-a", "192.168.1.1"]); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "error: the following required arguments were not provided:\n --udp\n\nUsage: wol --udp --ip-address \n\nFor more information, try '--help'.\n" + ); + // Broadcast and udp flags are mutually exclusive + let result = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05", "-b", "-u"]); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "error: the argument '--broadcast' cannot be used with '--udp'\n\nUsage: wol --broadcast \n\nFor more information, try '--help'.\n" + ); + } + + #[test] + fn verify_args_default_value(){ + let args = WolArgs::try_parse_from(&["wol", "eth0", "00:01:02:03:04:05"]).unwrap(); + assert_eq!(args.broadcast, false); + assert_eq!(args.udp, false); + assert_eq!(args.ip_address, "255.255.255.255"); + assert_eq!(args.udp_port, 9); + assert_eq!(args.password.is_none(), true); + assert_eq!(args.count, 1); + assert_eq!(args.interval, 0); + assert_eq!(args.verbose, false); } }