Skip to content

Commit

Permalink
proto: introduce ConnectionIdValidator and expose PlainHeader parsing…
Browse files Browse the repository at this point in the history
… stuffes
  • Loading branch information
LAN Xingcan committed May 17, 2024
1 parent 8bd0600 commit e09d8c9
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 18 deletions.
27 changes: 27 additions & 0 deletions quinn-proto/src/cid_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ pub trait ConnectionIdGenerator: Send + Sync {
fn cid_lifetime(&self) -> Option<Duration>;
}

/// Validates connection id in short header packet
pub trait ConnectionIdValidator {
/// Validate and regconize a connection id from given buffer, return the length of connection id
/// or [`None`] if it is invalid
fn validate(&self, buffer: &[u8]) -> Option<usize>;
}

pub(crate) struct FixedLengthConnectionIdValidator {
expected_len: usize,
}

impl FixedLengthConnectionIdValidator {
pub(crate) fn new(expected_len: usize) -> Self {
Self { expected_len }
}
}

impl ConnectionIdValidator for FixedLengthConnectionIdValidator {
fn validate(&self, buffer: &[u8]) -> Option<usize> {
if buffer.len() <= self.expected_len {
Some(self.expected_len)
} else {
None
}
}
}

/// The connection ID was not recognized by the [`ConnectionIdGenerator`]
#[derive(Debug, Copy, Clone)]
pub struct InvalidCid;
Expand Down
6 changes: 5 additions & 1 deletion quinn-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ mod cid_queue;
#[doc(hidden)]
pub mod coding;
mod constant_time;

mod packet;
pub use packet::{LongType, PacketDecodeError, PartialDecode, PlainHeader, PlainInitialHeader};

mod range_set;
#[cfg(all(test, feature = "rustls"))]
mod tests;
Expand Down Expand Up @@ -75,7 +78,8 @@ pub mod congestion;

mod cid_generator;
pub use crate::cid_generator::{
ConnectionIdGenerator, HashedConnectionIdGenerator, InvalidCid, RandomConnectionIdGenerator,
ConnectionIdGenerator, ConnectionIdValidator, HashedConnectionIdGenerator, InvalidCid,
RandomConnectionIdGenerator,
};

mod token;
Expand Down
103 changes: 87 additions & 16 deletions quinn-proto/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut};
use thiserror::Error;

use crate::{
cid_generator::{ConnectionIdValidator, FixedLengthConnectionIdValidator},
coding::{self, BufExt, BufMutExt},
crypto, ConnectionId,
};
Expand All @@ -18,17 +19,17 @@ use crate::{
// across QUIC versions), which gives us the destination CID and allows us
// to inspect the version and packet type (which depends on the version).
// This information allows us to fully decode and decrypt the packet.
#[allow(unreachable_pub)] // fuzzing only
#[allow(unreachable_pub, missing_docs)] // fuzzing only
#[cfg_attr(test, derive(Clone))]
#[derive(Debug)]
pub struct PartialDecode {
plain_header: PlainHeader,
buf: io::Cursor<BytesMut>,
}

#[allow(clippy::len_without_is_empty)]
#[allow(clippy::len_without_is_empty, missing_docs)]
impl PartialDecode {
#[allow(unreachable_pub)] // fuzzing only
/// Create an instance of [`PartialDecode`], expecting a fixed-length local connection id
pub fn new(
bytes: BytesMut,
local_cid_len: usize,
Expand All @@ -55,12 +56,44 @@ impl PartialDecode {
}
}

/// Create an instance of [`PartialDecode`], with connection id validated by given
/// [`ConnectionIdValidator`]
pub fn new_with_cid_validator<CidValidator: ConnectionIdValidator>(
bytes: BytesMut,
cid_validator: &CidValidator,
supported_versions: &[u32],
grease_quic_bit: bool,
) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
let mut buf = io::Cursor::new(bytes);
let plain_header = PlainHeader::decode_with_cid_validator(
&mut buf,
cid_validator,
supported_versions,
grease_quic_bit,
)?;
let dgram_len = buf.get_ref().len();
let packet_len = plain_header
.payload_len()
.map(|len| (buf.position() + len) as usize)
.unwrap_or(dgram_len);
match dgram_len.cmp(&packet_len) {
Ordering::Equal => Ok((Self { plain_header, buf }, None)),
Ordering::Less => Err(PacketDecodeError::InvalidHeader(
"packet too short to contain payload length",
)),
Ordering::Greater => {
let rest = Some(buf.get_mut().split_off(packet_len));
Ok((Self { plain_header, buf }, rest))
}
}
}

/// The underlying partially-decoded packet data
pub(crate) fn data(&self) -> &[u8] {
self.buf.get_ref()
}

pub(crate) fn initial_header(&self) -> Option<&PlainInitialHeader> {
pub fn initial_header(&self) -> Option<&PlainInitialHeader> {
self.plain_header.as_initial()
}

Expand Down Expand Up @@ -96,7 +129,7 @@ impl PartialDecode {
}
}

pub(crate) fn dst_cid(&self) -> &ConnectionId {
pub fn dst_cid(&self) -> &ConnectionId {
self.plain_header.dst_cid()
}

Expand Down Expand Up @@ -486,8 +519,10 @@ impl PartialEncode {
}
}

/// Plain header of a packet
#[derive(Clone, Debug)]
pub(crate) enum PlainHeader {
#[allow(missing_docs)]
pub enum PlainHeader {
Initial(PlainInitialHeader),
Long {
ty: LongType,
Expand All @@ -497,7 +532,9 @@ pub(crate) enum PlainHeader {
version: u32,
},
Retry {
/// Destination Connection ID
dst_cid: ConnectionId,
/// Source Connection ID
src_cid: ConnectionId,
version: u32,
},
Expand All @@ -513,14 +550,15 @@ pub(crate) enum PlainHeader {
}

impl PlainHeader {
pub(crate) fn as_initial(&self) -> Option<&PlainInitialHeader> {
fn as_initial(&self) -> Option<&PlainInitialHeader> {
match self {
Self::Initial(x) => Some(x),
_ => None,
}
}

fn dst_cid(&self) -> &ConnectionId {
/// The destination Connection ID of the packet.
pub fn dst_cid(&self) -> &ConnectionId {
use self::PlainHeader::*;
match self {
Initial(header) => &header.dst_cid,
Expand All @@ -539,9 +577,10 @@ impl PlainHeader {
}
}

fn decode(
/// Decode a plain header from given buffer.
pub fn decode_with_cid_validator<CidValidator: ConnectionIdValidator>(
buf: &mut io::Cursor<BytesMut>,
local_cid_len: usize,
validator: &CidValidator,
supported_versions: &[u32],
grease_quic_bit: bool,
) -> Result<Self, PacketDecodeError> {
Expand All @@ -551,13 +590,11 @@ impl PlainHeader {
}
if first & LONG_HEADER_FORM == 0 {
let spin = first & SPIN_BIT != 0;
if buf.remaining() < local_cid_len {
return Err(PacketDecodeError::InvalidHeader("cid out of bounds"));
}

Ok(Self::Short {
spin,
dst_cid: ConnectionId::from_buf(buf, local_cid_len),
dst_cid: ConnectionId::from_buf_validated(buf, validator)
.ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?,
})
} else {
let version = buf.get::<u32>()?;
Expand Down Expand Up @@ -618,17 +655,43 @@ impl PlainHeader {
}
}
}

fn decode(
buf: &mut io::Cursor<BytesMut>,
local_cid_len: usize,
supported_versions: &[u32],
grease_quic_bit: bool,
) -> Result<Self, PacketDecodeError> {
Self::decode_with_cid_validator(
buf,
&FixedLengthConnectionIdValidator::new(local_cid_len),
supported_versions,
grease_quic_bit,
)
}
}

/// A Plain QUIC Header
#[derive(Clone, Debug)]
pub(crate) struct PlainInitialHeader {
pub struct PlainInitialHeader {
pub(crate) dst_cid: ConnectionId,
pub(crate) src_cid: ConnectionId,
pub(crate) token_pos: Range<usize>,
pub(crate) len: u64,
pub(crate) version: u32,
}

impl PlainInitialHeader {
/// The destination Connection ID of the packet.
pub fn dst_cid(&self) -> &ConnectionId {
&self.dst_cid
}
/// The source Connection ID of the packet.
pub fn src_cid(&self) -> &ConnectionId {
&self.src_cid
}
}

#[derive(Clone, Debug)]
pub(crate) struct InitialHeader {
pub(crate) dst_cid: ConnectionId,
Expand Down Expand Up @@ -777,20 +840,28 @@ impl From<LongHeaderType> for u8 {

/// Long packet types with uniform header structure
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum LongType {
pub enum LongType {
/// Long header type for Handshake packets
Handshake,
/// Long header type for 0-RTT packets
ZeroRtt,
}

#[allow(unreachable_pub)] // fuzzing only
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
/// Packet decode error
pub enum PacketDecodeError {
/// The packet uses a QUIC version that is not supported
#[error("unsupported version {version:x}")]
UnsupportedVersion {
/// The source connection ID
src_cid: ConnectionId,
/// The destination connection ID
dst_cid: ConnectionId,
/// The version that was not supported
version: u32,
},
/// The header of the packet is invalid
#[error("invalid header: {0}")]
InvalidHeader(&'static str),
}
Expand Down
13 changes: 12 additions & 1 deletion quinn-proto/src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::{fmt, net::SocketAddr, time::Instant};

use bytes::{Buf, BufMut, BytesMut};

use crate::{coding::BufExt, packet::PartialDecode, ResetToken, MAX_CID_SIZE};
use crate::{
cid_generator::ConnectionIdValidator, coding::BufExt, packet::PartialDecode, ResetToken,
MAX_CID_SIZE,
};

/// Events sent from an Endpoint to a Connection
#[derive(Debug)]
Expand Down Expand Up @@ -96,6 +99,14 @@ impl ConnectionId {
res
}

pub(crate) fn from_buf_validated<Validator: ConnectionIdValidator>(
buf: &mut impl Buf,
validator: &Validator,
) -> Option<Self> {
let len = validator.validate(buf.chunk())?;
Some(Self::from_buf(buf, len))
}

/// Decode from long header format
pub(crate) fn decode_long(buf: &mut impl Buf) -> Option<Self> {
let len = buf.get::<u8>().ok()? as usize;
Expand Down

0 comments on commit e09d8c9

Please sign in to comment.