Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport supress stateless packet to 0.10.x #1598

Merged
42 changes: 40 additions & 2 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use std::{
fmt, iter,
net::{IpAddr, SocketAddr},
ops::{Index, IndexMut},
sync::Arc,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::{Instant, SystemTime},
};

Expand Down Expand Up @@ -63,8 +66,14 @@ pub struct Endpoint {
server_config: Option<Arc<ServerConfig>>,
/// Whether the underlying UDP socket promises not to fragment packets
allow_mtud: bool,
transmit_queue_contents_len: Arc<AtomicUsize>,
}

/// The maximum size of content length of packets in the outgoing transmit queue. Transmit packets
/// generated from the endpoint (retry or initial close) can be dropped when this limit is being execeeded.
/// Chose to represent 100 MB of data.
const MAX_TRANSMIT_QUEUE_CONTENTS_LEN: usize = 100_000_000;

impl Endpoint {
/// Create a new endpoint
///
Expand All @@ -75,6 +84,7 @@ impl Endpoint {
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
allow_mtud: bool,
transmit_queue_contents_len: Arc<AtomicUsize>,
lijunwangs marked this conversation as resolved.
Show resolved Hide resolved
) -> Self {
Self {
rng: StdRng::from_entropy(),
Expand All @@ -88,13 +98,19 @@ impl Endpoint {
config,
server_config,
allow_mtud,
transmit_queue_contents_len,
}
}

/// Get the next packet to transmit
#[must_use]
pub fn poll_transmit(&mut self) -> Option<Transmit> {
self.transmits.pop_front()
let t = self.transmits.pop_front();
self.transmit_queue_contents_len.fetch_sub(
t.as_ref().map_or(0, |t| t.contents.len()),
Ordering::Relaxed,
);
t
}

/// Replace the server configuration, affecting new incoming connections only
Expand Down Expand Up @@ -193,6 +209,8 @@ impl Endpoint {
for &version in &self.config.supported_versions {
buf.write(version);
}
self.transmit_queue_contents_len
.fetch_add(buf.len(), Ordering::Relaxed);
self.transmits.push_back(Transmit {
destination: remote,
ecn: None,
Expand Down Expand Up @@ -355,6 +373,8 @@ impl Endpoint {
buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));

debug_assert!(buf.len() < inciting_dgram_len);
self.transmit_queue_contents_len
.fetch_add(buf.len(), Ordering::Relaxed);

self.transmits.push_back(Transmit {
destination: addresses.remote,
Expand Down Expand Up @@ -447,6 +467,14 @@ impl Endpoint {
}
}

/// Limiting the memory usage for items queued in the outgoing queue from endpoint
/// generated packets. Otherwise, we may see a build-up of the queue under test with
/// flood of initial packets against the endpoint. The sender with the sender-limiter
/// may not keep up the pace of these packets queued into the queue.
fn to_supresss_stateless_packets(&self) -> bool {
lijunwangs marked this conversation as resolved.
Show resolved Hide resolved
self.transmit_queue_contents_len.load(Ordering::Relaxed) >= MAX_TRANSMIT_QUEUE_CONTENTS_LEN
}

fn handle_first_packet(
&mut self,
now: Instant,
Expand Down Expand Up @@ -521,6 +549,9 @@ impl Endpoint {

let (retry_src_cid, orig_dst_cid) = if server_config.use_retry {
if token.is_empty() {
if self.to_supresss_stateless_packets() {
return None;
}
// First Initial
let mut random_bytes = vec![0u8; RetryToken::RANDOM_BYTES_LEN];
self.rng.fill_bytes(&mut random_bytes);
Expand All @@ -544,6 +575,8 @@ impl Endpoint {
buf.extend_from_slice(&server_config.crypto.retry_tag(version, &dst_cid, &buf));
encode.finish(&mut buf, &*crypto.header.local, None);

self.transmit_queue_contents_len
.fetch_add(buf.len(), Ordering::Relaxed);
self.transmits.push_back(Transmit {
destination: addresses.remote,
ecn: None,
Expand Down Expand Up @@ -680,6 +713,9 @@ impl Endpoint {
local_id: &ConnectionId,
reason: TransportError,
) {
if self.to_supresss_stateless_packets() {
return;
}
let number = PacketNumber::U8(0);
let header = Header::Initial {
dst_cid: *remote_id,
Expand All @@ -700,6 +736,8 @@ impl Endpoint {
&*crypto.header.local,
Some((0, &*crypto.packet.local)),
);
self.transmit_queue_contents_len
.fetch_add(buf.len(), Ordering::Relaxed);
self.transmits.push_back(Transmit {
destination: addresses.remote,
ecn: None,
Expand Down
60 changes: 53 additions & 7 deletions quinn-proto/src/tests/mod.rs
lijunwangs marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{
convert::TryInto,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
sync::{atomic::AtomicUsize, Arc},
time::{Duration, Instant},
};

Expand All @@ -25,7 +25,14 @@ use util::*;
fn version_negotiate_server() {
let _guard = subscribe();
let client_addr = "[::2]:7890".parse().unwrap();
let mut server = Endpoint::new(Default::default(), Some(Arc::new(server_config())), true);
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());

let mut server = Endpoint::new(
Default::default(),
Some(Arc::new(server_config())),
true,
transmit_queue_contents_len,
);
let now = Instant::now();
let event = server.handle(
now,
Expand Down Expand Up @@ -55,13 +62,15 @@ fn version_negotiate_client() {
let server_addr = "[::2]:7890".parse().unwrap();
let cid_generator_factory: fn() -> Box<dyn ConnectionIdGenerator> =
|| Box::new(RandomConnectionIdGenerator::new(0));
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());
let mut client = Endpoint::new(
Arc::new(EndpointConfig {
connection_id_generator_factory: Arc::new(cid_generator_factory),
..Default::default()
}),
None,
true,
transmit_queue_contents_len,
);
let (_, mut client_ch) = client
.connect(client_config(), server_addr, "localhost")
Expand Down Expand Up @@ -173,11 +182,17 @@ fn server_stateless_reset() {
let reset_key = hmac::Key::new(hmac::HMAC_SHA256, &reset_key);

let endpoint_config = Arc::new(EndpointConfig::new(Arc::new(reset_key)));
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());

let mut pair = Pair::new(endpoint_config.clone(), server_config());
let (client_ch, _) = pair.connect();
pair.drive(); // Flush any post-handshake frames
pair.server.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true);
pair.server.endpoint = Endpoint::new(
endpoint_config,
Some(Arc::new(server_config())),
true,
transmit_queue_contents_len,
);
// Force the server to generate the smallest possible stateless reset
pair.client.connections.get_mut(&client_ch).unwrap().ping();
info!("resetting");
Expand All @@ -202,7 +217,14 @@ fn client_stateless_reset() {

let mut pair = Pair::new(endpoint_config.clone(), server_config());
let (_, server_ch) = pair.connect();
pair.client.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true);
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());

pair.client.endpoint = Endpoint::new(
endpoint_config,
Some(Arc::new(server_config())),
true,
transmit_queue_contents_len,
);
// Send something big enough to allow room for a smaller stateless reset.
pair.server.connections.get_mut(&server_ch).unwrap().close(
pair.time,
Expand Down Expand Up @@ -1316,6 +1338,7 @@ fn cid_rotation() {

let cid_generator_factory: fn() -> Box<dyn ConnectionIdGenerator> =
|| Box::new(*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT));
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());

// Only test cid rotation on server side to have a clear output trace
let server = Endpoint::new(
Expand All @@ -1325,8 +1348,15 @@ fn cid_rotation() {
}),
Some(Arc::new(server_config())),
true,
transmit_queue_contents_len,
);
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());
let client = Endpoint::new(
Arc::new(EndpointConfig::default()),
None,
true,
transmit_queue_contents_len,
);
let client = Endpoint::new(Arc::new(EndpointConfig::default()), None, true);

let mut pair = Pair::new_from_endpoint(client, server);
let (_, server_ch) = pair.connect();
Expand Down Expand Up @@ -1904,7 +1934,14 @@ fn big_cert_and_key() -> (rustls::Certificate, rustls::PrivateKey) {
fn malformed_token_len() {
let _guard = subscribe();
let client_addr = "[::2]:7890".parse().unwrap();
let mut server = Endpoint::new(Default::default(), Some(Arc::new(server_config())), true);
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());

let mut server = Endpoint::new(
Default::default(),
Some(Arc::new(server_config())),
true,
transmit_queue_contents_len,
);
server.handle(
Instant::now(),
client_addr,
Expand Down Expand Up @@ -1976,16 +2013,25 @@ fn migrate_detects_new_mtu_and_respects_original_peer_max_udp_payload_size() {

// Set up a client with a max payload size of 1400 (and use the defaults for the server)
let server_endpoint_config = EndpointConfig::default();
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());

let server = Endpoint::new(
Arc::new(server_endpoint_config),
Some(Arc::new(server_config())),
true,
transmit_queue_contents_len,
);
let client_endpoint_config = EndpointConfig {
max_udp_payload_size: VarInt::from(client_max_udp_payload_size),
..EndpointConfig::default()
};
let client = Endpoint::new(Arc::new(client_endpoint_config), None, true);
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());
let client = Endpoint::new(
Arc::new(client_endpoint_config),
None,
true,
transmit_queue_contents_len,
);
let mut pair = Pair::new_from_endpoint(client, server);
pair.mtu = 1300;

Expand Down
14 changes: 11 additions & 3 deletions quinn-proto/src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
net::{Ipv6Addr, SocketAddr, UdpSocket},
ops::RangeFrom,
str,
sync::{Arc, Mutex},
sync::{atomic::AtomicUsize, Arc, Mutex},
time::{Duration, Instant},
};

Expand Down Expand Up @@ -36,8 +36,16 @@ pub(super) struct Pair {

impl Pair {
pub(super) fn new(endpoint_config: Arc<EndpointConfig>, server_config: ServerConfig) -> Self {
let server = Endpoint::new(endpoint_config.clone(), Some(Arc::new(server_config)), true);
let client = Endpoint::new(endpoint_config, None, true);
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());

let server = Endpoint::new(
endpoint_config.clone(),
Some(Arc::new(server_config)),
true,
transmit_queue_contents_len,
);
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());
let client = Endpoint::new(endpoint_config, None, true, transmit_queue_contents_len);

Self::new_from_endpoint(client, server)
}
Expand Down
26 changes: 23 additions & 3 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use std::{
net::{SocketAddr, SocketAddrV6},
pin::Pin,
str,
sync::{Arc, Mutex},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
task::{Context, Poll, Waker},
time::Instant,
};
Expand Down Expand Up @@ -115,11 +118,18 @@ impl Endpoint {
) -> io::Result<Self> {
let addr = socket.local_addr()?;
let allow_mtud = !socket.may_fragment();
let transmit_queue_contents_len = Arc::new(AtomicUsize::default());
let rc = EndpointRef::new(
socket,
proto::Endpoint::new(Arc::new(config), server_config.map(Arc::new), allow_mtud),
proto::Endpoint::new(
Arc::new(config),
server_config.map(Arc::new),
allow_mtud,
transmit_queue_contents_len.clone(),
),
addr.is_ipv6(),
runtime.clone(),
transmit_queue_contents_len,
);
let driver = EndpointDriver(rc.clone());
runtime.spawn(Box::pin(async {
Expand Down Expand Up @@ -379,6 +389,8 @@ pub(crate) struct State {
recv_buf: Box<[u8]>,
send_limiter: WorkLimiter,
runtime: Arc<dyn Runtime>,
/// The aggregateed contents length of the packets in the transmit queue
transmit_queue_contents_len: Arc<AtomicUsize>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -486,7 +498,10 @@ impl State {
.poll_send(&self.udp_state, cx, self.outgoing.as_slices().0)
{
Poll::Ready(Ok(n)) => {
self.outgoing.drain(..n);
let contents_len: usize =
self.outgoing.drain(..n).map(|t| t.contents.len()).sum();
self.transmit_queue_contents_len
.fetch_sub(contents_len, Ordering::Relaxed);
// We count transmits instead of `poll_send` calls since the cost
// of a `sendmmsg` still linearily increases with number of packets.
self.send_limiter.record_work(n);
Expand Down Expand Up @@ -540,6 +555,9 @@ impl State {
}

fn queue_transmit(&mut self, t: proto::Transmit) {
let contents_len = t.contents.len();
self.transmit_queue_contents_len
.fetch_add(contents_len, Ordering::Relaxed);
self.outgoing.push_back(udp::Transmit {
destination: t.destination,
ecn: t.ecn.map(udp_ecn),
Expand Down Expand Up @@ -655,6 +673,7 @@ impl EndpointRef {
inner: proto::Endpoint,
ipv6: bool,
runtime: Arc<dyn Runtime>,
transmit_queue_contents_len: Arc<AtomicUsize>,
) -> Self {
let udp_state = Arc::new(UdpState::new());
let recv_buf = vec![
Expand Down Expand Up @@ -689,6 +708,7 @@ impl EndpointRef {
recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
send_limiter: WorkLimiter::new(SEND_TIME_BOUND),
runtime,
transmit_queue_contents_len,
}),
}))
}
Expand Down