diff --git a/neqo-client/src/main.rs b/neqo-client/src/main.rs index f465d7c206..868a86ddd7 100644 --- a/neqo-client/src/main.rs +++ b/neqo-client/src/main.rs @@ -763,7 +763,7 @@ fn to_headers(values: &[impl AsRef]) -> Vec
{ struct ClientRunner<'a> { local_addr: SocketAddr, - socket: &'a udp::Socket, + socket: &'a mut udp::Socket, client: Http3Client, handler: Handler<'a>, timeout: Option>>, @@ -773,7 +773,7 @@ struct ClientRunner<'a> { impl<'a> ClientRunner<'a> { fn new( args: &'a mut Args, - socket: &'a udp::Socket, + socket: &'a mut udp::Socket, local_addr: SocketAddr, remote_addr: SocketAddr, hostname: &str, @@ -998,7 +998,7 @@ async fn main() -> Res<()> { SocketAddr::V6(..) => SocketAddr::new(IpAddr::V6(Ipv6Addr::from([0; 16])), 0), }; - let socket = udp::Socket::bind(local_addr)?; + let mut socket = udp::Socket::bind(local_addr)?; let real_local = socket.local_addr().unwrap(); println!( "{} Client connecting: {:?} -> {:?}", @@ -1022,7 +1022,7 @@ async fn main() -> Res<()> { token = if args.use_old_http { old::ClientRunner::new( &args, - &socket, + &mut socket, real_local, remote_addr, &hostname, @@ -1034,7 +1034,7 @@ async fn main() -> Res<()> { } else { ClientRunner::new( &mut args, - &socket, + &mut socket, real_local, remote_addr, &hostname, @@ -1249,7 +1249,7 @@ mod old { pub struct ClientRunner<'a> { local_addr: SocketAddr, - socket: &'a udp::Socket, + socket: &'a mut udp::Socket, client: Connection, handler: HandlerOld<'a>, timeout: Option>>, @@ -1259,7 +1259,7 @@ mod old { impl<'a> ClientRunner<'a> { pub fn new( args: &'a Args, - socket: &'a udp::Socket, + socket: &'a mut udp::Socket, local_addr: SocketAddr, remote_addr: SocketAddr, origin: &str, diff --git a/neqo-common/src/udp.rs b/neqo-common/src/udp.rs index 7ad0b97625..64fc356760 100644 --- a/neqo-common/src/udp.rs +++ b/neqo-common/src/udp.rs @@ -21,6 +21,7 @@ use crate::{Datagram, IpTos}; pub struct Socket { socket: tokio::net::UdpSocket, state: UdpSocketState, + recv_buf: Vec, } impl Socket { @@ -31,6 +32,7 @@ impl Socket { Ok(Self { state: quinn_udp::UdpSocketState::new((&socket).into())?, socket: tokio::net::UdpSocket::from_std(socket)?, + recv_buf: vec![0; u16::MAX as usize], }) } @@ -70,15 +72,13 @@ impl Socket { } /// Receive a UDP datagram on the specified socket. - pub fn recv(&self, local_address: &SocketAddr) -> Result, io::Error> { - let mut buf = [0; u16::MAX as usize]; - + pub fn recv(&mut self, local_address: &SocketAddr) -> Result, io::Error> { let mut meta = RecvMeta::default(); match self.socket.try_io(Interest::READABLE, || { self.state.recv( (&self.socket).into(), - &mut [IoSliceMut::new(&mut buf)], + &mut [IoSliceMut::new(&mut self.recv_buf)], slice::from_mut(&mut meta), ) }) { @@ -101,8 +101,11 @@ impl Socket { return Ok(None); } - if meta.len == buf.len() { - eprintln!("Might have received more than {} bytes", buf.len()); + if meta.len == self.recv_buf.len() { + eprintln!( + "Might have received more than {} bytes", + self.recv_buf.len() + ); } Ok(Some(Datagram::new( @@ -110,7 +113,7 @@ impl Socket { *local_address, meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 - &buf[..meta.len], + &self.recv_buf[..meta.len], ))) } } @@ -124,7 +127,7 @@ mod tests { async fn datagram_tos() -> Result<(), io::Error> { let sender = Socket::bind("127.0.0.1:0")?; let receiver_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); - let receiver = Socket::bind(receiver_addr)?; + let mut receiver = Socket::bind(receiver_addr)?; let datagram = Datagram::new( sender.local_addr()?,