From 6a99908eda215b215b44dda18691aae39406cd47 Mon Sep 17 00:00:00 2001 From: Alextopher Date: Mon, 29 Apr 2024 13:35:55 -0400 Subject: [PATCH 1/3] errors: adds `From` for `IpNetworkError` --- src/error.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/error.rs b/src/error.rs index e58781b..4c8120b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use std::{error::Error, fmt}; +use std::{error::Error, fmt, net::AddrParseError}; use crate::error::IpNetworkError::*; @@ -9,7 +9,7 @@ pub enum IpNetworkError { InvalidAddr(String), InvalidPrefix, InvalidCidrFormat(String), - NetworkSizeError(NetworkSizeError) + NetworkSizeError(NetworkSizeError), } impl fmt::Display for IpNetworkError { @@ -18,7 +18,7 @@ impl fmt::Display for IpNetworkError { InvalidAddr(ref s) => write!(f, "invalid address: {s}"), InvalidPrefix => write!(f, "invalid prefix"), InvalidCidrFormat(ref s) => write!(f, "invalid cidr format: {s}"), - NetworkSizeError(ref e) => write!(f, "network size error: {e}") + NetworkSizeError(ref e) => write!(f, "network size error: {e}"), } } } @@ -29,16 +29,22 @@ impl Error for IpNetworkError { InvalidAddr(_) => "address is invalid", InvalidPrefix => "prefix is invalid", InvalidCidrFormat(_) => "cidr is invalid", - NetworkSizeError(_) => "network size error" + NetworkSizeError(_) => "network size error", } } } +impl From for IpNetworkError { + fn from(e: AddrParseError) -> Self { + InvalidAddr(e.to_string()) + } +} + /// Cannot convert an IPv6 network size to a u32 as it is a 128-bit value. #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[non_exhaustive] pub enum NetworkSizeError { - NetworkIsTooLarge + NetworkIsTooLarge, } impl fmt::Display for NetworkSizeError { From d2890843401c36ccea922519f3b758fb65efe789 Mon Sep 17 00:00:00 2001 From: Alextopher Date: Mon, 29 Apr 2024 13:37:40 -0400 Subject: [PATCH 2/3] ipv4: adds safety comment to `new_unchecked` and debug assertions to verify compliance --- src/ipv4.rs | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/ipv4.rs b/src/ipv4.rs index 3722fd0..363b29d 100644 --- a/src/ipv4.rs +++ b/src/ipv4.rs @@ -76,6 +76,23 @@ impl Ipv4Network { /// Constructs without checking prefix a new `Ipv4Network` from any `Ipv4Addr, /// and a prefix denoting the network size. + /// + /// # Safety + /// + /// The caller must ensure that the prefix is less than or equal to 32. + /// + /// # Examples + /// + /// ``` + /// use std::net::Ipv4Addr; + /// use ipnetwork::Ipv4Network; + /// + /// let prefix = 24; + /// let addr = Ipv4Addr::new(192, 168, 1, 1); + /// + /// debug_assert!(prefix <= 32); + /// let network = unsafe { Ipv4Network::new_unchecked(addr, prefix) }; + /// ``` pub const unsafe fn new_unchecked(addr: Ipv4Addr, prefix: u8) -> Ipv4Network { Ipv4Network { addr, prefix } } @@ -128,8 +145,9 @@ impl Ipv4Network { /// Checks if the given `Ipv4Network` is partly contained in other. pub fn overlaps(self, other: Ipv4Network) -> bool { other.contains(self.ip()) - || (other.contains(self.broadcast()) - || (self.contains(other.ip()) || (self.contains(other.broadcast())))) + || other.contains(self.broadcast()) + || self.contains(other.ip()) + || (self.contains(other.broadcast())) } /// Returns the mask for this `Ipv4Network`. @@ -147,9 +165,11 @@ impl Ipv4Network { /// assert_eq!(net.mask(), Ipv4Addr::new(255, 255, 0, 0)); /// ``` pub fn mask(&self) -> Ipv4Addr { - let mask = !(0xffff_ffff_u64 >> u64::from(self.prefix)) as u32; + debug_assert!(self.prefix <= 32); + + let mask = u32::MAX << (IPV4_BITS - self.prefix); Ipv4Addr::from(mask) - } + } /// Returns the address of the network denoted by this `Ipv4Network`. /// This means the lowest possible IPv4 address inside of the network. @@ -201,6 +221,8 @@ impl Ipv4Network { /// ``` #[inline] pub fn contains(&self, ip: Ipv4Addr) -> bool { + debug_assert!(self.prefix <= IPV4_BITS); + let mask = !(0xffff_ffff_u64 >> self.prefix) as u32; let net = u32::from(self.addr) & mask; (u32::from(ip) & mask) == net @@ -221,8 +243,9 @@ impl Ipv4Network { /// assert_eq!(tinynet.size(), 1); /// ``` pub fn size(self) -> u32 { - 1 << (u32::from(IPV4_BITS - self.prefix)) - } + debug_assert!(self.prefix <= 32); + 1 << (IPV4_BITS - self.prefix) + } /// Returns the `n`:th address within this network. /// The adresses are indexed from 0 and `n` must be smaller than the size of the network. @@ -274,8 +297,7 @@ impl FromStr for Ipv4Network { type Err = IpNetworkError; fn from_str(s: &str) -> Result { let (addr_str, prefix_str) = cidr_parts(s)?; - let addr = Ipv4Addr::from_str(addr_str) - .map_err(|_| IpNetworkError::InvalidAddr(addr_str.to_string()))?; + let addr = Ipv4Addr::from_str(addr_str)?; let prefix = match prefix_str { Some(v) => { if let Ok(netmask) = Ipv4Addr::from_str(v) { @@ -458,6 +480,7 @@ mod test { } #[test] + #[allow(dropping_copy_types)] fn copy_compatibility_v4() { let net = Ipv4Network::new(Ipv4Addr::new(127, 0, 0, 1), 16).unwrap(); mem::drop(net); From f926c45c92af38746bb59aa43fc67f372d1076f1 Mon Sep 17 00:00:00 2001 From: Alextopher Date: Mon, 29 Apr 2024 13:40:32 -0400 Subject: [PATCH 3/3] ipv6: rewrite core ipv6 methods to operate on `u128`s --- src/ipv4.rs | 2 +- src/ipv6.rs | 143 +++++++++++++++++++++++++++------------------------- 2 files changed, 76 insertions(+), 69 deletions(-) diff --git a/src/ipv4.rs b/src/ipv4.rs index 363b29d..4ad0b8c 100644 --- a/src/ipv4.rs +++ b/src/ipv4.rs @@ -147,7 +147,7 @@ impl Ipv4Network { other.contains(self.ip()) || other.contains(self.broadcast()) || self.contains(other.ip()) - || (self.contains(other.broadcast())) + || self.contains(other.broadcast()) } /// Returns the mask for this `Ipv4Network`. diff --git a/src/ipv6.rs b/src/ipv6.rs index 8c37817..afe319f 100644 --- a/src/ipv6.rs +++ b/src/ipv6.rs @@ -1,6 +1,6 @@ use crate::error::IpNetworkError; use crate::parse::{cidr_parts, parse_prefix}; -use std::{cmp, convert::TryFrom, fmt, net::Ipv6Addr, str::FromStr}; +use std::{convert::TryFrom, fmt, net::Ipv6Addr, str::FromStr}; const IPV6_BITS: u8 = 128; const IPV6_SEGMENT_BITS: u8 = 16; @@ -87,6 +87,23 @@ impl Ipv6Network { /// Constructs without checking prefix a new `Ipv6Network` from any `Ipv6Addr, /// and a prefix denoting the network size. + /// + /// # Safety + /// + /// The caller must ensure that the prefix is less than or equal to 32. + /// + /// # Examples + /// + /// ``` + /// use std::net::Ipv6Addr; + /// use ipnetwork::Ipv6Network; + /// + /// let prefix = 64; + /// let addr = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0); + /// + /// debug_assert!(prefix <= 128); + /// let net = unsafe { Ipv6Network::new_unchecked(addr, prefix) }; + /// ``` pub const unsafe fn new_unchecked(addr: Ipv6Addr, prefix: u8) -> Ipv6Network { Ipv6Network { addr, prefix } } @@ -106,6 +123,10 @@ impl Ipv6Network { /// Returns an iterator over `Ipv6Network`. Each call to `next` will return the next /// `Ipv6Addr` in the given network. `None` will be returned when there are no more /// addresses. + /// + /// # Warning + /// + /// This can return up to 2^128 addresses, which will take a _long_ time to iterate over. pub fn iter(&self) -> Ipv6NetworkIterator { let dec = u128::from(self.addr); let max = u128::max_value(); @@ -123,42 +144,6 @@ impl Ipv6Network { } } - /// Returns the address of the network denoted by this `Ipv6Network`. - /// This means the lowest possible IPv6 address inside of the network. - /// - /// # Examples - /// - /// ``` - /// use std::net::Ipv6Addr; - /// use ipnetwork::Ipv6Network; - /// - /// let net: Ipv6Network = "2001:db8::/96".parse().unwrap(); - /// assert_eq!(net.network(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)); - /// ``` - pub fn network(&self) -> Ipv6Addr { - let mask = u128::from(self.mask()); - let ip = u128::from(self.addr) & mask; - Ipv6Addr::from(ip) - } - - /// Returns the broadcast address of this `Ipv6Network`. - /// This means the highest possible IPv4 address inside of the network. - /// - /// # Examples - /// - /// ``` - /// use std::net::Ipv6Addr; - /// use ipnetwork::Ipv6Network; - /// - /// let net: Ipv6Network = "2001:db8::/96".parse().unwrap(); - /// assert_eq!(net.broadcast(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0xffff, 0xffff)); - /// ``` - pub fn broadcast(&self) -> Ipv6Addr { - let mask = u128::from(self.mask()); - let broadcast = u128::from(self.addr) | !mask; - Ipv6Addr::from(broadcast) - } - pub fn ip(&self) -> Ipv6Addr { self.addr } @@ -180,8 +165,9 @@ impl Ipv6Network { /// Checks if the given `Ipv6Network` is partly contained in other. pub fn overlaps(self, other: Ipv6Network) -> bool { other.contains(self.ip()) - || (other.contains(self.broadcast()) - || (self.contains(other.ip()) || (self.contains(other.broadcast())))) + || other.contains(self.broadcast()) + || self.contains(other.ip()) + || self.contains(other.broadcast()) } /// Returns the mask for this `Ipv6Network`. @@ -199,17 +185,47 @@ impl Ipv6Network { /// assert_eq!(net.mask(), Ipv6Addr::new(0xffff, 0xffff, 0, 0, 0, 0, 0, 0)); /// ``` pub fn mask(&self) -> Ipv6Addr { - let mut segments = [0; 16]; - for (i, chunk) in segments.chunks_mut(2).enumerate() { - let bits_remaining = self.prefix.saturating_sub(i as u8 * 16); - let set_bits = cmp::min(bits_remaining, 16); - let mask = !(0xffff >> set_bits) as u16; - chunk[0] = (mask >> 8) as u8; - chunk[1] = mask as u8; - } - Ipv6Addr::from(segments) + debug_assert!(self.prefix <= IPV6_BITS); + + let mask = u128::MAX << (IPV6_BITS - self.prefix); + Ipv6Addr::from(mask) + } + + /// Returns the address of the network denoted by this `Ipv6Network`. + /// This means the lowest possible IPv6 address inside of the network. + /// + /// # Examples + /// + /// ``` + /// use std::net::Ipv6Addr; + /// use ipnetwork::Ipv6Network; + /// + /// let net: Ipv6Network = "2001:db8::/96".parse().unwrap(); + /// assert_eq!(net.network(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)); + /// ``` + pub fn network(&self) -> Ipv6Addr { + let mask = u128::from(self.mask()); + let network = u128::from(self.addr) & mask; + Ipv6Addr::from(network) + } + + /// Returns the broadcast address of this `Ipv6Network`. + /// This means the highest possible IPv4 address inside of the network. + /// + /// # Examples + /// + /// ``` + /// use std::net::Ipv6Addr; + /// use ipnetwork::Ipv6Network; + /// + /// let net: Ipv6Network = "2001:db8::/96".parse().unwrap(); + /// assert_eq!(net.broadcast(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0xffff, 0xffff)); + /// ``` + pub fn broadcast(&self) -> Ipv6Addr { + let mask = u128::from(self.mask()); + let broadcast = u128::from(self.addr) | !mask; + Ipv6Addr::from(broadcast) } - /// Checks if a given `Ipv6Addr` is in this `Ipv6Network` /// @@ -225,14 +241,10 @@ impl Ipv6Network { /// ``` #[inline] pub fn contains(&self, ip: Ipv6Addr) -> bool { - let a = self.addr.segments(); - let b = ip.segments(); - let addrs = Iterator::zip(a.iter(), b.iter()); - self.mask() - .segments() - .iter() - .zip(addrs) - .all(|(mask, (a, b))| a & mask == b & mask) + let ip = u128::from(ip); + let net = u128::from(self.network()); + let mask = u128::from(self.mask()); + (ip & mask) == net } /// Returns number of possible host addresses in this `Ipv6Network`. @@ -250,12 +262,12 @@ impl Ipv6Network { /// assert_eq!(tinynet.size(), 1); /// ``` pub fn size(&self) -> u128 { - let host_bits = u32::from(IPV6_BITS - self.prefix); - 2u128.pow(host_bits) + debug_assert!(self.prefix <= IPV6_BITS); + 1 << (IPV6_BITS - self.prefix) } /// Returns the `n`:th address within this network. - /// The adresses are indexed from 0 and `n` must be smaller than the size of the network. + /// The addresses are indexed from 0 and `n` must be smaller than the size of the network. /// /// # Examples /// @@ -296,14 +308,12 @@ impl FromStr for Ipv6Network { type Err = IpNetworkError; fn from_str(s: &str) -> Result { let (addr_str, prefix_str) = cidr_parts(s)?; - let addr = Ipv6Addr::from_str(addr_str).map_err(|e| IpNetworkError::InvalidAddr(e.to_string()))?; + let addr = Ipv6Addr::from_str(addr_str)?; let prefix = parse_prefix(prefix_str.unwrap_or(&IPV6_BITS.to_string()), IPV6_BITS)?; Ipv6Network::new(addr, prefix) } } - - impl TryFrom<&str> for Ipv6Network { type Error = IpNetworkError; @@ -694,7 +704,7 @@ mod test { let other: Ipv6Network = "2001:DB8:ACAD::1/64".parse().unwrap(); let other2: Ipv6Network = "2001:DB8:ACAD::20:2/64".parse().unwrap(); - assert_eq!(other2.overlaps(other), true); + assert!(other2.overlaps(other)); } #[test] @@ -726,9 +736,6 @@ mod test { net.nth(65538).unwrap(), Ipv6Addr::from_str("ff01::1:2").unwrap() ); - assert_eq!( - net.nth(net.size()), - None - ); + assert_eq!(net.nth(net.size()), None); } }