From 69fcc676e98da959fcdf7c718ea3d477237b863b Mon Sep 17 00:00:00 2001 From: LAN Xingcan Date: Fri, 17 May 2024 22:52:56 +0800 Subject: [PATCH] proto: introduce ConnectionIdParser Currently `PlainHeader::decode` and `PartialDecoder::new` expect a `local_cid_len`, which means they cannot support variable length Connection ID format and make it less useful in various use cases, such as implementing a [QUIC-LB] confirming load balancer. [QUIC-LB]: https://www.ietf.org/archive/id/draft-ietf-quic-load-balancers-19.html --- fuzz/fuzz_targets/packet.rs | 4 +-- quinn-proto/src/connection/mod.rs | 6 ++-- quinn-proto/src/endpoint.rs | 6 ++-- quinn-proto/src/lib.rs | 5 +++- quinn-proto/src/packet.rs | 50 ++++++++++++++++++++++++------- 5 files changed, 51 insertions(+), 20 deletions(-) diff --git a/fuzz/fuzz_targets/packet.rs b/fuzz/fuzz_targets/packet.rs index 5524a36b34..a8320a87a6 100644 --- a/fuzz/fuzz_targets/packet.rs +++ b/fuzz/fuzz_targets/packet.rs @@ -5,7 +5,7 @@ extern crate proto; use libfuzzer_sys::fuzz_target; use proto::{ fuzzing::{PacketParams, PartialDecode}, - DEFAULT_SUPPORTED_VERSIONS, + FixedLengthConnectionIdParser, DEFAULT_SUPPORTED_VERSIONS, }; fuzz_target!(|data: PacketParams| { @@ -13,7 +13,7 @@ fuzz_target!(|data: PacketParams| { let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); if let Ok(decoded) = PartialDecode::new( data.buf, - data.local_cid_len, + &FixedLengthConnectionIdParser::new(data.local_cid_len), &supported_versions, data.grease_quic_bit, ) { diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 49050dec55..8e476cae97 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -23,8 +23,8 @@ use crate::{ frame, frame::{Close, Datagram, FrameStruct}, packet::{ - Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, - SpaceId, + FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, + PacketNumber, PartialDecode, SpaceId, }, range_set::ArrayRangeSet, shared::{ @@ -2101,7 +2101,7 @@ impl Connection { while let Some(data) = remaining { match PartialDecode::new( data, - self.local_cid_state.cid_len(), + &FixedLengthConnectionIdParser::new(self.local_cid_state.cid_len()), &[self.version], self.endpoint_config.grease_quic_bit, ) { diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 3bfab440f8..2c21ca00d7 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -23,8 +23,8 @@ use crate::{ crypto::{self, Keys, UnsupportedVersion}, frame, packet::{ - Header, InitialHeader, InitialPacket, Packet, PacketDecodeError, PacketNumber, - PartialDecode, PlainInitialHeader, + FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, Packet, + PacketDecodeError, PacketNumber, PartialDecode, PlainInitialHeader, }, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint, @@ -144,7 +144,7 @@ impl Endpoint { let datagram_len = data.len(); let (first_decode, remaining) = match PartialDecode::new( data, - self.local_cid_generator.cid_len(), + &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()), &self.config.supported_versions, self.config.grease_quic_bit, ) { diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 77f7ce0794..819aac4650 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -65,7 +65,10 @@ pub use crate::endpoint::{ }; mod packet; -pub use packet::{LongType, PacketDecodeError, PartialDecode, PlainHeader, PlainInitialHeader}; +pub use packet::{ + ConnectionIdParser, FixedLengthConnectionIdParser, LongType, PacketDecodeError, PartialDecode, + PlainHeader, PlainInitialHeader, +}; mod shared; pub use crate::shared::{ConnectionEvent, ConnectionId, EcnCodepoint, EndpointEvent}; diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index 9dff073256..c669a7b841 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -14,7 +14,7 @@ use crate::{ /// (which includes a variable-length packet number) without crypto context. /// The crypto context (represented by the `Crypto` type in Quinn) is usually /// part of the `Connection`, or can be derived from the destination CID for -// Initial packets. +/// Initial packets. /// /// To cope with this, we decode the invariant header (which should be stable /// across QUIC versions), which gives us the destination CID and allows us @@ -32,13 +32,13 @@ impl PartialDecode { /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet pub fn new( bytes: BytesMut, - local_cid_len: usize, + cid_parser: &impl ConnectionIdParser, supported_versions: &[u32], grease_quic_bit: bool, ) -> Result<(Self, Option), PacketDecodeError> { let mut buf = io::Cursor::new(bytes); let plain_header = - PlainHeader::decode(&mut buf, local_cid_len, supported_versions, grease_quic_bit)?; + PlainHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?; let dgram_len = buf.get_ref().len(); let packet_len = plain_header .payload_len() @@ -564,7 +564,7 @@ impl PlainHeader { /// Decode a plain header from given buffer, with given [`ConnectionIdParser`]. pub fn decode( buf: &mut io::Cursor, - local_cid_len: usize, + cid_parser: &impl ConnectionIdParser, supported_versions: &[u32], grease_quic_bit: bool, ) -> Result { @@ -574,13 +574,10 @@ impl PlainHeader { } if first & LONG_HEADER_FORM == 0 { let spin = first & SPIN_BIT != 0; - if buf.remaining() < local_cid_len { - return Err(PacketDecodeError::InvalidHeader("cid out of bounds")); - } Ok(Self::Short { spin, - dst_cid: ConnectionId::from_buf(buf, local_cid_len), + dst_cid: cid_parser.parse(buf)?, }) } else { let version = buf.get::()?; @@ -770,6 +767,32 @@ impl PacketNumber { } } +/// An [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length +pub struct FixedLengthConnectionIdParser { + expected_len: usize, +} + +impl FixedLengthConnectionIdParser { + /// Create a new instance of `FixedLengthConnectionIdParser` + pub fn new(expected_len: usize) -> Self { + Self { expected_len } + } +} + +impl ConnectionIdParser for FixedLengthConnectionIdParser { + fn parse(&self, buffer: &mut impl Buf) -> Result { + (buffer.remaining() >= self.expected_len) + .then(|| ConnectionId::from_buf(buffer, self.expected_len)) + .ok_or(PacketDecodeError::InvalidHeader("packet too small")) + } +} + +/// Parse connection id in short header packet +pub trait ConnectionIdParser { + /// Parse a connection id from given buffer + fn parse(&self, buf: &mut impl Buf) -> Result; +} + /// Long packet type including non-uniform cases #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum LongHeaderType { @@ -945,9 +968,14 @@ mod tests { let server = initial_keys(Version::V1, &dcid, Side::Server, &suite); let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); - let decode = PartialDecode::new(buf.as_slice().into(), 0, &supported_versions, false) - .unwrap() - .0; + let decode = PartialDecode::new( + buf.as_slice().into(), + &FixedLengthConnectionIdParser::new(0), + &supported_versions, + false, + ) + .unwrap() + .0; let mut packet = decode.finish(Some(&*server.header.remote)).unwrap(); assert_eq!( packet.header_data[..],