Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ipv6: rewrite core ipv6 methods to operate on u128s #187

Merged
merged 3 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{error::Error, fmt};
use std::{error::Error, fmt, net::AddrParseError};

use crate::error::IpNetworkError::*;

Expand All @@ -9,7 +9,7 @@ pub enum IpNetworkError {
InvalidAddr(String),
InvalidPrefix,
InvalidCidrFormat(String),
NetworkSizeError(NetworkSizeError)
NetworkSizeError(NetworkSizeError),
}

impl fmt::Display for IpNetworkError {
Expand All @@ -18,7 +18,7 @@ impl fmt::Display for IpNetworkError {
InvalidAddr(ref s) => write!(f, "invalid address: {s}"),
InvalidPrefix => write!(f, "invalid prefix"),
InvalidCidrFormat(ref s) => write!(f, "invalid cidr format: {s}"),
NetworkSizeError(ref e) => write!(f, "network size error: {e}")
NetworkSizeError(ref e) => write!(f, "network size error: {e}"),
}
}
}
Expand All @@ -29,16 +29,22 @@ impl Error for IpNetworkError {
InvalidAddr(_) => "address is invalid",
InvalidPrefix => "prefix is invalid",
InvalidCidrFormat(_) => "cidr is invalid",
NetworkSizeError(_) => "network size error"
NetworkSizeError(_) => "network size error",
}
}
}

impl From<AddrParseError> for IpNetworkError {
fn from(e: AddrParseError) -> Self {
InvalidAddr(e.to_string())
}
}

/// Cannot convert an IPv6 network size to a u32 as it is a 128-bit value.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum NetworkSizeError {
NetworkIsTooLarge
NetworkIsTooLarge,
}

impl fmt::Display for NetworkSizeError {
Expand Down
39 changes: 31 additions & 8 deletions src/ipv4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ impl Ipv4Network {

/// Constructs without checking prefix a new `Ipv4Network` from any `Ipv4Addr,
/// and a prefix denoting the network size.
///
/// # Safety
///
/// The caller must ensure that the prefix is less than or equal to 32.
///
/// # Examples
///
/// ```
/// use std::net::Ipv4Addr;
/// use ipnetwork::Ipv4Network;
///
/// let prefix = 24;
/// let addr = Ipv4Addr::new(192, 168, 1, 1);
///
/// debug_assert!(prefix <= 32);
/// let network = unsafe { Ipv4Network::new_unchecked(addr, prefix) };
/// ```
pub const unsafe fn new_unchecked(addr: Ipv4Addr, prefix: u8) -> Ipv4Network {
Ipv4Network { addr, prefix }
}
Expand Down Expand Up @@ -128,8 +145,9 @@ impl Ipv4Network {
/// Checks if the given `Ipv4Network` is partly contained in other.
pub fn overlaps(self, other: Ipv4Network) -> bool {
other.contains(self.ip())
|| (other.contains(self.broadcast())
|| (self.contains(other.ip()) || (self.contains(other.broadcast()))))
|| other.contains(self.broadcast())
|| self.contains(other.ip())
|| self.contains(other.broadcast())
}

/// Returns the mask for this `Ipv4Network`.
Expand All @@ -147,9 +165,11 @@ impl Ipv4Network {
/// assert_eq!(net.mask(), Ipv4Addr::new(255, 255, 0, 0));
/// ```
pub fn mask(&self) -> Ipv4Addr {
let mask = !(0xffff_ffff_u64 >> u64::from(self.prefix)) as u32;
debug_assert!(self.prefix <= 32);

let mask = u32::MAX << (IPV4_BITS - self.prefix);
Ipv4Addr::from(mask)
}
}

/// Returns the address of the network denoted by this `Ipv4Network`.
/// This means the lowest possible IPv4 address inside of the network.
Expand Down Expand Up @@ -201,6 +221,8 @@ impl Ipv4Network {
/// ```
#[inline]
pub fn contains(&self, ip: Ipv4Addr) -> bool {
debug_assert!(self.prefix <= IPV4_BITS);

let mask = !(0xffff_ffff_u64 >> self.prefix) as u32;
let net = u32::from(self.addr) & mask;
(u32::from(ip) & mask) == net
Expand All @@ -221,8 +243,9 @@ impl Ipv4Network {
/// assert_eq!(tinynet.size(), 1);
/// ```
pub fn size(self) -> u32 {
1 << (u32::from(IPV4_BITS - self.prefix))
}
debug_assert!(self.prefix <= 32);
1 << (IPV4_BITS - self.prefix)
}

/// Returns the `n`:th address within this network.
/// The adresses are indexed from 0 and `n` must be smaller than the size of the network.
Expand Down Expand Up @@ -274,8 +297,7 @@ impl FromStr for Ipv4Network {
type Err = IpNetworkError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (addr_str, prefix_str) = cidr_parts(s)?;
let addr = Ipv4Addr::from_str(addr_str)
.map_err(|_| IpNetworkError::InvalidAddr(addr_str.to_string()))?;
let addr = Ipv4Addr::from_str(addr_str)?;
let prefix = match prefix_str {
Some(v) => {
if let Ok(netmask) = Ipv4Addr::from_str(v) {
Expand Down Expand Up @@ -458,6 +480,7 @@ mod test {
}

#[test]
#[allow(dropping_copy_types)]
fn copy_compatibility_v4() {
let net = Ipv4Network::new(Ipv4Addr::new(127, 0, 0, 1), 16).unwrap();
mem::drop(net);
Expand Down
143 changes: 75 additions & 68 deletions src/ipv6.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::error::IpNetworkError;
use crate::parse::{cidr_parts, parse_prefix};
use std::{cmp, convert::TryFrom, fmt, net::Ipv6Addr, str::FromStr};
use std::{convert::TryFrom, fmt, net::Ipv6Addr, str::FromStr};

const IPV6_BITS: u8 = 128;
const IPV6_SEGMENT_BITS: u8 = 16;
Expand Down Expand Up @@ -87,6 +87,23 @@ impl Ipv6Network {

/// Constructs without checking prefix a new `Ipv6Network` from any `Ipv6Addr,
/// and a prefix denoting the network size.
///
/// # Safety
///
/// The caller must ensure that the prefix is less than or equal to 32.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let prefix = 64;
/// let addr = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0);
///
/// debug_assert!(prefix <= 128);
/// let net = unsafe { Ipv6Network::new_unchecked(addr, prefix) };
/// ```
pub const unsafe fn new_unchecked(addr: Ipv6Addr, prefix: u8) -> Ipv6Network {
Ipv6Network { addr, prefix }
}
Expand All @@ -106,6 +123,10 @@ impl Ipv6Network {
/// Returns an iterator over `Ipv6Network`. Each call to `next` will return the next
/// `Ipv6Addr` in the given network. `None` will be returned when there are no more
/// addresses.
///
/// # Warning
///
/// This can return up to 2^128 addresses, which will take a _long_ time to iterate over.
pub fn iter(&self) -> Ipv6NetworkIterator {
let dec = u128::from(self.addr);
let max = u128::max_value();
Expand All @@ -123,42 +144,6 @@ impl Ipv6Network {
}
}

/// Returns the address of the network denoted by this `Ipv6Network`.
/// This means the lowest possible IPv6 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.network(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0));
/// ```
pub fn network(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let ip = u128::from(self.addr) & mask;
Ipv6Addr::from(ip)
}

/// Returns the broadcast address of this `Ipv6Network`.
/// This means the highest possible IPv4 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.broadcast(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0xffff, 0xffff));
/// ```
pub fn broadcast(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let broadcast = u128::from(self.addr) | !mask;
Ipv6Addr::from(broadcast)
}

pub fn ip(&self) -> Ipv6Addr {
self.addr
}
Expand All @@ -180,8 +165,9 @@ impl Ipv6Network {
/// Checks if the given `Ipv6Network` is partly contained in other.
pub fn overlaps(self, other: Ipv6Network) -> bool {
other.contains(self.ip())
|| (other.contains(self.broadcast())
|| (self.contains(other.ip()) || (self.contains(other.broadcast()))))
|| other.contains(self.broadcast())
|| self.contains(other.ip())
|| self.contains(other.broadcast())
}

/// Returns the mask for this `Ipv6Network`.
Expand All @@ -199,17 +185,47 @@ impl Ipv6Network {
/// assert_eq!(net.mask(), Ipv6Addr::new(0xffff, 0xffff, 0, 0, 0, 0, 0, 0));
/// ```
pub fn mask(&self) -> Ipv6Addr {
let mut segments = [0; 16];
for (i, chunk) in segments.chunks_mut(2).enumerate() {
let bits_remaining = self.prefix.saturating_sub(i as u8 * 16);
let set_bits = cmp::min(bits_remaining, 16);
let mask = !(0xffff >> set_bits) as u16;
chunk[0] = (mask >> 8) as u8;
chunk[1] = mask as u8;
}
Ipv6Addr::from(segments)
debug_assert!(self.prefix <= IPV6_BITS);

let mask = u128::MAX << (IPV6_BITS - self.prefix);
Ipv6Addr::from(mask)
}

/// Returns the address of the network denoted by this `Ipv6Network`.
/// This means the lowest possible IPv6 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.network(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0));
/// ```
pub fn network(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let network = u128::from(self.addr) & mask;
Ipv6Addr::from(network)
}

/// Returns the broadcast address of this `Ipv6Network`.
/// This means the highest possible IPv4 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.broadcast(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0xffff, 0xffff));
/// ```
pub fn broadcast(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let broadcast = u128::from(self.addr) | !mask;
Ipv6Addr::from(broadcast)
}


/// Checks if a given `Ipv6Addr` is in this `Ipv6Network`
///
Expand All @@ -225,14 +241,10 @@ impl Ipv6Network {
/// ```
#[inline]
pub fn contains(&self, ip: Ipv6Addr) -> bool {
let a = self.addr.segments();
let b = ip.segments();
let addrs = Iterator::zip(a.iter(), b.iter());
self.mask()
.segments()
.iter()
.zip(addrs)
.all(|(mask, (a, b))| a & mask == b & mask)
let ip = u128::from(ip);
let net = u128::from(self.network());
let mask = u128::from(self.mask());
(ip & mask) == net
}

/// Returns number of possible host addresses in this `Ipv6Network`.
Expand All @@ -250,12 +262,12 @@ impl Ipv6Network {
/// assert_eq!(tinynet.size(), 1);
/// ```
pub fn size(&self) -> u128 {
let host_bits = u32::from(IPV6_BITS - self.prefix);
2u128.pow(host_bits)
debug_assert!(self.prefix <= IPV6_BITS);
1 << (IPV6_BITS - self.prefix)
}

/// Returns the `n`:th address within this network.
/// The adresses are indexed from 0 and `n` must be smaller than the size of the network.
/// The addresses are indexed from 0 and `n` must be smaller than the size of the network.
///
/// # Examples
///
Expand Down Expand Up @@ -296,14 +308,12 @@ impl FromStr for Ipv6Network {
type Err = IpNetworkError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (addr_str, prefix_str) = cidr_parts(s)?;
let addr = Ipv6Addr::from_str(addr_str).map_err(|e| IpNetworkError::InvalidAddr(e.to_string()))?;
let addr = Ipv6Addr::from_str(addr_str)?;
let prefix = parse_prefix(prefix_str.unwrap_or(&IPV6_BITS.to_string()), IPV6_BITS)?;
Ipv6Network::new(addr, prefix)
}
}



impl TryFrom<&str> for Ipv6Network {
type Error = IpNetworkError;

Expand Down Expand Up @@ -694,7 +704,7 @@ mod test {
let other: Ipv6Network = "2001:DB8:ACAD::1/64".parse().unwrap();
let other2: Ipv6Network = "2001:DB8:ACAD::20:2/64".parse().unwrap();

assert_eq!(other2.overlaps(other), true);
assert!(other2.overlaps(other));
}

#[test]
Expand Down Expand Up @@ -726,9 +736,6 @@ mod test {
net.nth(65538).unwrap(),
Ipv6Addr::from_str("ff01::1:2").unwrap()
);
assert_eq!(
net.nth(net.size()),
None
);
assert_eq!(net.nth(net.size()), None);
}
}
Loading