diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 1630d015f5..8beff764c0 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -3,6 +3,7 @@ use std::{ future::Future, io, io::IoSliceMut, + mem, net::{SocketAddr, SocketAddrV6}, pin::Pin, str, @@ -215,7 +216,7 @@ impl Endpoint { let addr = socket.local_addr()?; let socket = self.runtime.wrap_udp_socket(socket)?; let mut inner = self.inner.state.lock().unwrap(); - inner.socket = socket; + inner.prev_socket = Some(mem::replace(&mut inner.socket, socket)); inner.ipv6 = addr.is_ipv6(); // Generate some activity so peers notice the rebind @@ -409,6 +410,9 @@ impl EndpointInner { #[derive(Debug)] pub(crate) struct State { socket: Arc, + /// During an active migration, abandoned_socket receives traffic + /// until the first packet arrives on the new socket. + prev_socket: Option>, inner: proto::Endpoint, transmit_state: TransmitState, recv_state: RecvState, @@ -431,6 +435,19 @@ pub(crate) struct Shared { impl State { fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result { self.recv_state.recv_limiter.start_cycle(); + if let Some(socket) = &self.prev_socket { + // We don't care about the `PollProgress` from old sockets. + let poll_res = self.recv_state.poll_socket( + cx, + &mut self.inner, + &mut self.transmit_state, + &**socket, + now, + ); + if poll_res.is_err() { + self.prev_socket = None; + } + }; let poll_res = self.recv_state.poll_socket( cx, &mut self.inner, @@ -439,7 +456,13 @@ impl State { now, ); self.recv_state.recv_limiter.finish_cycle(); - poll_res.map(|x| x.keep_going) + let poll_res = poll_res?; + if poll_res.received_connection_packet { + // Traffic has arrived on self.socket, therefore there is no need for the abandoned + // one anymore. TODO: Account for multiple outgoing connections. + self.prev_socket = None; + } + Ok(poll_res.keep_going) } fn drive_send(&mut self, cx: &mut Context) -> Result { @@ -681,6 +704,7 @@ impl EndpointRef { }, state: Mutex::new(State { socket, + prev_socket: None, inner, transmit_state: TransmitState::default(), ipv6, @@ -767,6 +791,7 @@ impl RecvState { socket: &dyn AsyncUdpSocket, now: Instant, ) -> Result { + let mut received_connection_packet = false; let mut metas = [RecvMeta::default(); BATCH_SIZE]; let mut iovs: [IoSliceMut; BATCH_SIZE] = { let mut bufs = self @@ -807,6 +832,7 @@ impl RecvState { } Some(DatagramEvent::ConnectionEvent(handle, event)) => { // Ignoring errors from dropped connections that haven't yet been cleaned up + received_connection_packet = true; let _ = self .connections .senders @@ -823,7 +849,10 @@ impl RecvState { } } Poll::Pending => { - return Ok(PollProgress { keep_going: false }); + return Ok(PollProgress { + received_connection_packet, + keep_going: false, + }); } // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an // attacker @@ -835,7 +864,10 @@ impl RecvState { } } if !self.recv_limiter.allow_work() { - return Ok(PollProgress { keep_going: true }); + return Ok(PollProgress { + received_connection_packet, + keep_going: true, + }); } } } @@ -843,6 +875,8 @@ impl RecvState { #[derive(Default)] struct PollProgress { + /// Whether a datagram was routed to an existing connection + received_connection_packet: bool, /// Whether datagram handling was interrupted early by the work limiter for fairness keep_going: bool, }