diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index c20cc964..26b4dc68 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -331,17 +331,17 @@ fn extract_probe_resp( Ok(match icmp_v4.get_icmp_type() { IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - recv, src, resp_seq, - ))) + let nested_ipv4 = Ipv4Packet::new_view(packet.payload()).req()?; + extract_probe_resp_seq(&nested_ipv4, protocol)?.map(|resp_seq| { + ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, src, resp_seq)) + }) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::DestinationUnreachable( - ProbeResponseData::new(recv, src, resp_seq), - )) + let nested_ipv4 = Ipv4Packet::new_view(packet.payload()).req()?; + extract_probe_resp_seq(&nested_ipv4, protocol)?.map(|resp_seq| { + ProbeResponse::DestinationUnreachable(ProbeResponseData::new(recv, src, resp_seq)) + }) } IcmpType::EchoReply => match protocol { TracerProtocol::Icmp => { @@ -361,50 +361,48 @@ fn extract_probe_resp( #[instrument] fn extract_probe_resp_seq( - payload: &[u8], + ipv4: &Ipv4Packet<'_>, protocol: TracerProtocol, -) -> TraceResult { - Ok(match protocol { - TracerProtocol::Icmp => { - let echo_request = extract_echo_request(payload)?; +) -> TraceResult> { + Ok(match (protocol, ipv4.get_protocol()) { + (TracerProtocol::Icmp, IpProtocol::Icmp) => { + let echo_request = extract_echo_request(ipv4)?; let identifier = echo_request.get_identifier(); let sequence = echo_request.get_sequence(); - ProbeResponseSeq::Icmp(ProbeResponseSeqIcmp::new(identifier, sequence)) + Some(ProbeResponseSeq::Icmp(ProbeResponseSeqIcmp::new( + identifier, sequence, + ))) } - TracerProtocol::Udp => { - let (src_port, dest_port, checksum, identifier) = extract_udp_packet(payload)?; - ProbeResponseSeq::Udp(ProbeResponseSeqUdp::new( + (TracerProtocol::Udp, IpProtocol::Udp) => { + let (src_port, dest_port, checksum, identifier) = extract_udp_packet(ipv4)?; + Some(ProbeResponseSeq::Udp(ProbeResponseSeqUdp::new( identifier, src_port, dest_port, checksum, - )) + ))) } - TracerProtocol::Tcp => { - let (src_port, dest_port) = extract_tcp_packet(payload)?; - ProbeResponseSeq::Tcp(ProbeResponseSeqTcp::new(src_port, dest_port)) + (TracerProtocol::Tcp, IpProtocol::Tcp) => { + let (src_port, dest_port) = extract_tcp_packet(ipv4)?; + Some(ProbeResponseSeq::Tcp(ProbeResponseSeqTcp::new( + src_port, dest_port, + ))) } + _ => None, }) } #[instrument] -fn extract_echo_request(payload: &[u8]) -> TraceResult> { - let ip4 = Ipv4Packet::new_view(payload).req()?; - let header_len = usize::from(ip4.get_header_length() * 4); - let nested_icmp = &payload[header_len..]; - let nested_echo = EchoRequestPacket::new_view(nested_icmp).req()?; - Ok(nested_echo) +fn extract_echo_request<'a>(ipv4: &'a Ipv4Packet<'a>) -> TraceResult> { + Ok(EchoRequestPacket::new_view(ipv4.payload()).req()?) } /// Get the src and dest ports from the original `UdpPacket` packet embedded in the payload. #[instrument] -fn extract_udp_packet(payload: &[u8]) -> TraceResult<(u16, u16, u16, u16)> { - let ip4 = Ipv4Packet::new_view(payload).req()?; - let header_len = usize::from(ip4.get_header_length() * 4); - let nested_udp = &payload[header_len..]; - let nested = UdpPacket::new_view(nested_udp).req()?; +fn extract_udp_packet(ipv4: &Ipv4Packet<'_>) -> TraceResult<(u16, u16, u16, u16)> { + let nested = UdpPacket::new_view(ipv4.payload()).req()?; Ok(( nested.get_source(), nested.get_destination(), nested.get_checksum(), - ip4.get_identification(), + ipv4.get_identification(), )) } @@ -420,10 +418,8 @@ fn extract_udp_packet(payload: &[u8]) -> TraceResult<(u16, u16, u16, u16)> { /// We therefore have to detect this situation and ensure we provide buffer a large enough for a /// complete TCP packet header. #[instrument] -fn extract_tcp_packet(payload: &[u8]) -> TraceResult<(u16, u16)> { - let ip4 = Ipv4Packet::new_view(payload).req()?; - let header_len = usize::from(ip4.get_header_length() * 4); - let nested_tcp = &payload[header_len..]; +fn extract_tcp_packet(ipv4: &Ipv4Packet<'_>) -> TraceResult<(u16, u16)> { + let nested_tcp = ipv4.payload(); if nested_tcp.len() < TcpPacket::minimum_packet_size() { let mut buf = [0_u8; TcpPacket::minimum_packet_size()]; buf[..nested_tcp.len()].copy_from_slice(nested_tcp); diff --git a/src/tracing/net/ipv6.rs b/src/tracing/net/ipv6.rs index 8b7f128a..9ef9fb48 100644 --- a/src/tracing/net/ipv6.rs +++ b/src/tracing/net/ipv6.rs @@ -12,6 +12,7 @@ use crate::tracing::packet::icmpv6::{IcmpCode, IcmpPacket, IcmpType}; use crate::tracing::packet::ipv6::Ipv6Packet; use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; +use crate::tracing::packet::IpProtocol; use crate::tracing::probe::{ ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, ProbeResponseSeqTcp, ProbeResponseSeqUdp, @@ -261,17 +262,17 @@ fn extract_probe_resp( Ok(match icmp_v6.get_icmp_type() { IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v6.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - recv, ip, resp_seq, - ))) + let nested_ipv6 = Ipv6Packet::new_view(packet.payload()).req()?; + extract_probe_resp_seq(&nested_ipv6, protocol)?.map(|resp_seq| { + ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, ip, resp_seq)) + }) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v6.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::DestinationUnreachable( - ProbeResponseData::new(recv, ip, resp_seq), - )) + let nested_ipv6 = Ipv6Packet::new_view(packet.payload()).req()?; + extract_probe_resp_seq(&nested_ipv6, protocol)?.map(|resp_seq| { + ProbeResponse::DestinationUnreachable(ProbeResponseData::new(recv, ip, resp_seq)) + }) } IcmpType::EchoReply => match protocol { TracerProtocol::Icmp => { @@ -290,27 +291,33 @@ fn extract_probe_resp( } fn extract_probe_resp_seq( - payload: &[u8], + ipv6: &Ipv6Packet<'_>, protocol: TracerProtocol, -) -> TraceResult { - Ok(match protocol { - TracerProtocol::Icmp => { - let (identifier, sequence) = extract_echo_request(payload)?; - ProbeResponseSeq::Icmp(ProbeResponseSeqIcmp::new(identifier, sequence)) +) -> TraceResult> { + Ok(match (protocol, ipv6.get_next_header()) { + (TracerProtocol::Icmp, IpProtocol::IcmpV6) => { + let (identifier, sequence) = extract_echo_request(ipv6)?; + Some(ProbeResponseSeq::Icmp(ProbeResponseSeqIcmp::new( + identifier, sequence, + ))) } - TracerProtocol::Udp => { - let (src_port, dest_port) = extract_udp_packet(payload)?; - ProbeResponseSeq::Udp(ProbeResponseSeqUdp::new(0, src_port, dest_port, 0)) + (TracerProtocol::Udp, IpProtocol::Udp) => { + let (src_port, dest_port) = extract_udp_packet(ipv6)?; + Some(ProbeResponseSeq::Udp(ProbeResponseSeqUdp::new( + 0, src_port, dest_port, 0, + ))) } - TracerProtocol::Tcp => { - let (src_port, dest_port) = extract_tcp_packet(payload)?; - ProbeResponseSeq::Tcp(ProbeResponseSeqTcp::new(src_port, dest_port)) + (TracerProtocol::Tcp, IpProtocol::Tcp) => { + let (src_port, dest_port) = extract_tcp_packet(ipv6)?; + Some(ProbeResponseSeq::Tcp(ProbeResponseSeqTcp::new( + src_port, dest_port, + ))) } + _ => None, }) } -fn extract_echo_request(ipv6_bytes: &[u8]) -> TraceResult<(u16, u16)> { - let ipv6 = Ipv6Packet::new_view(ipv6_bytes).req()?; +fn extract_echo_request(ipv6: &Ipv6Packet<'_>) -> TraceResult<(u16, u16)> { let echo_request_packet = EchoRequestPacket::new_view(ipv6.payload()).req()?; Ok(( echo_request_packet.get_identifier(), @@ -318,8 +325,7 @@ fn extract_echo_request(ipv6_bytes: &[u8]) -> TraceResult<(u16, u16)> { )) } -fn extract_udp_packet(ipv6_bytes: &[u8]) -> TraceResult<(u16, u16)> { - let ipv6 = Ipv6Packet::new_view(ipv6_bytes).req()?; +fn extract_udp_packet(ipv6: &Ipv6Packet<'_>) -> TraceResult<(u16, u16)> { let udp_packet = UdpPacket::new_view(ipv6.payload()).req()?; Ok((udp_packet.get_source(), udp_packet.get_destination())) } @@ -343,8 +349,7 @@ fn extract_udp_packet(ipv6_bytes: &[u8]) -> TraceResult<(u16, u16)> { /// /// [rfc4443]: https://datatracker.ietf.org/doc/html/rfc4443#section-2.4 /// [rfc2460]: https://datatracker.ietf.org/doc/html/rfc2460#section-5 -fn extract_tcp_packet(ipv6_bytes: &[u8]) -> TraceResult<(u16, u16)> { - let ipv6 = Ipv6Packet::new_view(ipv6_bytes).req()?; +fn extract_tcp_packet(ipv6: &Ipv6Packet<'_>) -> TraceResult<(u16, u16)> { let tcp_packet = TcpPacket::new_view(ipv6.payload()).req()?; Ok((tcp_packet.get_source(), tcp_packet.get_destination())) }