diff --git a/src/adapter.rs b/src/adapter.rs index ee3db75..2adda06 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -197,13 +197,16 @@ impl Adapter { /// Set `MTU` of this adapter pub fn set_mtu(&self, mtu: usize) -> Result<(), Error> { let name = self.get_name()?; - Ok(util::set_adapter_mtu(&name, mtu)?) + util::set_adapter_mtu(&name, mtu, false)?; + // FIXME: Here we set the IPv6 MTU as well for consistency, but for some users it may not be expected. + util::set_adapter_mtu(&name, mtu, true)?; + Ok(()) } /// Returns `MTU` of this adapter pub fn get_mtu(&self) -> Result { - let luid = self.get_luid(); - Ok(util::get_adapter_mtu(&luid)?) + // FIXME: Here we get the IPv4 MTU only, but for some users it may not be expected. + Ok(util::get_mtu_by_index(self.index, false)? as _) } /// Returns the Win32 interface index of this adapter. Useful for specifying the interface diff --git a/src/util.rs b/src/util.rs index 3f54558..53f5d64 100644 --- a/src/util.rs +++ b/src/util.rs @@ -9,12 +9,13 @@ use windows_sys::{ }, NetworkManagement::{ IpHelper::{ - FreeMibTable, GetAdaptersAddresses, GetIfTable2, GetInterfaceInfo, DNS_INTERFACE_SETTINGS, + FreeMibTable, GetAdaptersAddresses, GetInterfaceInfo, DNS_INTERFACE_SETTINGS, DNS_INTERFACE_SETTINGS_VERSION1, DNS_SETTING_NAMESERVER, GAA_FLAG_INCLUDE_GATEWAYS, GAA_FLAG_INCLUDE_PREFIX, IF_TYPE_ETHERNET_CSMACD, IF_TYPE_IEEE80211, IP_ADAPTER_ADDRESSES_LH, - IP_ADAPTER_INDEX_MAP, IP_INTERFACE_INFO, MIB_IF_ROW2, MIB_IF_TABLE2, + IP_ADAPTER_INDEX_MAP, IP_INTERFACE_INFO, }, - Ndis::{IfOperStatusUp, NET_LUID_LH}, + IpHelper::{GetIpInterfaceTable, MIB_IPINTERFACE_ROW, MIB_IPINTERFACE_TABLE}, + Ndis::IfOperStatusUp, }, Networking::WinSock::{AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6, SOCKET_ADDRESS}, System::{ @@ -424,19 +425,20 @@ pub(crate) fn get_os_error_from_id(id: i32) -> std::io::Result<()> { } } -pub fn set_adapter_mtu(name: &str, mtu: usize) -> std::io::Result<()> { - if let Err(e) = set_adapter_mtu_cmd(name, mtu) { +pub fn set_adapter_mtu(name: &str, mtu: usize, is_ipv6: bool) -> std::io::Result<()> { + if let Err(e) = set_adapter_mtu_cmd(name, mtu, is_ipv6) { log::error!("Failed to set MTU for adapter: {}", e); set_adapter_mtu_api(name, mtu)?; } Ok(()) } -pub fn set_adapter_mtu_cmd(name: &str, mtu: usize) -> std::io::Result<()> { +pub fn set_adapter_mtu_cmd(name: &str, mtu: usize, is_ipv6: bool) -> std::io::Result<()> { // command line: `netsh interface ipv4 set subinterface "MyAdapter" mtu=1500` + let ip_str = if is_ipv6 { "ipv6" } else { "ipv4" }; let args = &[ "interface", - "ipv4", + ip_str, "set", "subinterface", &format!("\"{}\"", name), @@ -494,6 +496,9 @@ pub fn run_command(command: &str, args: &[&str]) -> std::io::Result> { Ok(out.stdout) } +/* +use windows_sys::Win32::NetworkManagement::IpHelper::{GetIfTable2, MIB_IF_ROW2, MIB_IF_TABLE2}; +use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH; pub(crate) fn get_adapter_mtu(luid: &NET_LUID_LH) -> std::io::Result { unsafe { let mut if_table: *mut MIB_IF_TABLE2 = std::ptr::null_mut(); @@ -524,12 +529,55 @@ pub(crate) fn get_adapter_mtu(luid: &NET_LUID_LH) -> std::io::Result { mtu.ok_or(std::io::Error::new(std::io::ErrorKind::NotFound, "Adapter not found")) } } +// */ + +pub(crate) fn get_mtu_by_index(index: u32, is_ipv6: bool) -> std::io::Result { + let mut mtu = None; + get_ip_interface_table( + |item| { + if item.InterfaceIndex == index { + mtu = Some(item.NlMtu); + } + Ok(()) + }, + is_ipv6, + )?; + let Some(mtu) = mtu else { + return Err(std::io::Error::from(std::io::ErrorKind::NotFound)); + }; + Ok(mtu) +} pub fn decode_utf16(string: &[u16]) -> String { let end = string.iter().position(|b| *b == 0).unwrap_or(string.len()); String::from_utf16_lossy(&string[..end]) } +pub fn get_ip_interface_table(mut callback: F, is_ipv6: bool) -> std::io::Result<()> +where + F: FnMut(&MIB_IPINTERFACE_ROW) -> std::io::Result<()>, +{ + let mut if_table: *mut MIB_IPINTERFACE_TABLE = std::ptr::null_mut(); + unsafe { + if GetIpInterfaceTable(if is_ipv6 { AF_INET6 } else { AF_INET }, &mut if_table as _) != NO_ERROR { + return Err(std::io::Error::last_os_error()); + } + if if_table.is_null() { + return Err(std::io::Error::from(std::io::ErrorKind::NotFound)); + } + use std::slice::from_raw_parts; + let ifaces = from_raw_parts::(&(*if_table).Table[0], (*if_table).NumEntries as usize); + for item in ifaces { + if let Err(e) = callback(item) { + FreeMibTable(if_table as _); + return Err(e); + } + } + FreeMibTable(if_table as _); + } + Ok(()) +} + #[repr(C, align(1))] #[derive(c2rust_bitfields::BitfieldStruct)] #[allow(non_snake_case)]