diff --git a/.gitignore b/.gitignore index 36b713bbd..1e9e04ba2 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ Cargo.lock .idea .DS_Store .vscode +.zed cargo-test-* tarpaulin-report.html diff --git a/Cargo.toml b/Cargo.toml index b78d81eaa..eaff6d2ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ crc = "3" directories-next = "2" futures-io = "0.3.19" getrandom = { version = "0.2", default-features = false } +fastbloom = "0.8" hdrhistogram = { version = "7.2", default-features = false } hex-literal = "0.4" lazy_static = "1" diff --git a/quinn-proto/Cargo.toml b/quinn-proto/Cargo.toml index 8a91c14df..10d364039 100644 --- a/quinn-proto/Cargo.toml +++ b/quinn-proto/Cargo.toml @@ -11,7 +11,7 @@ categories.workspace = true workspace = ".." [features] -default = ["rustls-ring", "log"] +default = ["rustls-ring", "log", "fastbloom"] aws-lc-rs = ["dep:aws-lc-rs", "aws-lc-rs?/aws-lc-sys", "aws-lc-rs?/prebuilt-nasm"] aws-lc-rs-fips = ["aws-lc-rs", "aws-lc-rs?/fips"] # For backwards compatibility, `rustls` forwards to `rustls-ring` @@ -34,6 +34,7 @@ rustls-log = ["rustls?/logging"] arbitrary = { workspace = true, optional = true } aws-lc-rs = { workspace = true, optional = true } bytes = { workspace = true } +fastbloom = { workspace = true, optional = true } rustc-hash = { workspace = true } rand = { workspace = true } ring = { workspace = true, optional = true } @@ -55,6 +56,7 @@ web-time = { workspace = true } [dev-dependencies] assert_matches = { workspace = true } hex-literal = { workspace = true } +rand_pcg = "0.3" rcgen = { workspace = true } tracing-subscriber = { workspace = true } lazy_static = "1" diff --git a/quinn-proto/src/bloom_token_log.rs b/quinn-proto/src/bloom_token_log.rs new file mode 100644 index 000000000..eeb62aa6b --- /dev/null +++ b/quinn-proto/src/bloom_token_log.rs @@ -0,0 +1,326 @@ +use std::{ + collections::HashSet, + f64::consts::LN_2, + hash::{BuildHasher, Hasher}, + mem::{size_of, swap}, + sync::Mutex, +}; + +use fastbloom::BloomFilter; +use rustc_hash::FxBuildHasher; +use tracing::{trace, warn}; + +use crate::{Duration, SystemTime, TokenLog, TokenReuseError, UNIX_EPOCH}; + +/// Bloom filter-based `TokenLog` +/// +/// Parameterizable over an approximate maximum number of bytes to allocate. Starts out by storing +/// used tokens in a hash set. Once the hash set becomes too large, converts it to a bloom filter. +/// This achieves a memory profile of linear growth with an upper bound. +/// +/// Divides time into periods based on `lifetime` and stores two filters at any given moment, for +/// each of the two periods currently non-expired tokens could expire in. As such, turns over +/// filters as time goes on to avoid bloom filter false positive rate increasing infinitely over +/// time. +pub struct BloomTokenLog(Mutex); + +impl BloomTokenLog { + /// Construct with an approximate maximum memory usage and expected number of validation token + /// usages per expiration period + /// + /// Calculates the optimal bloom filter k number automatically. + /// + /// Panics if: + /// - `max_bytes` < 2 + pub fn new_expected_items(max_bytes: usize, expected_hits: u64) -> Self { + Self::new(max_bytes, optimal_k_num(max_bytes, expected_hits)) + } + + /// Construct with an approximate maximum memory usage and a bloom filter k number + /// + /// If choosing a custom k number, note that `BloomTokenLog` always maintains two filters + /// between them and divides the allocation budget of `max_bytes` evenly between them. As such, + /// each bloom filter will contain `max_bytes * 4` bits. + /// + /// Panics if: + /// - `max_bytes` < 2 + /// - `k_num` < 1 + pub fn new(max_bytes: usize, k_num: u32) -> Self { + assert!(max_bytes >= 2, "BloomTokenLog max_bytes too low"); + assert!(k_num >= 1, "BloomTokenLog k_num must be at least 1"); + + Self(Mutex::new(State { + config: FilterConfig { + filter_max_bytes: max_bytes / 2, + k_num, + }, + period_idx_1: 0, + filter_1: Filter::new(), + filter_2: Filter::new(), + })) + } +} + +fn optimal_k_num(num_bytes: usize, expected_hits: u64) -> u32 { + // be more forgiving rather than panickey here. excessively high num_bits may occur if the user + // wishes it to be unbounded, so just saturate. expected_hits of 0 would cause divide-by-zero, + // so just fudge it up to 1 in that case. + let num_bits = (num_bytes as u64).saturating_mul(8); + let expected_hits = expected_hits.max(1); + (((num_bits as f64 / expected_hits as f64) * LN_2).round() as u32).max(1) +} + +/// Lockable state of [`BloomTokenLog`] +struct State { + config: FilterConfig, + // filter_1 covers tokens that expire in the period starting at + // UNIX_EPOCH + period_idx_1 * lifetime and extending lifetime after. + // filter_2 covers tokens for the next lifetime after that. + period_idx_1: u128, + filter_1: Filter, + filter_2: Filter, +} + +impl TokenLog for BloomTokenLog { + fn check_and_insert( + &self, + rand: u128, + issued: SystemTime, + lifetime: Duration, + ) -> Result<(), TokenReuseError> { + trace!(%rand, "check_and_insert"); + let mut guard = self.0.lock().unwrap(); + let state = &mut *guard; + let fingerprint = rand_to_fingerprint(rand); + + // calculate period index for token + let period_idx = (issued + lifetime) + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + / lifetime.as_nanos(); + + // get relevant filter + let filter = if period_idx < state.period_idx_1 { + // shouldn't happen unless time travels backwards or new_token_lifetime changes + warn!("BloomTokenLog presented with token too far in past"); + return Err(TokenReuseError); + } else if period_idx == state.period_idx_1 { + &mut state.filter_1 + } else if period_idx == state.period_idx_1 + 1 { + &mut state.filter_2 + } else { + // turn over filters + if period_idx == state.period_idx_1 + 2 { + swap(&mut state.filter_1, &mut state.filter_2); + } else { + state.filter_1 = Filter::new(); + } + state.filter_2 = Filter::new(); + state.period_idx_1 = period_idx - 1; + + &mut state.filter_2 + }; + + filter.check_and_insert(fingerprint, &state.config) + } +} + +/// The token's rand needs to guarantee uniqueness because of the role it plays in the encryption +/// of the tokens, so it is 128 bits. But since the token log can tolerate both false positives and +/// false negatives, we trim it down to 64 bits, which would still only have a small collision rate +/// even at significant amounts of usage, while allowing us to store twice as many in the hash set +/// variant. +/// +/// Token rand values are uniformly randomly generated server-side and cryptographically integrity- +/// checked, so we don't need to employ secure hashing for this, we can simply truncate. +fn rand_to_fingerprint(rand: u128) -> u64 { + (rand & 0xffffffff) as u64 +} + +const DEFAULT_MAX_BYTES: usize = 10 << 20; +const DEFAULT_EXPECTED_HITS: u64 = 1_000_000; + +/// Default to 20 MiB max memory consumption and expected one million hits +impl Default for BloomTokenLog { + fn default() -> Self { + Self::new_expected_items(DEFAULT_MAX_BYTES, DEFAULT_EXPECTED_HITS) + } +} + +/// Unchanging parameters governing [`Filter`] behavior +struct FilterConfig { + filter_max_bytes: usize, + k_num: u32, +} + +/// Period filter within [`State`] +enum Filter { + Set(IdentityHashSet), + Bloom(FxBloomFilter), +} + +impl Filter { + fn new() -> Self { + Self::Set(HashSet::default()) + } + + fn check_and_insert( + &mut self, + fingerprint: u64, + config: &FilterConfig, + ) -> Result<(), TokenReuseError> { + match *self { + Self::Set(ref mut hset) => { + if !hset.insert(fingerprint) { + return Err(TokenReuseError); + } + + if hset.capacity() * size_of::() > config.filter_max_bytes { + // convert to bloom + let mut bloom = BloomFilter::with_num_bits(config.filter_max_bytes * 8) + .hasher(FxBuildHasher) + .hashes(config.k_num); + for item in hset.iter() { + bloom.insert(item); + } + *self = Self::Bloom(bloom); + } + } + Self::Bloom(ref mut bloom) => { + if bloom.insert(&fingerprint) { + return Err(TokenReuseError); + } + } + } + Ok(()) + } +} + +/// Bloom filter that uses `FxHasher`s +type FxBloomFilter = BloomFilter<512, FxBuildHasher>; + +/// `BuildHasher` of `IdentityHasher` +#[derive(Default)] +struct IdentityBuildHasher; + +impl BuildHasher for IdentityBuildHasher { + type Hasher = IdentityHasher; + + fn build_hasher(&self) -> Self::Hasher { + IdentityHasher::default() + } +} + +/// Hasher that is the identity operation--it assumes that exactly 8 bytes will be hashed, and the +/// resultant hash is those bytes as a `u64` +#[derive(Default)] +struct IdentityHasher { + data: [u8; 8], + #[cfg(debug_assertions)] + wrote_8_byte_slice: bool, +} + +impl Hasher for IdentityHasher { + fn write(&mut self, bytes: &[u8]) { + #[cfg(debug_assertions)] + { + assert!(!self.wrote_8_byte_slice); + assert_eq!(bytes.len(), 8); + self.wrote_8_byte_slice = true; + } + self.data.copy_from_slice(bytes); + } + + fn finish(&self) -> u64 { + #[cfg(debug_assertions)] + assert!(self.wrote_8_byte_slice); + u64::from_ne_bytes(self.data) + } +} + +/// Hash set of `u64` which are assumed to already be uniformly randomly distributed, and thus +/// effectively pre-hashed +type IdentityHashSet = HashSet; + +#[cfg(test)] +mod test { + use super::*; + use rand::prelude::*; + use rand_pcg::Pcg32; + + fn new_rng() -> impl Rng { + Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeefu128.to_le_bytes()) + } + + #[test] + fn identity_hash_test() { + let mut rng = new_rng(); + let builder = IdentityBuildHasher; + for _ in 0..100 { + let n = rng.gen::(); + let hash = builder.hash_one(n); + assert_eq!(hash, n); + } + } + + #[test] + fn optimal_k_num_test() { + assert_eq!(optimal_k_num(10 << 20, 1_000_000), 58); + assert_eq!(optimal_k_num(10 << 20, 1_000_000_000_000_000), 1); + // assert that these don't panic: + optimal_k_num(10 << 20, 0); + optimal_k_num(usize::MAX, 1_000_000); + } + + #[test] + fn bloom_token_log_conversion() { + let mut rng = new_rng(); + let log = BloomTokenLog::new_expected_items(800, 200); + + let issued = SystemTime::now(); + let lifetime = Duration::from_secs(1_000_000); + + for i in 0..200 { + let token = rng.gen::(); + let result = log.check_and_insert(token, issued, lifetime); + { + let filter = &log.0.lock().unwrap().filter_2; + if let Filter::Set(ref hset) = *filter { + assert!(hset.capacity() * size_of::() <= 800); + assert_eq!(hset.len(), i + 1); + assert!(result.is_ok()); + } else { + assert!(i > 10, "definitely bloomed too early"); + } + } + assert!(log.check_and_insert(token, issued, lifetime).is_err()); + } + } + + #[test] + fn turn_over() { + let mut rng = new_rng(); + let log = BloomTokenLog::new_expected_items(800, 200); + let lifetime = Duration::from_secs(1_000); + let mut old = Vec::default(); + let mut accepted = 0; + + for i in 0..200 { + let token = rng.gen::(); + let now = UNIX_EPOCH + lifetime * 10 + lifetime * i / 10; + let issued = now - lifetime.mul_f32(rng.gen_range(0.0..3.0)); + let result = log.check_and_insert(token, issued, lifetime); + if result.is_ok() { + accepted += 1; + } + old.push((token, issued)); + let old_idx = rng.gen::() % old.len(); + let (old_token, old_issued) = old[old_idx]; + assert!(log + .check_and_insert(old_token, old_issued, lifetime) + .is_err()); + } + assert!(accepted > 0); + } +} diff --git a/quinn-proto/src/config/mod.rs b/quinn-proto/src/config/mod.rs index 3ec56f571..bcf4f72c7 100644 --- a/quinn-proto/src/config/mod.rs +++ b/quinn-proto/src/config/mod.rs @@ -11,14 +11,16 @@ use rustls::client::WebPkiServerVerifier; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use thiserror::Error; +#[cfg(feature = "fastbloom")] +use crate::bloom_token_log::BloomTokenLog; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] use crate::crypto::rustls::{configured_provider, QuicServerConfig}; use crate::{ cid_generator::{ConnectionIdGenerator, HashedConnectionIdGenerator}, crypto::{self, HandshakeTokenKey, HmacKey}, shared::ConnectionId, - Duration, RandomConnectionIdGenerator, VarInt, VarIntBoundsExceeded, - DEFAULT_SUPPORTED_VERSIONS, MAX_CID_SIZE, + Duration, RandomConnectionIdGenerator, TokenLog, TokenMemoryCache, TokenStore, VarInt, + VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS, MAX_CID_SIZE, }; mod transport; @@ -159,13 +161,13 @@ impl EndpointConfig { impl fmt::Debug for EndpointConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("EndpointConfig") - .field("reset_key", &"[ elided ]") + // reset_key not debug .field("max_udp_payload_size", &self.max_udp_payload_size) - .field("cid_generator_factory", &"[ elided ]") + // cid_generator_factory not debug .field("supported_versions", &self.supported_versions) .field("grease_quic_bit", &self.grease_quic_bit) .field("rng_seed", &self.rng_seed) - .finish() + .finish_non_exhaustive() } } @@ -193,7 +195,7 @@ pub struct ServerConfig { /// Transport configuration to use for incoming connections pub transport: Arc, - /// TLS configuration used for incoming connections. + /// TLS configuration used for incoming connections /// /// Must be set to use TLS 1.3 only. pub crypto: Arc, @@ -201,7 +203,7 @@ pub struct ServerConfig { /// Used to generate one-time AEAD keys to protect handshake tokens pub(crate) token_key: Arc, - /// Microseconds after a stateless retry token was issued for which it's considered valid. + /// Duration after a retry token was issued for which it's considered valid pub(crate) retry_token_lifetime: Duration, /// Whether to allow clients to migrate to new addresses @@ -210,6 +212,16 @@ pub struct ServerConfig { /// rebinding. Enabled by default. pub(crate) migration: bool, + /// Duration after an address validation token was issued for which it's considered valid + pub(crate) validation_token_lifetime: Duration, + + /// Responsible for limiting clients' ability to reuse tokens from NEW_TOKEN frames + pub(crate) validation_token_log: Option>, + + /// Number of address validation tokens sent to a client via NEW_TOKEN frames when its path is + /// validated + pub(crate) validation_tokens_sent: u32, + pub(crate) preferred_address_v4: Option, pub(crate) preferred_address_v6: Option, @@ -224,6 +236,10 @@ impl ServerConfig { crypto: Arc, token_key: Arc, ) -> Self { + #[cfg(feature = "fastbloom")] + let validation_token_log = Some(Arc::new(BloomTokenLog::default()) as _); + #[cfg(not(feature = "fastbloom"))] + let validation_token_log = None; Self { transport: Arc::new(TransportConfig::default()), crypto, @@ -233,6 +249,10 @@ impl ServerConfig { migration: true, + validation_token_lifetime: Duration::from_secs(2 * 7 * 24 * 60 * 60), + validation_token_log, + validation_tokens_sent: 2, + preferred_address_v4: None, preferred_address_v6: None, @@ -248,18 +268,52 @@ impl ServerConfig { self } - /// Private key used to authenticate data included in handshake tokens. + /// Private key used to authenticate data included in handshake tokens pub fn token_key(&mut self, value: Arc) -> &mut Self { self.token_key = value; self } - /// Duration after a stateless retry token was issued for which it's considered valid. + /// Duration after a retry token was issued for which it's considered valid + /// + /// Defaults to 15 seconds. pub fn retry_token_lifetime(&mut self, value: Duration) -> &mut Self { self.retry_token_lifetime = value; self } + /// Duration after an address validation token was issued for which it's considered valid + /// + /// This refers only to tokens sent in NEW_TOKEN frames, in contrast to retry tokens. + /// + /// Defaults to 2 weeks. + pub fn validation_token_lifetime(&mut self, value: Duration) -> &mut Self { + self.validation_token_lifetime = value; + self + } + + /// Set a custom [`TokenLog`] + /// + /// Setting this to `None` makes the server ignore all address validation tokens (that is, + /// tokens originating from NEW_TOKEN frames--retry tokens may still be accepted). + /// + /// Defaults to a default [`BloomTokenLog`], unless the `fastbloom` default feature is + /// disabled, in which case this defaults to `None`. + pub fn validation_token_log(&mut self, log: Option>) -> &mut Self { + self.validation_token_log = log; + self + } + + /// Number of address validation tokens sent to a client when its path is validated + /// + /// This refers only to tokens sent in NEW_TOKEN frames, in contrast to retry tokens. + /// + /// Defaults to 2. + pub fn validation_tokens_sent(&mut self, value: u32) -> &mut Self { + self.validation_tokens_sent = value; + self + } + /// Whether to allow clients to migrate to new addresses /// /// Improves behavior for clients that move between different internet connections or suffer NAT @@ -269,14 +323,16 @@ impl ServerConfig { self } - /// The preferred IPv4 address that will be communicated to clients during handshaking. + /// The preferred IPv4 address that will be communicated to clients during handshaking + /// /// If the client is able to reach this address, it will switch to it. pub fn preferred_address_v4(&mut self, address: Option) -> &mut Self { self.preferred_address_v4 = address; self } - /// The preferred IPv6 address that will be communicated to clients during handshaking. + /// The preferred IPv6 address that will be communicated to clients during handshaking + /// /// If the client is able to reach this address, it will switch to it. pub fn preferred_address_v6(&mut self, address: Option) -> &mut Self { self.preferred_address_v6 = address; @@ -370,11 +426,14 @@ impl ServerConfig { impl fmt::Debug for ServerConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("ServerConfig") + fmt.debug_struct("ServerConfig") .field("transport", &self.transport) - .field("crypto", &"ServerConfig { elided }") - .field("token_key", &"[ elided ]") + // crypto not debug + // token not debug .field("retry_token_lifetime", &self.retry_token_lifetime) + .field("validation_token_lifetime", &self.validation_token_lifetime) + // validation_token_log not debug + .field("validation_tokens_sent", &self.validation_tokens_sent) .field("migration", &self.migration) .field("preferred_address_v4", &self.preferred_address_v4) .field("preferred_address_v6", &self.preferred_address_v6) @@ -384,7 +443,7 @@ impl fmt::Debug for ServerConfig { "incoming_buffer_size_total", &self.incoming_buffer_size_total, ) - .finish() + .finish_non_exhaustive() } } @@ -400,6 +459,9 @@ pub struct ClientConfig { /// Cryptographic configuration to use pub(crate) crypto: Arc, + /// Validation token store to use + pub(crate) token_store: Option>, + /// Provider that populates the destination connection ID of Initial Packets pub(crate) initial_dst_cid_provider: Arc ConnectionId + Send + Sync>, @@ -413,6 +475,7 @@ impl ClientConfig { Self { transport: Default::default(), crypto, + token_store: Some(Arc::new(TokenMemoryCache::<2>::default())), initial_dst_cid_provider: Arc::new(|| { RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid() }), @@ -421,7 +484,7 @@ impl ClientConfig { } /// Configure how to populate the destination CID of the initial packet when attempting to - /// establish a new connection. + /// establish a new connection /// /// By default, it's populated with random bytes with reasonable length, so unless you have /// a good reason, you do not need to change it. @@ -442,6 +505,19 @@ impl ClientConfig { self } + /// Set a custom [`TokenStore`] + /// + /// Defaults to a [`TokenMemoryCache`] limited to 256 servers and 2 tokens per server. This + /// default is chosen to complement `rustls`'s default [`ClientSessionStore`]. + /// + /// [`ClientSessionStore`]: rustls::client::ClientSessionStore + /// + /// Setting to `None` disables the use of tokens from NEW_TOKEN frames as a client. + pub fn token_store(&mut self, store: Option>) -> &mut Self { + self.token_store = store; + self + } + /// Set the QUIC version to use pub fn version(&mut self, version: u32) -> &mut Self { self.version = version; @@ -471,9 +547,10 @@ impl ClientConfig { impl fmt::Debug for ClientConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("ClientConfig") + fmt.debug_struct("ClientConfig") .field("transport", &self.transport) - .field("crypto", &"ClientConfig { elided }") + // token_store not debug + // crypto not debug .field("version", &self.version) .finish_non_exhaustive() } diff --git a/quinn-proto/src/config/transport.rs b/quinn-proto/src/config/transport.rs index d3fa86d75..82ddd1d19 100644 --- a/quinn-proto/src/config/transport.rs +++ b/quinn-proto/src/config/transport.rs @@ -410,9 +410,9 @@ impl fmt::Debug for TransportConfig { .field("allow_spin", allow_spin) .field("datagram_receive_buffer_size", datagram_receive_buffer_size) .field("datagram_send_buffer_size", datagram_send_buffer_size) - .field("congestion_controller_factory", &"[ opaque ]") + // congestion_controller_factory not debug .field("enable_segmentation_offload", enable_segmentation_offload) - .finish() + .finish_non_exhaustive() } } @@ -610,7 +610,7 @@ impl Default for MtuDiscoveryConfig { } } -/// Maximum duration of inactivity to accept before timing out the connection. +/// Maximum duration of inactivity to accept before timing out the connection /// /// This wraps an underlying [`VarInt`], representing the duration in milliseconds. Values can be /// constructed by converting directly from `VarInt`, or using `TryFrom`. diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index babb0d757..30f224309 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -19,7 +19,7 @@ use crate::{ coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, KeyPair, Keys, PacketKey}, - frame::{self, Close, Datagram, FrameStruct}, + frame::{self, Close, Datagram, FrameStruct, NewToken}, packet::{ FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, SpaceId, @@ -29,11 +29,11 @@ use crate::{ ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint, EndpointEvent, EndpointEventInner, }, - token::ResetToken, + token::{ResetToken, Token, TokenInner, ValidationTokenInner}, transport_parameters::TransportParameters, - Dir, Duration, EndpointConfig, Frame, Instant, Side, StreamId, Transmit, TransportError, - TransportErrorCode, VarInt, INITIAL_MTU, MAX_CID_SIZE, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, - TIMER_GRANULARITY, + Dir, Duration, EndpointConfig, Frame, Instant, Side, StreamId, SystemTime, TokenStore, + Transmit, TransportError, TransportErrorCode, VarInt, INITIAL_MTU, MAX_CID_SIZE, + MAX_STREAM_COUNT, MIN_INITIAL_SIZE, TIMER_GRANULARITY, }; mod ack_frequency; @@ -129,7 +129,6 @@ use timer::{Timer, TimerTable}; /// events or timeouts with different instants must not be interleaved. pub struct Connection { endpoint_config: Arc, - server_config: Option>, config: Arc, rng: StdRng, crypto: Box, @@ -145,7 +144,7 @@ pub struct Connection { allow_mtud: bool, prev_path: Option<(ConnectionId, PathData)>, state: State, - side: Side, + side_state: SideState, /// Whether or not 0-RTT was enabled during the handshake. Does not imply acceptance. zero_rtt_enabled: bool, /// Set if 0-RTT is supported, then cleared when no longer needed. @@ -191,9 +190,6 @@ pub struct Connection { authentication_failures: u64, /// Why the connection was lost, if it has been error: Option, - /// Sent in every outgoing Initial packet. Always empty for servers and after Initial keys are - /// discarded. - retry_token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, @@ -239,15 +235,63 @@ pub struct Connection { version: u32, } +/// Fields of `Connection` specific to it being client-side or server-side +enum SideState { + Client { + /// Sent in every outgoing Initial packet. Always empty after Initial keys are discarded + token: Bytes, + token_store: Option>, + server_name: String, + }, + Server { + server_config: Arc, + }, +} + +impl SideState { + fn side(&self) -> Side { + match *self { + Self::Client { .. } => Side::Client, + Self::Server { .. } => Side::Server, + } + } +} + +/// Parameters to `Connection::new` specific to it being client-side or server-side +pub(crate) enum SideArgs { + Client { + token_store: Option>, + server_name: String, + }, + Server { + server_config: Arc, + pref_addr_cid: Option, + path_validated: bool, + }, +} + +impl SideArgs { + pub(crate) fn side(&self) -> Side { + match *self { + Self::Client { .. } => Side::Client, + Self::Server { .. } => Side::Server, + } + } + pub(crate) fn pref_addr_cid(&self) -> Option { + match *self { + Self::Client { .. } => None, + Self::Server { pref_addr_cid, .. } => pref_addr_cid, + } + } +} + impl Connection { pub(crate) fn new( endpoint_config: Arc, - server_config: Option>, config: Arc, init_cid: ConnectionId, loc_cid: ConnectionId, rem_cid: ConnectionId, - pref_addr_cid: Option, remote: SocketAddr, local_ip: Option, crypto: Box, @@ -256,15 +300,37 @@ impl Connection { version: u32, allow_mtud: bool, rng_seed: [u8; 32], - path_validated: bool, + side_args: SideArgs, ) -> Self { - let side = if server_config.is_some() { - Side::Server - } else { - Side::Client + let (side_state, pref_addr_cid, path_validated) = match side_args { + SideArgs::Client { + token_store, + server_name, + } => ( + SideState::Client { + token: token_store + .as_ref() + .and_then(|store| store.take(&server_name)) + .unwrap_or_default(), + token_store, + server_name, + }, + None, + true, + ), + SideArgs::Server { + server_config, + pref_addr_cid, + path_validated, + } => ( + SideState::Server { server_config }, + pref_addr_cid, + path_validated, + ), }; + let side = side_state.side(); let initial_space = PacketSpace { - crypto: Some(crypto.initial_keys(&init_cid, side)), + crypto: Some(crypto.initial_keys(&init_cid, side_state.side())), ..PacketSpace::new(now) }; let state = State::Handshake(state::Handshake { @@ -275,7 +341,6 @@ impl Connection { let mut rng = StdRng::from_seed(rng_seed); let mut this = Self { endpoint_config, - server_config, crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, @@ -289,8 +354,8 @@ impl Connection { allow_mtud, local_ip, prev_path: None, - side, state, + side_state, zero_rtt_enabled: false, zero_rtt_crypto: None, key_phase: false, @@ -323,7 +388,6 @@ impl Connection { timers: TimerTable::default(), authentication_failures: 0, error: None, - retry_token: Bytes::new(), #[cfg(test)] packet_number_filter: match config.deterministic_packet_numbers { false => PacketNumberFilter::new(&mut rng), @@ -360,6 +424,9 @@ impl Connection { stats: ConnectionStats::default(), version, }; + if path_validated { + this.on_path_validated(); + } if side.is_client() { // Kick off the connection this.write_crypto(); @@ -420,7 +487,7 @@ impl Connection { /// Provide control over streams #[must_use] pub fn recv_stream(&mut self, id: StreamId) -> RecvStream<'_> { - assert!(id.dir() == Dir::Bi || id.initiator() != self.side); + assert!(id.dir() == Dir::Bi || id.initiator() != self.side_state.side()); RecvStream { id, state: &mut self.streams, @@ -431,7 +498,7 @@ impl Connection { /// Provide control over streams #[must_use] pub fn send_stream(&mut self, id: StreamId) -> SendStream<'_> { - assert!(id.dir() == Dir::Bi || id.initiator() == self.side); + assert!(id.dir() == Dir::Bi || id.initiator() == self.side_state.side()); SendStream { id, state: &mut self.streams, @@ -798,7 +865,7 @@ impl Connection { if self.spaces[SpaceId::Initial].crypto.is_some() && space_id == SpaceId::Handshake - && self.side.is_client() + && self.side_state.side().is_client() { // A client stops both sending and processing Initial packets when it // sends its first Handshake packet. @@ -826,8 +893,8 @@ impl Connection { coalesce = coalesce && !builder.short_header; // https://tools.ietf.org/html/draft-ietf-quic-transport-34#section-14.1 - pad_datagram |= - space_id == SpaceId::Initial && (self.side.is_client() || ack_eliciting); + pad_datagram |= space_id == SpaceId::Initial + && (self.side_state.side().is_client() || ack_eliciting); if close { trace!("sending CONNECTION_CLOSE"); @@ -1045,7 +1112,7 @@ impl Connection { if self.spaces[space_id].crypto.is_none() && (space_id != SpaceId::Data || self.zero_rtt_crypto.is_none() - || self.side.is_server()) + || self.side_state.side().is_server()) { // No keys available for this space return SendableFrames::empty(); @@ -1076,7 +1143,10 @@ impl Connection { // forbids migration, drop the datagram. This could be relaxed to heuristically // permit NAT-rebinding-like migration. if remote != self.path.remote - && self.server_config.as_ref().map_or(true, |x| !x.migration) + && match self.side_state { + SideState::Server { ref server_config } => !server_config.migration, + SideState::Client { .. } => true, + } { trace!("discarding packet from unrecognized peer {}", remote); return; @@ -1297,7 +1367,7 @@ impl Connection { /// Look up whether we're the client or server of this Connection pub fn side(&self) -> Side { - self.side + self.side_state.side() } /// The latest socket address for this connection's peer @@ -1774,7 +1844,7 @@ impl Connection { #[allow(clippy::suspicious_operation_groupings)] fn peer_completed_address_validation(&self) -> bool { - if self.side.is_server() || self.state.is_closed() { + if self.side_state.side().is_server() || self.state.is_closed() { return true; } // The server is guaranteed to have validated our address if any of our handshake or 1-RTT @@ -1859,7 +1929,7 @@ impl Connection { Some(x) => x, None => return, }; - if self.side.is_server() { + if self.side_state.side().is_server() { if self.spaces[SpaceId::Initial].crypto.is_some() && space_id == SpaceId::Handshake { // A server stops sending and processing Initial packets when it receives its first Handshake packet. self.discard_space(now, SpaceId::Initial); @@ -1874,7 +1944,7 @@ impl Connection { if packet >= space.rx_packet { space.rx_packet = packet; // Update outgoing spin bit, inverting iff we're the client - self.spin = self.side.is_client() ^ spin; + self.spin = self.side_state.side().is_client() ^ spin; } } @@ -1920,7 +1990,7 @@ impl Connection { ) -> Result<(), ConnectionError> { let span = trace_span!("first recv"); let _guard = span.enter(); - debug_assert!(self.side.is_server()); + debug_assert!(self.side_state.side().is_server()); let len = packet.header_data.len() + packet.payload.len(); self.path.total_recvd = len as u64; @@ -1952,7 +2022,7 @@ impl Connection { Some(x) => x, None => return, }; - if self.side.is_client() { + if self.side_state.side().is_client() { match self.crypto.transport_parameters() { Ok(params) => { let params = params @@ -2058,7 +2128,7 @@ impl Connection { let offset = self.spaces[space].crypto_offset; let outgoing = Bytes::from(outgoing); if let State::Handshake(ref mut state) = self.state { - if space == SpaceId::Initial && offset == 0 && self.side.is_client() { + if space == SpaceId::Initial && offset == 0 && self.side_state.side().is_client() { state.client_hello = Some(outgoing.clone()); } } @@ -2090,7 +2160,7 @@ impl Connection { self.spaces[space].crypto = Some(crypto); debug_assert!(space as usize > self.highest_space as usize); self.highest_space = space; - if space == SpaceId::Data && self.side.is_client() { + if space == SpaceId::Data && self.side_state.side().is_client() { // Discard 0-RTT keys because 1-RTT keys are available. self.zero_rtt_crypto = None; } @@ -2101,7 +2171,9 @@ impl Connection { trace!("discarding {:?} keys", space_id); if space_id == SpaceId::Initial { // No longer needed - self.retry_token = Bytes::new(); + if let SideState::Client { ref mut token, .. } = self.side_state { + *token = Bytes::new(); + } } let space = &mut self.spaces[space_id]; space.crypto = None; @@ -2236,7 +2308,7 @@ impl Connection { } else { if let Header::Initial(InitialHeader { ref token, .. }) = packet.header { if let State::Handshake(ref hs) = self.state { - if self.side.is_server() && token != &hs.expected_token { + if self.side_state.side().is_server() && token != &hs.expected_token { // Clients must send the same retry token in every Initial. Initial // packets can be spoofed, so we discard rather than killing the // connection. @@ -2362,7 +2434,7 @@ impl Connection { Header::Retry { src_cid: rem_cid, .. } => { - if self.side.is_server() { + if self.side_state.side().is_server() { return Err(TransportError::PROTOCOL_VIOLATION("client sent Retry").into()); } @@ -2398,7 +2470,7 @@ impl Connection { self.discard_space(now, SpaceId::Initial); // Make sure we clean up after any retransmitted Initials self.spaces[SpaceId::Initial] = PacketSpace { - crypto: Some(self.crypto.initial_keys(&rem_cid, self.side)), + crypto: Some(self.crypto.initial_keys(&rem_cid, self.side_state.side())), next_packet_number: self.spaces[SpaceId::Initial].next_packet_number, crypto_offset: client_hello.len() as u64, ..PacketSpace::new(now) @@ -2420,7 +2492,10 @@ impl Connection { self.streams.retransmit_all_for_0rtt(); let token_len = packet.payload.len() - 16; - self.retry_token = packet.payload.freeze().split_to(token_len); + let SideState::Client { ref mut token, .. } = self.side_state else { + unreachable!("we already short-circuited if we're server"); + }; + *token = packet.payload.freeze().split_to(token_len); self.state = State::Handshake(state::Handshake { expected_token: Bytes::new(), rem_cid_set: false, @@ -2440,7 +2515,7 @@ impl Connection { ); return Ok(()); } - self.path.validated = true; + self.on_path_validated(); self.process_early_payload(now, packet)?; if self.state.is_closed() { @@ -2452,7 +2527,7 @@ impl Connection { return Ok(()); } - if self.side.is_client() { + if self.side_state.side().is_client() { // Client-only because server params were set from the client's Initial let params = self.crypto @@ -2465,7 +2540,7 @@ impl Connection { if self.has_0rtt() { if !self.crypto.early_data_accepted().unwrap() { - debug_assert!(self.side.is_client()); + debug_assert!(self.side_state.side().is_client()); debug!("0-RTT rejected"); self.accepted_0rtt = false; self.streams.zero_rtt_rejected(); @@ -2523,7 +2598,7 @@ impl Connection { let starting_space = self.highest_space; self.process_early_payload(now, packet)?; - if self.side.is_server() + if self.side_state.side().is_server() && starting_space == SpaceId::Initial && self.highest_space != SpaceId::Initial { @@ -2718,7 +2793,7 @@ impl Connection { trace!("new path validated"); self.timers.stop(Timer::PathValidation); self.path.challenge = None; - self.path.validated = true; + self.on_path_validated(); if let Some((_, ref mut prev_path)) = self.prev_path { prev_path.challenge = None; prev_path.challenge_pending = false; @@ -2745,7 +2820,7 @@ impl Connection { debug!(offset, "peer claims to be blocked at connection level"); } Frame::StreamDataBlocked { id, offset } => { - if id.initiator() == self.side && id.dir() == Dir::Uni { + if id.initiator() == self.side_state.side() && id.dir() == Dir::Uni { debug!("got STREAM_DATA_BLOCKED on send-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( "STREAM_DATA_BLOCKED on send-only stream", @@ -2768,7 +2843,7 @@ impl Connection { ); } Frame::StopSending(frame::StopSending { id, error_code }) => { - if id.initiator() != self.side { + if id.initiator() != self.side_state.side() { if id.dir() == Dir::Uni { debug!("got STOP_SENDING on recv-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( @@ -2848,21 +2923,28 @@ impl Connection { } }; - if self.side.is_server() && self.rem_cids.active_seq() == 0 { + if self.side_state.side().is_server() && self.rem_cids.active_seq() == 0 { // We're a server still using the initial remote CID for the client, so // let's switch immediately to enable clientside stateless resets. self.update_rem_cid(); } } - Frame::NewToken { token } => { - if self.side.is_server() { + Frame::NewToken(NewToken { token }) => { + if self.side_state.side().is_server() { return Err(TransportError::PROTOCOL_VIOLATION("client sent NEW_TOKEN")); } if token.is_empty() { return Err(TransportError::FRAME_ENCODING_ERROR("empty token")); } trace!("got new token"); - // TODO: Cache, or perhaps forward to user? + if let SideState::Client { + token_store: Some(ref store), + ref server_name, + .. + } = self.side_state + { + store.insert(server_name, token); + } } Frame::Datagram(datagram) => { if self @@ -2900,7 +2982,7 @@ impl Connection { .set_immediate_ack_required(); } Frame::HandshakeDone => { - if self.side.is_server() { + if self.side_state.side().is_server() { return Err(TransportError::PROTOCOL_VIOLATION( "client sent HANDSHAKE_DONE", )); @@ -2938,11 +3020,11 @@ impl Connection { && !is_probing_packet && number == self.spaces[SpaceId::Data].rx_packet { + let SideState::Server { ref server_config } = self.side_state else { + panic!("packets from unknown remote should be dropped by clients"); + }; debug_assert!( - self.server_config - .as_ref() - .expect("packets from unknown remote should be dropped by clients") - .migration, + server_config.migration, "migration-initiating packets should have been dropped immediately" ); self.migrate(now, remote); @@ -3247,6 +3329,39 @@ impl Connection { self.datagrams.send_blocked = false; } + // NEW_TOKEN + while let Some(remote_addr) = space.pending.new_tokens.pop() { + debug_assert_eq!(space_id, SpaceId::Data); + let SideState::Server { ref server_config } = self.side_state else { + panic!("NEW_TOKEN frames should not be enqueued by clients"); + }; + + if remote_addr != self.path.remote { + continue; + } + + let token_inner = TokenInner::Validation(ValidationTokenInner { + issued: SystemTime::now(), + }); + let token = Token::new(&mut self.rng, token_inner) + .encode(&*server_config.token_key, &self.path.remote); + let new_token = NewToken { + token: token.into(), + }; + + if buf.len() + new_token.size() >= max_size { + space.pending.new_tokens.push(remote_addr); + break; + } + + new_token.encode(buf); + sent.retransmits + .get_or_create() + .new_tokens + .push(remote_addr); + self.stats.frame_tx.new_token += 1; + } + // STREAM if space_id == SpaceId::Data { sent.stream_frames = @@ -3312,7 +3427,7 @@ impl Connection { /// Handle transport parameters received from the peer fn handle_peer_params(&mut self, params: TransportParameters) -> Result<(), TransportError> { if Some(self.orig_rem_cid) != params.initial_src_cid - || (self.side.is_client() + || (self.side_state.side().is_client() && (Some(self.initial_dst_cid) != params.original_dst_cid || self.retry_src_cid != params.retry_src_cid)) { @@ -3608,6 +3723,18 @@ impl Connection { // but that would needlessly prevent sending datagrams during 0-RTT. key.map_or(16, |x| x.tag_len()) } + + /// Mark the path as validated, and enqueue NEW_TOKEN frames to be sent as appropriate + fn on_path_validated(&mut self) { + self.path.validated = true; + if let SideState::Server { ref server_config } = self.side_state { + let new_tokens = &mut self.spaces[SpaceId::Data as usize].pending.new_tokens; + new_tokens.clear(); + for _ in 0..server_config.validation_tokens_sent { + new_tokens.push(self.path.remote); + } + } + } } impl fmt::Debug for Connection { diff --git a/quinn-proto/src/connection/packet_builder.rs b/quinn-proto/src/connection/packet_builder.rs index 868a8c7ca..10bccb8da 100644 --- a/quinn-proto/src/connection/packet_builder.rs +++ b/quinn-proto/src/connection/packet_builder.rs @@ -4,6 +4,7 @@ use tracing::{trace, trace_span}; use super::{spaces::SentPacket, Connection, SentFrames}; use crate::{ + connection::SideState, frame::{self, Close}, packet::{Header, InitialHeader, LongType, PacketNumber, PartialEncode, SpaceId, FIXED_BIT}, ConnectionId, Instant, TransportError, TransportErrorCode, @@ -113,7 +114,10 @@ impl PacketBuilder { SpaceId::Initial => Header::Initial(InitialHeader { src_cid: conn.handshake_cid, dst_cid, - token: conn.retry_token.clone(), + token: match conn.side_state { + SideState::Client { ref token, .. } => token.clone(), + SideState::Server { .. } => Bytes::new(), + }, number, version, }), diff --git a/quinn-proto/src/connection/spaces.rs b/quinn-proto/src/connection/spaces.rs index ed58b51c1..9c91cf875 100644 --- a/quinn-proto/src/connection/spaces.rs +++ b/quinn-proto/src/connection/spaces.rs @@ -2,6 +2,7 @@ use std::{ cmp, collections::{BTreeMap, VecDeque}, mem, + net::SocketAddr, ops::{Bound, Index, IndexMut}, }; @@ -309,6 +310,13 @@ pub struct Retransmits { pub(super) retire_cids: Vec, pub(super) ack_frequency: bool, pub(super) handshake_done: bool, + /// Two notable things about `new_tokens`: + /// + /// - NEW_TOKEN frames from an old path are not retransmitted on a new path + /// - If a token is lost, a new randomly generated token is re-transmitted rather than the + /// original. This is so that if both transmissions end up being received, the client won't + /// risk sending the same token twice. + pub(super) new_tokens: Vec, } impl Retransmits { @@ -326,6 +334,7 @@ impl Retransmits { && self.retire_cids.is_empty() && !self.ack_frequency && !self.handshake_done + && self.new_tokens.is_empty() } } @@ -347,6 +356,7 @@ impl ::std::ops::BitOrAssign for Retransmits { self.retire_cids.extend(rhs.retire_cids); self.ack_frequency |= rhs.ack_frequency; self.handshake_done |= rhs.handshake_done; + self.new_tokens.extend_from_slice(&rhs.new_tokens); } } diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index bd18b011a..a3c72226e 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -18,7 +18,7 @@ use crate::{ cid_generator::ConnectionIdGenerator, coding::BufMutExt, config::{ClientConfig, EndpointConfig, ServerConfig}, - connection::{Connection, ConnectionError}, + connection::{Connection, ConnectionError, SideArgs}, crypto::{self, Keys, UnsupportedVersion}, frame, packet::{ @@ -29,9 +29,9 @@ use crate::{ ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint, EndpointEvent, EndpointEventInner, IssuedCid, }, - token::TokenDecodeError, + token::{IncomingTokenState, RetryTokenInner, TokenInner, ValidationError}, transport_parameters::{PreferredAddress, TransportParameters}, - Duration, Instant, ResetToken, RetryToken, Side, SystemTime, Transmit, TransportConfig, + Duration, Instant, ResetToken, Side, SystemTime, Token, Transmit, TransportConfig, TransportError, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, }; @@ -423,16 +423,17 @@ impl Endpoint { remote_id, loc_cid, remote_id, - None, FourTuple { remote, local_ip: None, }, now, tls, - None, config.transport, - true, + SideArgs::Client { + token_store: config.token_store, + server_name: server_name.into(), + }, ); Ok((ch, conn)) } @@ -496,28 +497,16 @@ impl Endpoint { let server_config = self.server_config.as_ref().unwrap().clone(); - let (retry_src_cid, orig_dst_cid) = if header.token.is_empty() { - (None, header.dst_cid) + let token_state = if header.token.is_empty() { + IncomingTokenState::default(&header) } else { - match RetryToken::from_bytes( - &*server_config.token_key, - &addresses.remote, - &header.dst_cid, - &header.token, - ) { - Ok(token) - if token.issued + server_config.retry_token_lifetime > SystemTime::now() => - { - (Some(header.dst_cid), token.orig_dst_cid) - } - Err(TokenDecodeError::UnknownToken) => { - // Token may have been generated by an incompatible endpoint, e.g. a - // different version or a neighbor behind the same load balancer. We - // can't interpret it, so we proceed as if there was no token. - (None, header.dst_cid) - } - _ => { - debug!("rejecting invalid stateless retry token"); + let token = Token::decode(&*server_config.token_key, &addresses.remote, &header.token); + let token_state = token.and_then(|token| token.validate(&header, &server_config)); + match token_state { + Ok(token_state) => token_state, + Err(ValidationError::Ignore) => IncomingTokenState::default(&header), + Err(ValidationError::InvalidRetry) => { + debug!("rejecting invalid retry token"); return Some(DatagramEvent::Response(self.initial_close( header.version, addresses, @@ -545,8 +534,7 @@ impl Endpoint { }, rest, crypto, - retry_src_cid, - orig_dst_cid, + token_state, incoming_idx, improper_drop_warner: IncomingImproperDropWarner, })) @@ -636,8 +624,8 @@ impl Endpoint { &mut self.rng, ); params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, &loc_cid)); - params.original_dst_cid = Some(incoming.orig_dst_cid); - params.retry_src_cid = incoming.retry_src_cid; + params.original_dst_cid = Some(incoming.token_state.orig_dst_cid); + params.retry_src_cid = incoming.token_state.retry_src_cid; let mut pref_addr_cid = None; if server_config.preferred_address_v4.is_some() || server_config.preferred_address_v6.is_some() @@ -660,13 +648,15 @@ impl Endpoint { dst_cid, loc_cid, src_cid, - pref_addr_cid, incoming.addresses, incoming.received_at, tls, - Some(server_config), transport_config, - remote_address_validated, + SideArgs::Server { + server_config, + pref_addr_cid, + path_validated: remote_address_validated, + }, ); self.index.insert_initial(dst_cid, ch); @@ -753,9 +743,9 @@ impl Endpoint { /// Respond with a retry packet, requiring the client to retry with address validation /// - /// Errors if `incoming.remote_address_validated()` is true. + /// Errors if `incoming.may_retry()` is false. pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec) -> Result { - if incoming.remote_address_validated() { + if !incoming.may_retry() { return Err(RetryError(incoming)); } @@ -772,15 +762,12 @@ impl Endpoint { // retried by the application layer. let loc_cid = self.local_cid_generator.generate_cid(); - let token = RetryToken { + let token_inner = RetryTokenInner { orig_dst_cid: incoming.packet.header.dst_cid, issued: SystemTime::now(), - } - .encode( - &*server_config.token_key, - &incoming.addresses.remote, - &loc_cid, - ); + }; + let token = Token::new(&mut self.rng, TokenInner::Retry(token_inner)) + .encode(&*server_config.token_key, &incoming.addresses.remote); let header = Header::Retry { src_cid: loc_cid, @@ -829,28 +816,22 @@ impl Endpoint { init_cid: ConnectionId, loc_cid: ConnectionId, rem_cid: ConnectionId, - pref_addr_cid: Option, addresses: FourTuple, now: Instant, tls: Box, - server_config: Option>, transport_config: Arc, - path_validated: bool, + side_args: SideArgs, ) -> Connection { let mut rng_seed = [0; 32]; self.rng.fill_bytes(&mut rng_seed); - let side = match server_config.is_some() { - true => Side::Server, - false => Side::Client, - }; + let side = side_args.side(); + let pref_addr_cid = side_args.pref_addr_cid(); let conn = Connection::new( self.config.clone(), - server_config, transport_config, init_cid, loc_cid, rem_cid, - pref_addr_cid, addresses.remote, addresses.local_ip, tls, @@ -859,7 +840,7 @@ impl Endpoint { version, self.allow_mtud, rng_seed, - path_validated, + side_args, ); let mut cids_issued = 0; @@ -1202,37 +1183,46 @@ pub struct Incoming { packet: InitialPacket, rest: Option, crypto: Keys, - retry_src_cid: Option, - orig_dst_cid: ConnectionId, + token_state: IncomingTokenState, incoming_idx: usize, improper_drop_warner: IncomingImproperDropWarner, } impl Incoming { - /// The local IP address which was used when the peer established - /// the connection + /// The local IP address which was used when the peer established the connection /// - /// This has the same behavior as [`Connection::local_ip`] + /// This has the same behavior as [`Connection::local_ip`]. pub fn local_ip(&self) -> Option { self.addresses.local_ip } - /// The peer's UDP address. + /// The peer's UDP address pub fn remote_address(&self) -> SocketAddr { self.addresses.remote } - /// Whether the socket address that is initiating this connection has been validated. + /// Whether the socket address that is initiating this connection has been validated /// /// This means that the sender of the initial packet has proved that they can receive traffic /// sent to `self.remote_address()`. + /// + /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true. + /// The inverse is not guaranteed. pub fn remote_address_validated(&self) -> bool { - self.retry_src_cid.is_some() + self.token_state.validated + } + + /// Whether it is legal to respond with a retry packet + /// + /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true. + /// The inverse is not guaranteed. + pub fn may_retry(&self) -> bool { + self.token_state.retry_src_cid.is_none() } /// The original destination connection ID sent by the client pub fn orig_dst_cid(&self) -> &ConnectionId { - &self.orig_dst_cid + &self.token_state.orig_dst_cid } } @@ -1243,8 +1233,7 @@ impl fmt::Debug for Incoming { .field("ecn", &self.ecn) // packet doesn't implement debug // rest is too big and not meaningful enough - .field("retry_src_cid", &self.retry_src_cid) - .field("orig_dst_cid", &self.orig_dst_cid) + .field("token_state", &self.token_state) .field("incoming_idx", &self.incoming_idx) // improper drop warner contains no information .finish_non_exhaustive() @@ -1308,8 +1297,7 @@ pub struct AcceptError { pub response: Option, } -/// Error for attempting to retry an [`Incoming`] which already bears an address -/// validation token from a previous retry +/// Error for attempting to retry an [`Incoming`] which already bears a token from a previous retry #[derive(Debug, Error)] #[error("retry() with validated Incoming")] pub struct RetryError(Incoming); diff --git a/quinn-proto/src/frame.rs b/quinn-proto/src/frame.rs index 0bc7f34ad..421acc313 100644 --- a/quinn-proto/src/frame.rs +++ b/quinn-proto/src/frame.rs @@ -147,7 +147,7 @@ pub(crate) enum Frame { ResetStream(ResetStream), StopSending(StopSending), Crypto(Crypto), - NewToken { token: Bytes }, + NewToken(NewToken), Stream(Stream), MaxData(VarInt), MaxStreamData { id: StreamId, offset: u64 }, @@ -200,7 +200,7 @@ impl Frame { PathResponse(_) => FrameType::PATH_RESPONSE, NewConnectionId { .. } => FrameType::NEW_CONNECTION_ID, Crypto(_) => FrameType::CRYPTO, - NewToken { .. } => FrameType::NEW_TOKEN, + NewToken(_) => FrameType::NEW_TOKEN, Datagram(_) => FrameType(*DATAGRAM_TYS.start()), AckFrequency(_) => FrameType::ACK_FREQUENCY, ImmediateAck => FrameType::IMMEDIATE_ACK, @@ -525,6 +525,23 @@ impl Crypto { } } +#[derive(Debug, Clone)] +pub(crate) struct NewToken { + pub(crate) token: Bytes, +} + +impl NewToken { + pub(crate) fn encode(&self, out: &mut W) { + out.write(FrameType::NEW_TOKEN); + out.write_var(self.token.len() as u64); + out.put_slice(&self.token); + } + + pub(crate) fn size(&self) -> usize { + 1 + VarInt::from_u64(self.token.len() as u64).unwrap().size() + self.token.len() + } +} + pub(crate) struct Iter { // TODO: ditch io::Cursor after bytes 0.5 bytes: io::Cursor, @@ -676,9 +693,9 @@ impl Iter { offset: self.bytes.get_var()?, data: self.take_len()?, }), - FrameType::NEW_TOKEN => Frame::NewToken { + FrameType::NEW_TOKEN => Frame::NewToken(NewToken { token: self.take_len()?, - }, + }), FrameType::HANDSHAKE_DONE => Frame::HandshakeDone, FrameType::ACK_FREQUENCY => Frame::AckFrequency(AckFrequency { sequence: self.bytes.get()?, diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 79eddc827..dd1d2aa2c 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -86,7 +86,16 @@ pub use crate::cid_generator::{ }; mod token; -use token::{ResetToken, RetryToken}; +use token::{ResetToken, Token}; +pub use token::{TokenLog, TokenReuseError}; + +mod token_store; +pub use token_store::{TokenMemoryCache, TokenStore}; + +#[cfg(feature = "fastbloom")] +mod bloom_token_log; +#[cfg(feature = "fastbloom")] +pub use bloom_token_log::BloomTokenLog; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index cb39b351f..7dc9dcbb4 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -2,7 +2,7 @@ use std::{ convert::TryInto, mem, net::{Ipv4Addr, Ipv6Addr, SocketAddr}, - sync::Arc, + sync::{Arc, Mutex}, }; use assert_matches::assert_matches; @@ -186,7 +186,7 @@ fn draft_version_compat() { fn stateless_retry() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; + pair.server.handle_incoming = Box::new(validate_incoming); let (client_ch, _server_ch) = pair.connect(); pair.client .connections @@ -200,6 +200,203 @@ fn stateless_retry() { assert_eq!(pair.server.known_cids(), 0); } +#[cfg(feature = "fastbloom")] +#[test] +fn use_token() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let client_config = client_config(); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + +#[cfg(feature = "fastbloom")] +#[test] +fn retry_then_use_token() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let client_config = client_config(); + pair.server.handle_incoming = Box::new(validate_incoming); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + +#[cfg(feature = "fastbloom")] +#[test] +fn use_token_then_retry() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let client_config = client_config(); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new({ + let mut i = 0; + move |incoming| { + if i == 0 { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + i += 1; + IncomingConnectionBehavior::Retry + } else if i == 1 { + assert!(incoming.remote_address_validated()); + assert!(!incoming.may_retry()); + i += 1; + IncomingConnectionBehavior::Accept + } else { + panic!("too many handle_incoming iterations") + } + } + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + +#[cfg(feature = "fastbloom")] +#[test] +fn use_same_token_twice() { + #[derive(Default)] + struct EvilTokenStore(Mutex); + + impl TokenStore for EvilTokenStore { + fn insert(&self, _server_name: &str, token: Bytes) { + let mut lock = self.0.lock().unwrap(); + if lock.is_empty() { + *lock = token; + } + } + + fn take(&self, _server_name: &str) -> Option { + let lock = self.0.lock().unwrap(); + if lock.is_empty() { + None + } else { + Some(lock.clone()) + } + } + } + + let _guard = subscribe(); + let mut pair = Pair::default(); + let mut client_config = client_config(); + client_config.token_store(Some(Arc::new(EvilTokenStore::default()))); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(!incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_3, _server_ch_3) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_3) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + #[test] fn server_stateless_reset() { let _guard = subscribe(); @@ -554,7 +751,7 @@ fn high_latency_handshake() { fn zero_rtt_happypath() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; + pair.server.handle_incoming = Box::new(validate_incoming); let config = client_config(); // Establish normal connection @@ -723,7 +920,7 @@ fn test_zero_rtt_incoming_limit(configure_server: CLIENT_PORTS.lock().unwrap().next().unwrap(), ); info!("resuming session"); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Wait; + pair.server.handle_incoming = Box::new(|_| IncomingConnectionBehavior::Wait); let client_ch = pair.begin_connect(config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); @@ -2993,7 +3190,7 @@ fn pure_sender_voluntarily_acks() { fn reject_manually() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::RejectAll; + pair.server.handle_incoming = Box::new(|_| IncomingConnectionBehavior::Reject); // The server should now reject incoming connections. let client_ch = pair.begin_connect(client_config()); @@ -3013,7 +3210,20 @@ fn reject_manually() { fn validate_then_reject_manually() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::ValidateThenReject; + pair.server.handle_incoming = Box::new({ + let mut i = 0; + move |incoming| { + if incoming.remote_address_validated() { + assert_eq!(i, 1); + i += 1; + IncomingConnectionBehavior::Reject + } else { + assert_eq!(i, 0); + i += 1; + IncomingConnectionBehavior::Retry + } + } + }); // The server should now retry and reject incoming connections. let client_ch = pair.begin_connect(client_config()); diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 7e927e203..f8f9bcb32 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -297,19 +297,26 @@ pub(super) struct TestEndpoint { conn_events: HashMap>, pub(super) captured_packets: Vec>, pub(super) capture_inbound_packets: bool, - pub(super) incoming_connection_behavior: IncomingConnectionBehavior, + pub(super) handle_incoming: Box IncomingConnectionBehavior>, pub(super) waiting_incoming: Vec, } #[derive(Debug, Copy, Clone)] pub(super) enum IncomingConnectionBehavior { - AcceptAll, - RejectAll, - Validate, - ValidateThenReject, + Accept, + Reject, + Retry, Wait, } +pub(super) fn validate_incoming(incoming: &Incoming) -> IncomingConnectionBehavior { + if incoming.remote_address_validated() { + IncomingConnectionBehavior::Accept + } else { + IncomingConnectionBehavior::Retry + } +} + impl TestEndpoint { fn new(endpoint: Endpoint, addr: SocketAddr) -> Self { let socket = if env::var_os("SSLKEYLOGFILE").is_some() { @@ -334,7 +341,7 @@ impl TestEndpoint { conn_events: HashMap::default(), captured_packets: Vec::new(), capture_inbound_packets: false, - incoming_connection_behavior: IncomingConnectionBehavior::AcceptAll, + handle_incoming: Box::new(|_| IncomingConnectionBehavior::Accept), waiting_incoming: Vec::new(), } } @@ -364,26 +371,15 @@ impl TestEndpoint { { match event { DatagramEvent::NewConnection(incoming) => { - match self.incoming_connection_behavior { - IncomingConnectionBehavior::AcceptAll => { + match (self.handle_incoming)(&incoming) { + IncomingConnectionBehavior::Accept => { let _ = self.try_accept(incoming, now); } - IncomingConnectionBehavior::RejectAll => { + IncomingConnectionBehavior::Reject => { self.reject(incoming); } - IncomingConnectionBehavior::Validate => { - if incoming.remote_address_validated() { - let _ = self.try_accept(incoming, now); - } else { - self.retry(incoming); - } - } - IncomingConnectionBehavior::ValidateThenReject => { - if incoming.remote_address_validated() { - self.reject(incoming); - } else { - self.retry(incoming); - } + IncomingConnectionBehavior::Retry => { + self.retry(incoming); } IncomingConnectionBehavior::Wait => { self.waiting_incoming.push(incoming); diff --git a/quinn-proto/src/token.rs b/quinn-proto/src/token.rs index d4de5200c..feb687cae 100644 --- a/quinn-proto/src/token.rs +++ b/quinn-proto/src/token.rs @@ -1,82 +1,295 @@ use std::{ fmt, io, + mem::size_of, net::{IpAddr, SocketAddr}, }; use bytes::{Buf, BufMut}; +use rand::Rng; +use tracing::*; use crate::{ coding::{BufExt, BufMutExt}, crypto::{CryptoError, HandshakeTokenKey, HmacKey}, + packet::InitialHeader, shared::ConnectionId, - Duration, SystemTime, RESET_TOKEN_SIZE, UNIX_EPOCH, + Duration, ServerConfig, SystemTime, RESET_TOKEN_SIZE, UNIX_EPOCH, }; -pub(crate) struct RetryToken { - /// The destination connection ID set in the very first packet from the client +/// Error for when a validation token may have been reused +pub struct TokenReuseError; + +/// Responsible for limiting clients' ability to reuse validation tokens +/// +/// [_RFC 9000 § 8.1.4:_](https://www.rfc-editor.org/rfc/rfc9000.html#section-8.1.4) +/// +/// > Attackers could replay tokens to use servers as amplifiers in DDoS attacks. To protect +/// > against such attacks, servers MUST ensure that replay of tokens is prevented or limited. +/// > Servers SHOULD ensure that tokens sent in Retry packets are only accepted for a short time, +/// > as they are returned immediately by clients. Tokens that are provided in NEW_TOKEN frames +/// > (Section 19.7) need to be valid for longer but SHOULD NOT be accepted multiple times. +/// > Servers are encouraged to allow tokens to be used only once, if possible; tokens MAY include +/// > additional information about clients to further narrow applicability or reuse. +/// +/// `TokenLog` pertains only to tokens provided in NEW_TOKEN frames. +pub trait TokenLog: Send + Sync { + /// Record that the token was used and, ideally, return a token reuse error if the token was + /// already used previously + /// + /// False negatives and false positives are both permissible. Called when a client uses an + /// address validation token. + /// + /// Parameters: + /// - `rand`: A server-generated random unique value for the token. + /// - `issued`: The time the server issued the token. + /// - `lifetime`: The expiration time of address validation tokens sent via NEW_TOKEN frames, + /// as configured by [`ServerConfig::validation_token_lifetime`][1]. + /// + /// [1]: crate::ServerConfig::validation_token_lifetime + fn check_and_insert( + &self, + rand: u128, + issued: SystemTime, + lifetime: Duration, + ) -> Result<(), TokenReuseError>; +} + +/// State in an `Incoming` determined by a token or lack thereof +#[derive(Debug)] +pub(crate) struct IncomingTokenState { + pub(crate) retry_src_cid: Option, pub(crate) orig_dst_cid: ConnectionId, - /// The time at which this token was issued - pub(crate) issued: SystemTime, + pub(crate) validated: bool, } -impl RetryToken { - pub(crate) fn encode( - &self, - key: &dyn HandshakeTokenKey, - address: &SocketAddr, - retry_src_cid: &ConnectionId, - ) -> Vec { - let aead_key = key.aead_from_hkdf(retry_src_cid); +impl IncomingTokenState { + /// Construct for an `Incoming` which is not validated by a token + pub(crate) fn default(header: &InitialHeader) -> Self { + Self { + retry_src_cid: None, + orig_dst_cid: header.dst_cid, + validated: false, + } + } +} +/// An address validation / retry token +/// +/// The data in this struct is encoded and encrypted in the context of not only a handshake token +/// key, but also a client socket address. +pub(crate) struct Token { + /// Randomly generated value, which must be unique, and is visible to the client + pub(crate) rand: u128, + /// Content which is encrypted from the client + pub(crate) inner: TokenInner, +} + +impl Token { + /// Construct with newly sampled randomness + pub(crate) fn new(rng: &mut R, inner: TokenInner) -> Self { + Self { + rand: rng.gen(), + inner, + } + } + + /// Encode and encrypt + pub(crate) fn encode(&self, key: &dyn HandshakeTokenKey, address: &SocketAddr) -> Vec { let mut buf = Vec::new(); - encode_addr(&mut buf, address); - self.orig_dst_cid.encode_long(&mut buf); - buf.write::( - self.issued - .duration_since(UNIX_EPOCH) - .map(|x| x.as_secs()) - .unwrap_or(0), - ); + self.inner.encode(&mut buf, address); + let aead_key = key.aead_from_hkdf(&self.rand.to_le_bytes()); aead_key.seal(&mut buf, &[]).unwrap(); + buf.extend(&self.rand.to_le_bytes()); buf } - pub(crate) fn from_bytes( + pub(crate) fn decode( key: &dyn HandshakeTokenKey, address: &SocketAddr, - retry_src_cid: &ConnectionId, raw_token_bytes: &[u8], - ) -> Result { - let aead_key = key.aead_from_hkdf(retry_src_cid); - let mut sealed_token = raw_token_bytes.to_vec(); - - let data = aead_key.open(&mut sealed_token, &[])?; - let mut reader = io::Cursor::new(data); - let token_addr = decode_addr(&mut reader).ok_or(TokenDecodeError::UnknownToken)?; - if token_addr != *address { - return Err(TokenDecodeError::WrongAddress); + ) -> Result { + let rand_slice_start = raw_token_bytes + .len() + .checked_sub(size_of::()) + .ok_or(ValidationError::Ignore)?; + let mut rand_bytes = [0; size_of::()]; + rand_bytes.copy_from_slice(&raw_token_bytes[rand_slice_start..]); + let rand = u128::from_le_bytes(rand_bytes); + + let aead_key = key.aead_from_hkdf(&rand_bytes); + let mut sealed_inner = raw_token_bytes[..rand_slice_start].to_vec(); + let encoded = aead_key.open(&mut sealed_inner, &[])?; + + let mut cursor = io::Cursor::new(encoded); + let inner = TokenInner::decode(&mut cursor, address)?; + if cursor.has_remaining() { + return Err(ValidationError::Ignore); } - let orig_dst_cid = - ConnectionId::decode_long(&mut reader).ok_or(TokenDecodeError::UnknownToken)?; - let issued = UNIX_EPOCH - + Duration::new( - reader - .get::() - .map_err(|_| TokenDecodeError::UnknownToken)?, - 0, - ); + Ok(Self { rand, inner }) + } + + /// Ensure that this token validates an `Incoming`, and construct its token state + pub(crate) fn validate( + &self, + header: &InitialHeader, + server_config: &ServerConfig, + ) -> Result { + self.inner.validate(self.rand, header, server_config) + } +} + +/// Content of [`Token`] depending on how token originated that is encrypted from the client +pub(crate) enum TokenInner { + Retry(RetryTokenInner), + Validation(ValidationTokenInner), +} + +impl TokenInner { + /// Encode without encryption + fn encode(&self, buf: &mut Vec, address: &SocketAddr) { + match *self { + Self::Retry(ref inner) => { + buf.push(0); + inner.encode(buf, address); + } + Self::Validation(ref inner) => { + buf.push(1); + inner.encode(buf, address); + } + } + } + + /// Try to decode without encryption, but do validate that the address is acceptable + fn decode(buf: &mut B, address: &SocketAddr) -> Result { + match buf.get::().ok().ok_or(ValidationError::Ignore)? { + 0 => RetryTokenInner::decode(buf, address).map(Self::Retry), + 1 => ValidationTokenInner::decode(buf, address).map(Self::Validation), + _ => Err(ValidationError::Ignore), + } + } + + /// Ensure that this token validates an `Incoming`, and construct its token state + pub(crate) fn validate( + &self, + rand: u128, + header: &InitialHeader, + server_config: &ServerConfig, + ) -> Result { + match *self { + Self::Retry(ref inner) => inner.validate(header, server_config), + Self::Validation(ref inner) => inner.validate(rand, header, server_config), + } + } +} + +/// Content of [`Token`] originating from Retry packet that is encrypted from the client +pub(crate) struct RetryTokenInner { + /// The destination connection ID set in the very first packet from the client + pub(crate) orig_dst_cid: ConnectionId, + /// The time at which this token was issued + pub(crate) issued: SystemTime, +} + +impl RetryTokenInner { + /// Encode without encryption + fn encode(&self, buf: &mut Vec, address: &SocketAddr) { + encode_socket_addr(buf, address); + self.orig_dst_cid.encode_long(buf); + encode_time(buf, self.issued); + } + + /// Try to decode without encryption, but do validate that the address is acceptable + fn decode(buf: &mut B, address: &SocketAddr) -> Result { + let token_address = decode_socket_addr(buf).ok_or(ValidationError::Ignore)?; + if token_address != *address { + return Err(ValidationError::InvalidRetry); + } + let orig_dst_cid = ConnectionId::decode_long(buf).ok_or(ValidationError::Ignore)?; + let issued = decode_time(buf).ok_or(ValidationError::Ignore)?; Ok(Self { orig_dst_cid, issued, }) } + + /// Ensure that this token validates an `Incoming`, and construct its token state + pub(crate) fn validate( + &self, + header: &InitialHeader, + server_config: &ServerConfig, + ) -> Result { + if self.issued + server_config.retry_token_lifetime > SystemTime::now() { + Ok(IncomingTokenState { + retry_src_cid: Some(header.dst_cid), + orig_dst_cid: self.orig_dst_cid, + validated: true, + }) + } else { + Err(ValidationError::InvalidRetry) + } + } +} + +/// Content of [`Token`] originating from NEW_TOKEN frame that is encrypted from the client +pub(crate) struct ValidationTokenInner { + /// The time at which this token was issued + pub(crate) issued: SystemTime, +} + +impl ValidationTokenInner { + /// Encode without encryption + fn encode(&self, buf: &mut Vec, address: &SocketAddr) { + encode_ip_addr(buf, &address.ip()); + encode_time(buf, self.issued); + } + + /// Try to decode without encryption, but do validate that the address is acceptable + fn decode(buf: &mut B, address: &SocketAddr) -> Result { + let token_address = decode_ip_addr(buf).ok_or(ValidationError::Ignore)?; + if token_address != address.ip() { + return Err(ValidationError::Ignore); + } + let issued = decode_time(buf).ok_or(ValidationError::Ignore)?; + Ok(Self { issued }) + } + + /// Ensure that this token validates an `Incoming`, and construct its token state + pub(crate) fn validate( + &self, + rand: u128, + header: &InitialHeader, + server_config: &ServerConfig, + ) -> Result { + let Some(ref log) = server_config.validation_token_log else { + return Err(ValidationError::Ignore); + }; + let log_result = + log.check_and_insert(rand, self.issued, server_config.validation_token_lifetime); + if log_result.is_err() { + debug!("rejecting token from NEW_TOKEN frame because detected as reuse"); + Err(ValidationError::Ignore) + } else if self.issued + server_config.validation_token_lifetime < SystemTime::now() { + Err(ValidationError::Ignore) + } else { + Ok(IncomingTokenState { + retry_src_cid: None, + orig_dst_cid: header.dst_cid, + validated: true, + }) + } + } } -fn encode_addr(buf: &mut Vec, address: &SocketAddr) { - match address.ip() { +fn encode_socket_addr(buf: &mut Vec, address: &SocketAddr) { + encode_ip_addr(buf, &address.ip()); + buf.put_u16(address.port()); +} + +fn encode_ip_addr(buf: &mut Vec, address: &IpAddr) { + match address { IpAddr::V4(x) => { buf.put_u8(0); buf.put_slice(&x.octets()); @@ -86,32 +299,64 @@ fn encode_addr(buf: &mut Vec, address: &SocketAddr) { buf.put_slice(&x.octets()); } } - buf.put_u16(address.port()); } -fn decode_addr(buf: &mut B) -> Option { - let ip = match buf.get_u8() { +fn decode_socket_addr(buf: &mut B) -> Option { + let ip = decode_ip_addr(buf)?; + let port = buf.get::().ok()?; + Some(SocketAddr::new(ip, port)) +} + +fn decode_ip_addr(buf: &mut B) -> Option { + Some(match buf.get::().ok()? { 0 => IpAddr::V4(buf.get().ok()?), 1 => IpAddr::V6(buf.get().ok()?), _ => return None, - }; - let port = buf.get_u16(); - Some(SocketAddr::new(ip, port)) + }) } -/// Reasons why a retry token might fail to validate a client's address +fn encode_time(buf: &mut Vec, time: SystemTime) { + let unix_secs = time + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + buf.write::(unix_secs); +} + +fn decode_time(buf: &mut B) -> Option { + Some(UNIX_EPOCH + Duration::from_secs(buf.get::().ok()?)) +} + +/// Error for a token failing to validate a client's address #[derive(Debug, Copy, Clone)] -pub(crate) enum TokenDecodeError { - /// Token was not recognized. It should be silently ignored. - UnknownToken, - /// Token was well-formed but associated with an incorrect address. The connection cannot be - /// established. - WrongAddress, +pub(crate) enum ValidationError { + /// Token may have come from a NEW_TOKEN frame (including from a different server or a previous + /// run of this server with different keys), and was not valid + /// + /// It should be silently ignored. + /// + /// In cases where a token cannot be decrypted/decoded, we must allow for the possibility that + /// this is caused not by client malfeasance, but by the token having been generated by an + /// incompatible endpoint, e.g. a different version or a neighbor behind the same load + /// balancer. In such cases we proceed as if there was no token. + /// + /// [_RFC 9000 § 8.1.3:_](https://www.rfc-editor.org/rfc/rfc9000.html#section-8.1.3-10) + /// + /// > If the token is invalid, then the server SHOULD proceed as if the client did not have a + /// > validated address, including potentially sending a Retry packet. + /// + /// That said, this may also be used when a token _can_ be unambiguously decrypted/decoded as a + /// token from a NEW_TOKEN frame, but is simply not valid. + Ignore, + /// Token was unambiguously from a Retry packet, and was not valid + /// + /// The connection cannot be established. + InvalidRetry, } -impl From for TokenDecodeError { +impl From for ValidationError { fn from(CryptoError: CryptoError) -> Self { - Self::UnknownToken + Self::Ignore } } @@ -165,47 +410,69 @@ impl fmt::Display for ResetToken { #[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))] mod test { + use super::*; #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] use aws_lc_rs::hkdf; + use rand::prelude::*; #[cfg(feature = "ring")] use ring::hkdf; - #[test] - fn token_sanity() { - use super::*; + fn token_round_trip(inner: TokenInner) -> TokenInner { + let rng = &mut rand::thread_rng(); + let token = Token::new(rng, inner); + let mut master_key = [0; 64]; + rng.fill_bytes(&mut master_key); + let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key); + let addr = SocketAddr::new(rng.gen::().to_ne_bytes().into(), rng.gen::()); + let encoded = token.encode(&prk, &addr); + let decoded = Token::decode(&prk, &addr, &encoded).expect("token didn't decrypt / decode"); + assert_eq!(token.rand, decoded.rand); + decoded.inner + } + + fn retry_token_sanity() { use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}; use crate::MAX_CID_SIZE; use crate::{Duration, UNIX_EPOCH}; - use rand::RngCore; - use std::net::Ipv6Addr; + let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); + let issued_1 = UNIX_EPOCH + Duration::new(42, 0); // Fractional seconds would be lost + + let inner_1 = TokenInner::Retry(RetryTokenInner { + orig_dst_cid: orig_dst_cid_1, + issued: issued_1, + }); + let inner_2 = token_round_trip(inner_1); + let TokenInner::Retry(RetryTokenInner { + orig_dst_cid: orig_dst_cid_2, + issued: issued_2, + }) = inner_2 + else { + panic!("token decoded as wrong variant") + }; - let rng = &mut rand::thread_rng(); + assert_eq!(orig_dst_cid_1, orig_dst_cid_2); + assert_eq!(issued_1, issued_2); + } - let mut master_key = [0; 64]; - rng.fill_bytes(&mut master_key); + #[test] + fn validation_token_sanity() { + use crate::{Duration, UNIX_EPOCH}; - let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key); + let issued_1 = UNIX_EPOCH + Duration::new(42, 0); // Fractional seconds would be lost - let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433); - let retry_src_cid = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); - let token = RetryToken { - orig_dst_cid: RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(), - issued: UNIX_EPOCH + Duration::new(42, 0), // Fractional seconds would be lost + let inner_1 = TokenInner::Validation(ValidationTokenInner { issued: issued_1 }); + let inner_2 = token_round_trip(inner_1); + let TokenInner::Validation(ValidationTokenInner { issued: issued_2 }) = inner_2 else { + panic!("token decoded as wrong variant") }; - let encoded = token.encode(&prk, &addr, &retry_src_cid); - let decoded = RetryToken::from_bytes(&prk, &addr, &retry_src_cid, &encoded) - .expect("token didn't validate"); - assert_eq!(token.orig_dst_cid, decoded.orig_dst_cid); - assert_eq!(token.issued, decoded.issued); + assert_eq!(issued_1, issued_2); } #[test] fn invalid_token_returns_err() { use super::*; - use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}; - use crate::MAX_CID_SIZE; use rand::RngCore; use std::net::Ipv6Addr; @@ -217,7 +484,6 @@ mod test { let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key); let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433); - let retry_src_cid = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); let mut invalid_token = Vec::new(); @@ -226,6 +492,6 @@ mod test { invalid_token.put_slice(&random_data); // Assert: garbage sealed data returns err - assert!(RetryToken::from_bytes(&prk, &addr, &retry_src_cid, &invalid_token).is_err()); + assert!(Token::decode(&prk, &addr, &invalid_token).is_err()); } } diff --git a/quinn-proto/src/token_store.rs b/quinn-proto/src/token_store.rs new file mode 100644 index 000000000..d80aaffd3 --- /dev/null +++ b/quinn-proto/src/token_store.rs @@ -0,0 +1,514 @@ +//! Storing tokens sent from servers in NEW_TOKEN frames and using them in subsequent connections + +use bytes::Bytes; +use slab::Slab; +use std::{ + collections::{hash_map, HashMap}, + mem::take, + sync::{Arc, Mutex}, +}; +use tracing::trace; + +/// Responsible for storing address validation tokens received from servers and retrieving them for +/// use in subsequent connections +pub trait TokenStore: Send + Sync { + /// Potentially store a token for later one-time use + /// + /// Called when a NEW_TOKEN frame is received from the server. + fn insert(&self, server_name: &str, token: Bytes); + + /// Try to find and take a token that was stored with the given server name + /// + /// The same token must never be returned from `take` twice, as doing so can be used to + /// de-anonymize a client's traffic. + /// + /// Called when trying to connect to a server. It is always ok for this to return `None`. + fn take(&self, server_name: &str) -> Option; +} + +/// `TokenStore` implementation that stores up to `N` tokens per server name for up to a +/// limited number of server names, in-memory +#[derive(Debug)] +pub struct TokenMemoryCache(Mutex>); + +impl TokenMemoryCache { + /// Construct empty + pub fn new(max_server_names: usize) -> Self { + Self(Mutex::new(State::new(max_server_names))) + } +} + +impl TokenStore for TokenMemoryCache { + fn insert(&self, server_name: &str, token: Bytes) { + trace!(%server_name, "storing token"); + self.0.lock().unwrap().store(server_name, token) + } + + fn take(&self, server_name: &str) -> Option { + let token = self.0.lock().unwrap().take(server_name); + trace!(%server_name, found=%token.is_some(), "taking token"); + token + } +} + +/// Defaults to a maximum of 256 servers +impl Default for TokenMemoryCache { + fn default() -> Self { + Self::new(256) + } +} + +/// Lockable inner state of `TokenMemoryCache` +#[derive(Debug)] +struct State { + max_server_names: usize, + // map from server name to slab index in linked + lookup: HashMap, usize>, + linked: LinkedCache, +} + +impl State { + fn new(max_server_names: usize) -> Self { + assert!(max_server_names > 0, "size limit cannot be 0"); + Self { + max_server_names, + lookup: HashMap::new(), + linked: LinkedCache::default(), + } + } + + fn store(&mut self, server_name: &str, token: Bytes) { + let server_name = Arc::::from(server_name); + let idx = match self.lookup.entry(server_name.clone()) { + hash_map::Entry::Occupied(hmap_entry) => { + // key already exists, add the new token to its token stack + let entry = &mut self.linked.entries[*hmap_entry.get()]; + entry.tokens.push(token); + + // unlink the entry and set it up to be linked as the most recently used + self.linked.unlink(*hmap_entry.get()); + *hmap_entry.get() + } + hash_map::Entry::Vacant(hmap_entry) => { + // key does not yet exist, create a new one, evicting the oldest if necessary + let removed_key = if self.linked.entries.len() >= self.max_server_names { + // unwrap safety: max_server_names is > 0, so there's at least one entry, so + // oldest_newest is some + let oldest = self.linked.oldest_newest.unwrap().0; + self.linked.unlink(oldest); + Some(self.linked.entries.remove(oldest).server_name) + } else { + None + }; + + let cache_entry = CacheEntry::new(server_name, token); + let idx = self.linked.entries.insert(cache_entry); + hmap_entry.insert(idx); + + // for borrowing reasons, we must defer removing the evicted hmap entry + if let Some(removed_key) = removed_key { + let removed = self.lookup.remove(&removed_key); + debug_assert!(removed.is_some()); + } + + idx + } + }; + + // link it as the newest entry + self.linked.link(idx); + } + + fn take(&mut self, server_name: &str) -> Option { + if let hash_map::Entry::Occupied(hmap_entry) = self.lookup.entry(server_name.into()) { + let entry = &mut self.linked.entries[*hmap_entry.get()]; + // pop from entry's token stack + let token = entry.tokens.pop(); + if entry.tokens.len > 0 { + // re-link entry as most recently used + self.linked.unlink(*hmap_entry.get()); + self.linked.link(*hmap_entry.get()); + } else { + // token stack emptied, remove entry + self.linked.unlink(*hmap_entry.get()); + self.linked.entries.remove(*hmap_entry.get()); + hmap_entry.remove(); + } + Some(token) + } else { + None + } + } +} + +/// Slab-based linked LRU cache of `CacheEntry` +#[derive(Debug, Default)] +struct LinkedCache { + entries: Slab>, + oldest_newest: Option<(usize, usize)>, +} + +impl LinkedCache { + /// Re-link an entry's neighbors around it + fn unlink(&mut self, idx: usize) { + // unwrap safety: we assume entries[idx] is linked, therefore oldest_newest is some + let &mut (ref mut oldest, ref mut newest) = self.oldest_newest.as_mut().unwrap(); + if *oldest == idx && *newest == idx { + // edge case where the list becomes empty + self.oldest_newest = None; + } else { + let older = self.entries[idx].older; + let newer = self.entries[idx].newer; + // re-link older's newer + if let Some(older) = older { + self.entries[older].newer = newer; + } else { + // unwrap safety: if both older and newer were None, we would've entered the branch + // where the list becomes empty instead + *oldest = newer.unwrap(); + } + // re-link newer's older + if let Some(newer) = newer { + self.entries[newer].older = older; + } else { + // unwrap safety: if both older and newer were None, we would've entered the branch + // where the list becomes empty instead + *newest = older.unwrap(); + } + } + } + + /// Link an unlinked entry as the most recently used entry + fn link(&mut self, idx: usize) { + self.entries[idx].newer = None; + self.entries[idx].older = self.oldest_newest.map(|(_, newest)| newest); + if let Some((_, ref mut newest)) = self.oldest_newest.as_mut() { + self.entries[*newest].newer = Some(idx); + *newest = idx; + } else { + self.oldest_newest = Some((idx, idx)); + } + } +} + +/// Cache entry within `LinkedCache` +#[derive(Debug)] +struct CacheEntry { + older: Option, + newer: Option, + server_name: Arc, + tokens: Queue, +} + +impl CacheEntry { + /// Construct with a single token, not linked + fn new(server_name: Arc, token: Bytes) -> Self { + let mut tokens = Queue::new(); + tokens.push(token); + Self { + server_name, + older: None, + newer: None, + tokens, + } + } +} + +/// In-place vector queue of up to `N` `Bytes` +#[derive(Debug)] +struct Queue { + elems: [Bytes; N], + // if len > 0, front is elems[start] + // invariant: start < N + start: usize, + // if len > 0, back is elems[(start + len - 1) % N] + len: usize, +} + +impl Queue { + /// Construct empty + fn new() -> Self { + const EMPTY_BYTES: Bytes = Bytes::new(); + Self { + elems: [EMPTY_BYTES; N], + start: 0, + len: 0, + } + } + + /// Push to back, popping from front first if already at capacity + fn push(&mut self, elem: Bytes) { + self.elems[(self.start + self.len) % N] = elem; + if self.len < N { + self.len += 1; + } else { + self.start += 1; + self.start %= N; + } + } + + /// Pop from front, panicking if empty + fn pop(&mut self) -> Bytes { + const PANIC_MSG: &str = "TokenMemoryCache popped from empty Queue, this is a bug!"; + self.len = self.len.checked_sub(1).expect(PANIC_MSG); + let elem = take(&mut self.elems[self.start]); + self.start += 1; + self.start %= N; + elem + } +} + +#[cfg(test)] +mod tests { + use std::collections::VecDeque; + + use super::*; + use rand::prelude::*; + use rand_pcg::Pcg32; + + fn new_rng() -> impl Rng { + Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeefu128.to_le_bytes()) + } + + #[test] + fn queue_test() { + let mut rng = new_rng(); + const N: usize = 2; + + for _ in 0..100 { + let mut queue_1 = VecDeque::new(); + let mut queue_2 = Queue::::new(); + + for i in 0..10 { + if rng.gen::() { + // push + let token = Bytes::from(vec![i]); + println!("PUSH {:?}", token); + queue_1.push_back(token.clone()); + if queue_1.len() > N { + queue_1.pop_front(); + } + queue_2.push(token); + } else { + // pop + if let Some(token) = queue_1.pop_front() { + println!("POP {:?}", token); + assert_eq!(queue_2.pop(), token); + } else { + println!("POP nothing"); + assert_eq!(queue_2.len, 0); + } + } + // assert equivalent + println!("queue_1 = {:?}", queue_1); + println!("queue_2 = {:?}", queue_2); + assert_eq!(queue_1.len(), queue_2.len); + for (j, token) in queue_1.iter().enumerate() { + let k = (queue_2.start + j) % N; + assert_eq!(queue_2.elems[k], token); + } + } + } + } + + #[test] + fn linked_test() { + let mut rng = new_rng(); + const N: usize = 2; + + for _ in 0..10 { + let mut cache_1: Vec = Vec::new(); // keep it sorted oldest to newest + let mut cache_2: LinkedCache = LinkedCache::default(); + for i in 0..100 { + match rng.gen::() % 4 { + 0 | 1 => { + // insert + println!("INSERT {}", i); + let entry_2 = CacheEntry::new(i.to_string().into(), Bytes::new()); + cache_1.push(i); + let slab_idx = cache_2.entries.insert(entry_2); + cache_2.link(slab_idx); + } + 2 => { + if cache_1.is_empty() { + println!("SKIP BECAUSE EMPTY"); + continue; + } + // hit + let idx = rng.gen::() % cache_1.len(); + let entry_1 = cache_1.remove(idx); + println!("HIT {}", entry_1); + let (slab_idx, _) = cache_2 + .entries + .iter() + .find(|(_, entry_2)| { + entry_2.server_name.as_ref() == entry_1.to_string().as_str() + }) + .unwrap(); + cache_1.push(entry_1); + cache_2.unlink(slab_idx); + cache_2.link(slab_idx); + } + 3 => { + if cache_1.is_empty() { + println!("SKIP BECAUSE EMPTY"); + continue; + } + // remove + let idx = rng.gen::() % cache_1.len(); + let entry_1 = cache_1.remove(idx); + println!("REMOVE {}", entry_1); + let (slab_idx, _) = cache_2 + .entries + .iter() + .find(|(_, entry_2)| { + entry_2.server_name.as_ref() == entry_1.to_string().as_str() + }) + .unwrap(); + cache_2.unlink(slab_idx); + cache_2.entries.remove(slab_idx); + } + _ => unreachable!(), + } + // assert equivalent + println!("cache_1 = {:#?}", cache_1); + println!("cache_2 = {:#?}", cache_2); + assert_eq!(cache_1.len(), cache_2.entries.len()); + let mut prev_slab_idx = None; + let mut slab_idx = cache_2.oldest_newest.map(|(oldest, _)| oldest); + for (i, entry_1) in cache_1.iter().enumerate() { + let entry_2 = &cache_2.entries + [slab_idx.unwrap_or_else(|| panic!("next link missing at index {}", i))]; + assert_eq!( + entry_2.server_name.as_ref(), + entry_1.to_string().as_str(), + "discrepancy at idx {}", + i + ); + assert_eq!( + entry_2.older, prev_slab_idx, + "backlink discrepancy at idx {}", + i + ); + prev_slab_idx = slab_idx; + slab_idx = entry_2.newer; + } + assert_eq!(slab_idx, None, "newest item has newer link"); + } + } + } + + #[test] + fn cache_test() { + let mut rng = new_rng(); + const N: usize = 2; + + for _ in 0..10 { + let mut cache_1: Vec<(u32, VecDeque)> = Vec::new(); // keep it sorted oldest to newest + let cache_2: TokenMemoryCache = TokenMemoryCache::new(20); + + for i in 0..200 { + let server_name = rng.gen::() % 10; + if rng.gen_bool(0.666) { + // store + let token = Bytes::from(vec![i]); + println!("STORE {} {:?}", server_name, token); + if let Some((j, _)) = cache_1 + .iter() + .enumerate() + .find(|&(_, &(server_name_2, _))| server_name_2 == server_name) + { + let (_, mut queue) = cache_1.remove(j); + queue.push_back(token.clone()); + if queue.len() > N { + queue.pop_front(); + } + cache_1.push((server_name, queue)); + } else { + let mut queue = VecDeque::new(); + queue.push_back(token.clone()); + cache_1.push((server_name, queue)); + if cache_1.len() > 20 { + cache_1.remove(0); + } + } + cache_2.insert(&server_name.to_string(), token); + } else { + // take + println!("TAKE {}", server_name); + let expecting = cache_1 + .iter() + .enumerate() + .find(|&(_, &(server_name_2, _))| server_name_2 == server_name) + .map(|(j, _)| j) + .map(|j| { + let (_, mut queue) = cache_1.remove(j); + let token = queue.pop_front().unwrap(); + if !queue.is_empty() { + cache_1.push((server_name, queue)); + } + token + }); + println!("EXPECTING {:?}", expecting); + assert_eq!(cache_2.take(&server_name.to_string()), expecting); + } + // assert equivalent + println!("cache_1 = {:#?}", cache_1); + println!("cache_2 = {:#?}", cache_2); + let cache_2 = cache_2.0.lock().unwrap(); + assert_eq!(cache_1.len(), cache_2.lookup.len(), "cache len discrepancy"); + assert_eq!( + cache_2.lookup.len(), + cache_2.linked.entries.len(), + "cache lookup hmap wrong len" + ); + let mut prev_slab_idx = None; + let mut slab_idx = cache_2.linked.oldest_newest.map(|(oldest, _)| oldest); + for (i, (server_name_1, queue_1)) in cache_1.iter().enumerate() { + let entry_2 = &cache_2.linked.entries + [slab_idx.unwrap_or_else(|| panic!("next link missing at index {}", i))]; + assert_eq!( + server_name_1.to_string().as_str(), + entry_2.server_name.as_ref(), + "server name discrepancy at idx {}", + i + ); + assert_eq!( + entry_2.older, prev_slab_idx, + "backlink discrepancy at idx {}", + i + ); + assert_eq!( + queue_1.len(), + entry_2.tokens.len, + "queue len discrepancy at idx {}", + i + ); + for (j, token) in queue_1.iter().enumerate() { + let k = (entry_2.tokens.start + j) % N; + assert_eq!( + entry_2.tokens.elems[k], token, + "queue item discrepancy at idx {} queue idx {}", + i, j + ); + } + assert_eq!( + *cache_2 + .lookup + .get(&Arc::::from(server_name_1.to_string())) + .unwrap_or_else(|| panic!( + "server name missing from hmap at idx {}", + i + )), + slab_idx.unwrap(), + "server name in hmap pointing to wrong slab entry at idx {}", + i + ); + prev_slab_idx = slab_idx; + slab_idx = entry_2.newer; + } + assert_eq!(slab_idx, None, "newest item has newer link"); + } + } + } +} diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index a061520d7..9730e6ed3 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -12,7 +12,9 @@ edition.workspace = true rust-version.workspace = true [features] -default = ["log", "platform-verifier", "runtime-tokio", "rustls-ring"] +default = ["log", "platform-verifier", "runtime-tokio", "rustls-ring", "fastbloom"] +# Enables BloomTokenLog +fastbloom = ["proto/fastbloom"] # Enables `Endpoint::client` and `Endpoint::server` conveniences aws-lc-rs = ["proto/aws-lc-rs"] aws-lc-rs-fips = ["proto/aws-lc-rs-fips"] diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 2563f3771..4f5cc3db3 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -1,5 +1,6 @@ use std::{ collections::VecDeque, + fmt, future::Future, io, io::IoSliceMut, @@ -728,7 +729,6 @@ impl std::ops::Deref for EndpointRef { } /// State directly involved in handling incoming packets -#[derive(Debug)] struct RecvState { incoming: VecDeque, connections: ConnectionSet, @@ -850,6 +850,17 @@ impl RecvState { } } +impl fmt::Debug for RecvState { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("RecvState") + .field("incoming", &self.incoming) + .field("connections", &self.connections) + // recv_buf too large + .field("recv_limiter", &self.recv_limiter) + .finish_non_exhaustive() + } +} + #[derive(Default)] struct PollProgress { /// Whether a datagram was routed to an existing connection diff --git a/quinn/src/incoming.rs b/quinn/src/incoming.rs index e306a368d..6b41d25bf 100644 --- a/quinn/src/incoming.rs +++ b/quinn/src/incoming.rs @@ -29,11 +29,9 @@ impl Incoming { state.endpoint.accept(state.inner, None) } - /// Accept this incoming connection using a custom configuration. + /// Accept this incoming connection using a custom configuration /// - /// See [`accept()`] for more details. - /// - /// [`accept()`]: Incoming::accept + /// See [`accept()`][Incoming::accept] for more details. pub fn accept_with( mut self, server_config: Arc, @@ -50,7 +48,7 @@ impl Incoming { /// Respond with a retry packet, requiring the client to retry with address validation /// - /// Errors if `remote_address_validated()` is true. + /// Errors if `may_retry()` is false. pub fn retry(mut self) -> Result<(), RetryError> { let state = self.0.take().unwrap(); state.endpoint.retry(state.inner).map_err(|e| { @@ -67,8 +65,7 @@ impl Incoming { state.endpoint.ignore(state.inner); } - /// The local IP address which was used when the peer established - /// the connection + /// The local IP address which was used when the peer established the connection pub fn local_ip(&self) -> Option { self.0.as_ref().unwrap().inner.local_ip() } @@ -82,10 +79,21 @@ impl Incoming { /// /// This means that the sender of the initial packet has proved that they can receive traffic /// sent to `self.remote_address()`. + /// + /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true. + /// The inverse is not guaranteed. pub fn remote_address_validated(&self) -> bool { self.0.as_ref().unwrap().inner.remote_address_validated() } + /// Whether it is legal to respond with a retry packet + /// + /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true. + /// The inverse is not guaranteed. + pub fn may_retry(&self) -> bool { + self.0.as_ref().unwrap().inner.may_retry() + } + /// The original destination CID when initiating the connection pub fn orig_dst_cid(&self) -> ConnectionId { *self.0.as_ref().unwrap().inner.orig_dst_cid() @@ -107,8 +115,7 @@ struct State { endpoint: EndpointRef, } -/// Error for attempting to retry an [`Incoming`] which already bears an address -/// validation token from a previous retry +/// Error for attempting to retry an [`Incoming`] which already bears a token from a previous retry #[derive(Debug, Error)] #[error("retry() with validated Incoming")] pub struct RetryError(Incoming); diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 84a3821f6..4cb0b2caa 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -61,12 +61,15 @@ mod runtime; mod send_stream; mod work_limiter; +#[cfg(feature = "fastbloom")] +pub use proto::BloomTokenLog; pub use proto::{ congestion, crypto, AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosedStream, ConfigError, ConnectError, ConnectionClose, ConnectionError, ConnectionId, ConnectionIdGenerator, ConnectionStats, Dir, EcnCodepoint, EndpointConfig, FrameStats, - FrameType, IdleTimeout, MtuDiscoveryConfig, PathStats, ServerConfig, Side, StreamId, Transmit, - TransportConfig, TransportErrorCode, UdpStats, VarInt, VarIntBoundsExceeded, Written, + FrameType, IdleTimeout, MtuDiscoveryConfig, PathStats, ServerConfig, Side, StreamId, TokenLog, + TokenMemoryCache, TokenReuseError, TokenStore, Transmit, TransportConfig, TransportErrorCode, + UdpStats, VarInt, VarIntBoundsExceeded, Written, }; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] pub use rustls;