diff --git a/Cargo.lock b/Cargo.lock index 763a8acf..231b5337 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -507,6 +507,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "either" version = "1.9.0" @@ -591,6 +597,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "futures-channel" version = "0.3.30" @@ -1030,6 +1042,33 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mockall" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43766c2b5203b10de348ffe19f7e54564b64f3d6018ff7648d1e2d6d3a0f0a48" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cbce79ec385a1d4f54baa90a76401eb15d9cab93685f62e7e9f942aa00ae2" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "nix" version = "0.27.1" @@ -1161,6 +1200,32 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "predicates" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "pretty_assertions" version = "1.4.0" @@ -1582,6 +1647,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "test-case" version = "3.3.1" @@ -1877,6 +1948,7 @@ dependencies = [ "ipnetwork", "itertools", "maxminddb", + "mockall", "nix", "parking_lot", "paste", diff --git a/Cargo.toml b/Cargo.toml index 31f5891d..7352b76e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,6 +88,7 @@ serde_yaml = "0.9.31" tokio = { version = "1.35.1", features = [ "full" ] } tokio-util = "0.7.10" ipnetwork = "0.20.0" +mockall = "0.12.1" # see https://github.com/meh/rust-tun/pull/74 [target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dev-dependencies] diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index 8b1fa9d2..2549af5a 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -9,7 +9,7 @@ use crate::tracing::packet::icmpv4::destination_unreachable::DestinationUnreacha use crate::tracing::packet::icmpv4::echo_reply::EchoReplyPacket; use crate::tracing::packet::icmpv4::echo_request::EchoRequestPacket; use crate::tracing::packet::icmpv4::time_exceeded::TimeExceededPacket; -use crate::tracing::packet::icmpv4::{IcmpCode, IcmpPacket, IcmpType}; +use crate::tracing::packet::icmpv4::{IcmpCode, IcmpPacket, IcmpTimeExceededCode, IcmpType}; use crate::tracing::packet::ipv4::Ipv4Packet; use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; @@ -352,23 +352,32 @@ fn extract_probe_resp( let recv = SystemTime::now(); let src = IpAddr::V4(ipv4.get_source()); let icmp_v4 = IcmpPacket::new_view(ipv4.payload())?; - Ok(match icmp_v4.get_icmp_type() { + let icmp_type = icmp_v4.get_icmp_type(); + let icmp_code = icmp_v4.get_icmp_code(); + Ok(match icmp_type { IcmpType::TimeExceeded => { - let packet = TimeExceededPacket::new_view(icmp_v4.packet())?; - let (nested_ipv4, extension) = match icmp_extension_mode { - IcmpExtensionParseMode::Enabled => { - let ipv4 = Ipv4Packet::new_view(packet.payload())?; - let ext = packet.extension().map(Extensions::try_from).transpose()?; - (ipv4, ext) - } - IcmpExtensionParseMode::Disabled => { - let ipv4 = Ipv4Packet::new_view(packet.payload_raw())?; - (ipv4, None) - } - }; - extract_probe_resp_seq(&nested_ipv4, protocol)?.map(|resp_seq| { - ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, src, resp_seq), extension) - }) + if IcmpTimeExceededCode::from(icmp_code) == IcmpTimeExceededCode::TtlExpired { + let packet = TimeExceededPacket::new_view(icmp_v4.packet())?; + let (nested_ipv4, extension) = match icmp_extension_mode { + IcmpExtensionParseMode::Enabled => { + let ipv4 = Ipv4Packet::new_view(packet.payload())?; + let ext = packet.extension().map(Extensions::try_from).transpose()?; + (ipv4, ext) + } + IcmpExtensionParseMode::Disabled => { + let ipv4 = Ipv4Packet::new_view(packet.payload_raw())?; + (ipv4, None) + } + }; + extract_probe_resp_seq(&nested_ipv4, protocol)?.map(|resp_seq| { + ProbeResponse::TimeExceeded( + ProbeResponseData::new(recv, src, resp_seq), + extension, + ) + }) + } else { + None + } } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v4.packet())?; @@ -473,3 +482,85 @@ fn extract_tcp_packet(ipv4: &Ipv4Packet<'_>) -> TraceResult<(u16, u16)> { Ok((tcp_packet.get_source(), tcp_packet.get_destination())) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::mocket_read; + use crate::tracing::error::IoResult; + use crate::tracing::net::socket::MockSocket; + use crate::tracing::{Port, Round, TimeToLive}; + use mockall::predicate; + use std::str::FromStr; + + // Test dispatching a IPv4/ICMP probe. + #[test] + fn test_dispatch_icmp_probe_no_payload() -> anyhow::Result<()> { + let probe = Probe::new( + Sequence(33000), + TraceId(1234), + Port(0), + Port(0), + TimeToLive(10), + Round(0), + SystemTime::now(), + ); + let src_addr = Ipv4Addr::from_str("1.2.3.4")?; + let dest_addr = Ipv4Addr::from_str("5.6.7.8")?; + let packet_size = PacketSize(28); + let payload_pattern = PayloadPattern(0x00); + let ipv4_byte_order = platform::PlatformIpv4FieldByteOrder::Network; + let expected_send_to_buf = hex_literal::hex!( + " + 45 00 00 1c 00 00 40 00 0a 01 00 00 01 02 03 04 + 05 06 07 08 08 00 72 45 04 d2 80 e8 + " + ); + let expected_send_to_addr = SocketAddr::new(IpAddr::V4(dest_addr), 0); + + let mut mocket = MockSocket::new(); + mocket + .expect_send_to() + .with( + predicate::eq(expected_send_to_buf), + predicate::eq(expected_send_to_addr), + ) + .times(1) + .returning(|_, _| Ok(())); + + dispatch_icmp_probe( + &mut mocket, + probe, + src_addr, + dest_addr, + packet_size, + payload_pattern, + ipv4_byte_order, + )?; + Ok(()) + } + + // This IPv4/ICMP TimeExceeded packet has code 1 ("Fragment reassembly + // time exceeded") and must be ignored. + // + // Note this is not real packet and so the length and checksum are not + // accurate. + #[test] + fn test_icmp_time_exceeded_fragment_reassembly_ignored() -> anyhow::Result<()> { + let expected_read_buf = hex_literal::hex!( + " + 45 20 2c 02 e4 5c 00 00 72 01 2e 04 67 4b 0b 34 + c0 a8 01 15 0b 01 1c 38 00 00 00 00 45 00 8c 05 + 85 4e 20 00 30 11 ab d6 c0 a8 01 15 67 4b 0b 34 + " + ); + let mut mocket = MockSocket::new(); + mocket + .expect_read() + .times(1) + .returning(mocket_read!(expected_read_buf)); + let resp = recv_icmp_probe(&mut mocket, Protocol::Udp, IcmpExtensionParseMode::Enabled)?; + assert!(resp.is_none()); + Ok(()) + } +} diff --git a/src/tracing/net/ipv6.rs b/src/tracing/net/ipv6.rs index 099b8428..2923e1a9 100644 --- a/src/tracing/net/ipv6.rs +++ b/src/tracing/net/ipv6.rs @@ -9,7 +9,7 @@ use crate::tracing::packet::icmpv6::destination_unreachable::DestinationUnreacha use crate::tracing::packet::icmpv6::echo_reply::EchoReplyPacket; use crate::tracing::packet::icmpv6::echo_request::EchoRequestPacket; use crate::tracing::packet::icmpv6::time_exceeded::TimeExceededPacket; -use crate::tracing::packet::icmpv6::{IcmpCode, IcmpPacket, IcmpType}; +use crate::tracing::packet::icmpv6::{IcmpCode, IcmpPacket, IcmpTimeExceededCode, IcmpType}; use crate::tracing::packet::ipv6::Ipv6Packet; use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; @@ -302,23 +302,32 @@ fn extract_probe_resp( ) -> TraceResult> { let recv = SystemTime::now(); let ip = IpAddr::V6(src); - Ok(match icmp_v6.get_icmp_type() { + let icmp_type = icmp_v6.get_icmp_type(); + let icmp_code = icmp_v6.get_icmp_code(); + Ok(match icmp_type { IcmpType::TimeExceeded => { - let packet = TimeExceededPacket::new_view(icmp_v6.packet())?; - let (nested_ipv6, extension) = match icmp_extension_mode { - IcmpExtensionParseMode::Enabled => { - let ipv6 = Ipv6Packet::new_view(packet.payload())?; - let ext = packet.extension().map(Extensions::try_from).transpose()?; - (ipv6, ext) - } - IcmpExtensionParseMode::Disabled => { - let ipv6 = Ipv6Packet::new_view(packet.payload_raw())?; - (ipv6, None) - } - }; - extract_probe_resp_seq(&nested_ipv6, protocol)?.map(|resp_seq| { - ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, ip, resp_seq), extension) - }) + if IcmpTimeExceededCode::from(icmp_code) == IcmpTimeExceededCode::TtlExpired { + let packet = TimeExceededPacket::new_view(icmp_v6.packet())?; + let (nested_ipv6, extension) = match icmp_extension_mode { + IcmpExtensionParseMode::Enabled => { + let ipv6 = Ipv6Packet::new_view(packet.payload())?; + let ext = packet.extension().map(Extensions::try_from).transpose()?; + (ipv6, ext) + } + IcmpExtensionParseMode::Disabled => { + let ipv6 = Ipv6Packet::new_view(packet.payload_raw())?; + (ipv6, None) + } + }; + extract_probe_resp_seq(&nested_ipv6, protocol)?.map(|resp_seq| { + ProbeResponse::TimeExceeded( + ProbeResponseData::new(recv, ip, resp_seq), + extension, + ) + }) + } else { + None + } } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v6.packet())?; @@ -419,3 +428,93 @@ fn extract_tcp_packet(ipv6: &Ipv6Packet<'_>) -> TraceResult<(u16, u16)> { let tcp_packet = TcpPacket::new_view(ipv6.payload())?; Ok((tcp_packet.get_source(), tcp_packet.get_destination())) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::mocket_recv_from; + use crate::tracing::error::IoResult; + use crate::tracing::net::socket::MockSocket; + use crate::tracing::{Port, Round, TimeToLive}; + use mockall::predicate; + use std::str::FromStr; + + // Test dispatching an IPv6/ICMP probe. + #[test] + fn test_dispatch_icmp_probe_no_payload() -> anyhow::Result<()> { + let probe = Probe::new( + Sequence(33000), + TraceId(1234), + Port(0), + Port(0), + TimeToLive(10), + Round(0), + SystemTime::now(), + ); + let src_addr = Ipv6Addr::from_str("fd7a:115c:a1e0:ab12:4843:cd96:6263:82a")?; + let dest_addr = Ipv6Addr::from_str("2a00:1450:4009:815::200e")?; + let packet_size = PacketSize(48); + let payload_pattern = PayloadPattern(0x00); + let expected_send_to_buf = hex_literal::hex!("80 00 77 54 04 d2 80 e8"); + let expected_send_to_addr = SocketAddr::new(IpAddr::V6(dest_addr), 0); + + let mut mocket = MockSocket::new(); + mocket + .expect_send_to() + .with( + predicate::eq(expected_send_to_buf), + predicate::eq(expected_send_to_addr), + ) + .times(1) + .returning(|_, _| Ok(())); + mocket + .expect_set_unicast_hops_v6() + .times(1) + .with(predicate::eq(10)) + .returning(|_| Ok(())); + + dispatch_icmp_probe( + &mut mocket, + probe, + src_addr, + dest_addr, + packet_size, + payload_pattern, + )?; + Ok(()) + } + + // This ICMPv6 packet has code 1 ("Fragment reassembly time exceeded") + // and must be ignored. + // + // Note this is not real packet and so the length and checksum are not + // accurate. + #[test] + fn test_icmp_time_exceeded_fragment_reassembly_ignored() -> anyhow::Result<()> { + let expected_recv_from_buf = hex_literal::hex!( + " + 03 01 da 90 00 00 00 00 60 0f 02 00 00 2c 11 01 + fd 7a 11 5c a1 e0 ab 12 48 43 cd 96 62 63 08 2a + 2a 00 14 50 40 09 08 15 00 00 00 00 00 00 20 0e + 95 ce 81 24 00 2c 65 f5 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 + " + ); + let expected_recv_from_addr = SocketAddr::new( + IpAddr::V6(Ipv6Addr::from_str("2604:a880:ffff:6:1::41c").unwrap()), + 0, + ); + let mut mocket = MockSocket::new(); + mocket + .expect_recv_from() + .times(1) + .returning(mocket_recv_from!( + expected_recv_from_buf, + expected_recv_from_addr + )); + let resp = recv_icmp_probe(&mut mocket, Protocol::Udp, IcmpExtensionParseMode::Enabled)?; + assert!(resp.is_none()); + Ok(()) + } +} diff --git a/src/tracing/net/socket.rs b/src/tracing/net/socket.rs index 82418733..60f0a2d2 100644 --- a/src/tracing/net/socket.rs +++ b/src/tracing/net/socket.rs @@ -2,6 +2,7 @@ use crate::tracing::error::IoResult as Result; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::Duration; +#[cfg_attr(test, mockall::automock)] pub trait Socket where Self: Sized, @@ -46,3 +47,26 @@ where fn icmp_error_info(&mut self) -> Result; fn close(&mut self) -> Result<()>; } + +#[cfg(test)] +pub mod tests { + #[macro_export] + macro_rules! mocket_read { + ($packet: expr) => { + move |buf: &mut [u8]| -> IoResult { + buf[..$packet.len()].copy_from_slice(&$packet); + Ok(buf.len()) + } + }; + } + + #[macro_export] + macro_rules! mocket_recv_from { + ($packet: expr, $addr: expr) => { + move |buf: &mut [u8]| -> IoResult<(usize, Option)> { + buf[..$packet.len()].copy_from_slice(&$packet); + Ok((buf.len(), Some($addr))) + } + }; + } +} diff --git a/src/tracing/packet/icmpv4.rs b/src/tracing/packet/icmpv4.rs index 1efe94be..fe682b6d 100644 --- a/src/tracing/packet/icmpv4.rs +++ b/src/tracing/packet/icmpv4.rs @@ -47,6 +47,27 @@ impl From for IcmpCode { } } +/// The code for `TimeExceeded` ICMP packet type. +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub enum IcmpTimeExceededCode { + /// TTL expired in transit. + TtlExpired, + /// Fragment reassembly time exceeded. + FragmentReassembly, + /// An unknown code. + Unknown(u8), +} + +impl From for IcmpTimeExceededCode { + fn from(val: IcmpCode) -> Self { + match val { + IcmpCode(0) => Self::TtlExpired, + IcmpCode(1) => Self::FragmentReassembly, + IcmpCode(id) => Self::Unknown(id), + } + } +} + const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; diff --git a/src/tracing/packet/icmpv6.rs b/src/tracing/packet/icmpv6.rs index 3e169f79..fa720f7f 100644 --- a/src/tracing/packet/icmpv6.rs +++ b/src/tracing/packet/icmpv6.rs @@ -47,6 +47,27 @@ impl From for IcmpCode { } } +/// The code for `TimeExceeded` `ICMPv6` packet type. +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub enum IcmpTimeExceededCode { + /// Hop limit exceeded in transit. + TtlExpired, + /// Fragment reassembly time exceeded. + FragmentReassembly, + /// An unknown code. + Unknown(u8), +} + +impl From for IcmpTimeExceededCode { + fn from(val: IcmpCode) -> Self { + match val { + IcmpCode(0) => Self::TtlExpired, + IcmpCode(1) => Self::FragmentReassembly, + IcmpCode(id) => Self::Unknown(id), + } + } +} + const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2;