diff --git a/.clusterfuzzlite/Dockerfile b/.clusterfuzzlite/Dockerfile new file mode 100644 index 000000000..b57da8caa --- /dev/null +++ b/.clusterfuzzlite/Dockerfile @@ -0,0 +1,21 @@ +FROM gcr.io/oss-fuzz-base/base-builder-go:v1 + +ARG TARGETPLATFORM +RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}" + +ENV GOVERSION=1.20.7 + +RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \ + filename="go${GOVERSION}.${platform}.tar.gz" && \ + wget https://dl.google.com/go/${filename} && \ + mkdir temp-go && \ + rm -rf /root/.go/* && \ + tar -C temp-go/ -xzf ${filename} && \ + mv temp-go/go/* /root/.go/ && \ + rm -r ${filename} temp-go + +RUN apt-get update && apt-get install -y make autoconf automake libtool + +COPY . $SRC/quic-go +WORKDIR quic-go +COPY .clusterfuzzlite/build.sh $SRC/ diff --git a/.clusterfuzzlite/build.sh b/.clusterfuzzlite/build.sh new file mode 100755 index 000000000..e7a9d4113 --- /dev/null +++ b/.clusterfuzzlite/build.sh @@ -0,0 +1,9 @@ +#!/bin/bash -eu + +export CXX="${CXX} -lresolv" # required by Go 1.20 + +compile_go_fuzzer github.com/refraction-networking/uquic/fuzzing/frames Fuzz frame_fuzzer +compile_go_fuzzer github.com/refraction-networking/uquic/fuzzing/header Fuzz header_fuzzer +compile_go_fuzzer github.com/refraction-networking/uquic/fuzzing/transportparameters Fuzz transportparameter_fuzzer +compile_go_fuzzer github.com/refraction-networking/uquic/fuzzing/tokens Fuzz token_fuzzer +compile_go_fuzzer github.com/refraction-networking/uquic/fuzzing/handshake Fuzz handshake_fuzzer diff --git a/.clusterfuzzlite/project.yaml b/.clusterfuzzlite/project.yaml new file mode 100644 index 000000000..4f2ee4d97 --- /dev/null +++ b/.clusterfuzzlite/project.yaml @@ -0,0 +1 @@ +language: go diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..7de30a1d7 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,13 @@ +# These are supported funding model platforms + +github: [marten-seemann] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/README.md b/README.md index e2aea327f..a594a4a4c 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Ginkgo Test Status](https://github.com/refraction-networking/uquic/actions/workflows/ginkgo_test.yml/badge.svg?branch=master)](https://github.com/refraction-networking/uquic/actions/workflows/ginkgo_test.yml) [![godoc](https://img.shields.io/badge/godoc-reference-blue.svg)](https://godoc.org/github.com/refraction-networking/uquic) --- -uQUIC is a fork of [quic-go](https://github.com/quic-go/quic-go), which provides Initial Packet fingerprinting resistance and other features. While the handshake is still performed by quic-go, this library provides interface to customize the unencrypted Initial Packet which may reveal fingerprint-able information. +uQUIC is a fork of [quic-go](https://github.com/refraction-networking/uquic), which provides Initial Packet fingerprinting resistance and other features. While the handshake is still performed by quic-go, this library provides interface to customize the unencrypted Initial Packet which may reveal fingerprint-able information. Golang 1.20+ is required. @@ -32,7 +32,7 @@ If you are interested in our research, please stay tuned for our paper. - [ ] QUIC ACK Frame (on hold) - [x] TLS ClientHello Message (by [uTLS](https://github.com/refraction-networking/utls)) - [x] QUIC Transport Parameters (in a uTLS extension) -- [ ] Customize Initial ACK behavior ([#1](https://github.com/refraction-networking/uquic/issues/1), [quic-go#4007](https://github.com/quic-go/quic-go/issues/4007)) +- [ ] Customize Initial ACK behavior ([#1](https://github.com/refraction-networking/uquic/issues/1), [quic-go#4007](https://github.com/refraction-networking/uquic/issues/4007)) - [ ] Customize Initial Retry behavior ([#2](https://github.com/refraction-networking/uquic/issues/2)) - [ ] Add preset QUIC parrots - [x] Google Chrome parrot (call for parrots w/ `Token/PSK`) diff --git a/client.go b/client.go index cc8c1bd10..c70937f56 100644 --- a/client.go +++ b/client.go @@ -35,7 +35,7 @@ type client struct { conn quicConn - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer tracingID uint64 logger utils.Logger } @@ -155,7 +155,7 @@ func dial( if c.config.Tracer != nil { c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) } - if c.tracer != nil { + if c.tracer != nil && c.tracer.StartedConnection != nil { c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) } @@ -238,7 +238,7 @@ func (c *client) dial(ctx context.Context) error { select { case <-ctx.Done(): c.conn.shutdown() - return ctx.Err() + return context.Cause(ctx) case err := <-errorChan: return err case recreateErr := <-recreateChan: diff --git a/client_test.go b/client_test.go index b12fec931..c55953f38 100644 --- a/client_test.go +++ b/client_test.go @@ -44,7 +44,7 @@ var _ = Describe("Client", func() { initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -55,10 +55,11 @@ var _ = Describe("Client", func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) originalClientConnConstructor = newClientConnection - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) config = &Config{ - Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer { - return tracer + Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer { + return tr }, Versions: []protocol.VersionNumber{protocol.Version1}, } @@ -71,7 +72,7 @@ var _ = Describe("Client", func() { destConnID: connID, version: protocol.Version1, sendConn: packetConn, - tracer: tracer, + tracer: tr, logger: utils.DefaultLogger, } getMultiplexer() // make the sync.Once execute @@ -122,7 +123,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, enable0RTT bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -159,7 +160,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, enable0RTT bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -196,7 +197,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -281,7 +282,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, versionP protocol.VersionNumber, @@ -324,7 +325,7 @@ var _ = Describe("Client", func() { pn protocol.PacketNumber, _ bool, hasNegotiatedVersion bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, versionP protocol.VersionNumber, diff --git a/codecov.yml b/codecov.yml index 694351509..a24c7a15e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,19 +1,10 @@ coverage: round: nearest ignore: - - streams_map_incoming_bidi.go - - streams_map_incoming_uni.go - - streams_map_outgoing_bidi.go - - streams_map_outgoing_uni.go - http3/gzip_reader.go - interop/ - - internal/ackhandler/packet_linkedlist.go - internal/handshake/cipher_suite.go - - internal/utils/byteinterval_linkedlist.go - - internal/utils/newconnectionid_linkedlist.go - - internal/utils/packetinterval_linkedlist.go - internal/utils/linkedlist/linkedlist.go - - logging/null_tracer.go - fuzzing/ - metrics/ status: diff --git a/config.go b/config.go index cf468d4fa..501ed1a07 100644 --- a/config.go +++ b/config.go @@ -6,7 +6,6 @@ import ( "time" "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" ) @@ -17,7 +16,11 @@ func (c *Config) Clone() *Config { } func (c *Config) handshakeTimeout() time.Duration { - return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) + return 2 * c.HandshakeIdleTimeout +} + +func (c *Config) maxRetryTokenAge() time.Duration { + return c.handshakeTimeout() } func validateConfig(config *Config) error { @@ -50,12 +53,6 @@ func validateConfig(config *Config) error { // it may be called with nil func populateServerConfig(config *Config) *Config { config = populateConfig(config) - if config.MaxTokenAge == 0 { - config.MaxTokenAge = protocol.TokenValidity - } - if config.MaxRetryTokenAge == 0 { - config.MaxRetryTokenAge = protocol.RetryTokenValidity - } if config.RequireAddressValidation == nil { config.RequireAddressValidation = func(net.Addr) bool { return false } } @@ -110,27 +107,24 @@ func populateConfig(config *Config) *Config { } return &Config{ - GetConfigForClient: config.GetConfigForClient, - Versions: versions, - HandshakeIdleTimeout: handshakeIdleTimeout, - MaxIdleTimeout: idleTimeout, - MaxTokenAge: config.MaxTokenAge, - MaxRetryTokenAge: config.MaxRetryTokenAge, - RequireAddressValidation: config.RequireAddressValidation, - KeepAlivePeriod: config.KeepAlivePeriod, - InitialStreamReceiveWindow: initialStreamReceiveWindow, - MaxStreamReceiveWindow: maxStreamReceiveWindow, - InitialConnectionReceiveWindow: initialConnectionReceiveWindow, - MaxConnectionReceiveWindow: maxConnectionReceiveWindow, - AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, - MaxIncomingStreams: maxIncomingStreams, - MaxIncomingUniStreams: maxIncomingUniStreams, - TokenStore: config.TokenStore, - EnableDatagrams: config.EnableDatagrams, - DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, - DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, - Allow0RTT: config.Allow0RTT, - Tracer: config.Tracer, + GetConfigForClient: config.GetConfigForClient, + Versions: versions, + HandshakeIdleTimeout: handshakeIdleTimeout, + MaxIdleTimeout: idleTimeout, + RequireAddressValidation: config.RequireAddressValidation, + KeepAlivePeriod: config.KeepAlivePeriod, + InitialStreamReceiveWindow: initialStreamReceiveWindow, + MaxStreamReceiveWindow: maxStreamReceiveWindow, + InitialConnectionReceiveWindow: initialConnectionReceiveWindow, + MaxConnectionReceiveWindow: maxConnectionReceiveWindow, + AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, + DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, + Allow0RTT: config.Allow0RTT, + Tracer: config.Tracer, } } diff --git a/config_test.go b/config_test.go index 02db453fa..e0eef4304 100644 --- a/config_test.go +++ b/config_test.go @@ -78,10 +78,6 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(time.Second)) case "MaxIdleTimeout": f.Set(reflect.ValueOf(time.Hour)) - case "MaxTokenAge": - f.Set(reflect.ValueOf(2 * time.Hour)) - case "MaxRetryTokenAge": - f.Set(reflect.ValueOf(2 * time.Minute)) case "TokenStore": f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) case "InitialStreamReceiveWindow": @@ -115,12 +111,7 @@ var _ = Describe("Config", func() { return c } - It("uses 10s handshake timeout for short handshake idle timeouts", func() { - c := &Config{HandshakeIdleTimeout: time.Second} - Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout)) - }) - - It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() { + It("uses twice the handshake idle timeouts for the handshake timeout", func() { c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) }) @@ -132,7 +123,7 @@ var _ = Describe("Config", func() { GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, - Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { calledTracer = true return nil }, @@ -192,7 +183,6 @@ var _ = Describe("Config", func() { Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) - Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) Expect(c.DisablePathMTUDiscovery).To(BeFalse()) Expect(c.GetConfigForClient).To(BeNil()) }) diff --git a/connection.go b/connection.go index 270f6ec0f..c7e0d4a35 100644 --- a/connection.go +++ b/connection.go @@ -209,7 +209,7 @@ type connection struct { connState ConnectionState logID string - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -233,7 +233,7 @@ var newConnection = func( tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, clientAddressValidated bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -244,7 +244,7 @@ var newConnection = func( handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(true), + oneRTTStream: newCryptoStream(), perspective: protocol.PerspectiveServer, tracer: tracer, logger: logger, @@ -279,6 +279,7 @@ var newConnection = func( getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, clientAddressValidated, + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, @@ -301,7 +302,7 @@ var newConnection = func( // different from protocol.DefaultActiveConnectionIDLimit. // If set to the default value, it will be omitted from the transport parameters, which will make // old quic-go versions interpret it as 0, instead of the default value of 2. - // See https://github.com/quic-go/quic-go/pull/3806. + // See https://github.com/refraction-networking/uquic/pull/3806. ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, RetrySourceConnectionID: retrySrcConnID, @@ -311,7 +312,7 @@ var newConnection = func( } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupServer( @@ -345,7 +346,7 @@ var newClientConnection = func( initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -387,14 +388,15 @@ var newClientConnection = func( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, - false, /* has no effect */ + false, // has no effect + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream(true) + oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -410,7 +412,7 @@ var newClientConnection = func( // different from protocol.DefaultActiveConnectionIDLimit. // If set to the default value, it will be omitted from the transport parameters, which will make // old quic-go versions interpret it as 0, instead of the default value of 2. - // See https://github.com/quic-go/quic-go/pull/3806. + // See https://github.com/refraction-networking/uquic/pull/3806. ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, } @@ -419,8 +421,7 @@ var newClientConnection = func( } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupClient( @@ -452,8 +453,8 @@ var newClientConnection = func( } func (s *connection) preSetup() { - s.initialStream = newCryptoStream(false) - s.handshakeStream = newCryptoStream(false) + s.initialStream = newCryptoStream() + s.handshakeStream = newCryptoStream() s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) @@ -645,8 +646,10 @@ runLoop: s.cryptoStreamHandler.Close() s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE s.handleCloseError(&closeErr) - if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { - s.tracer.Close() + if s.tracer != nil && s.tracer.Close != nil { + if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) { + s.tracer.Close() + } } s.logger.Infof("Connection %s closed.", s.logID) s.timer.Stop() @@ -676,6 +679,7 @@ func (s *connection) ConnectionState() ConnectionState { cs := s.cryptoStreamHandler.ConnectionState() s.connState.TLS = cs.ConnectionState s.connState.Used0RTT = cs.Used0RTT + s.connState.GSO = s.conn.capabilities().GSO return s.connState } @@ -803,14 +807,14 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { var err error destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) break } if destConnID != lastConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) @@ -821,7 +825,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if wire.IsLongHeaderPacket(p.data[0]) { hdr, packetData, rest, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { dropReason := logging.PacketDropHeaderParseError if err == wire.ErrUnsupportedVersion { dropReason = logging.PacketDropUnsupportedVersion @@ -834,7 +838,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { lastConnID = hdr.DestConnectionID if hdr.Version != s.version { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) } s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) @@ -893,14 +897,14 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) } return false } var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedShortHeaderPacket != nil { log = func(frames []logging.Frame) { s.tracer.ReceivedShortHeaderPacket( &logging.ShortHeader{ @@ -910,6 +914,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc KeyPhase: keyPhase, }, p.Size(), + p.ecn, frames, ) } @@ -932,13 +937,13 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) }() if hdr.Type == protocol.PacketTypeRetry { - return s.handleRetryPacket(hdr, p.data) + return s.handleRetryPacket(hdr, p.data, p.rcvTime) } // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) @@ -946,7 +951,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) } // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) } return false @@ -965,7 +970,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) } return false @@ -981,7 +986,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { switch err { case handshake.ErrKeysDropped: - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) } s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) @@ -997,7 +1002,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P }) case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) @@ -1005,7 +1010,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P var headerErr *headerParseError if errors.As(err, &headerErr) { // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) @@ -1018,16 +1023,16 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P return false } -func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { +func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry.") return false } if s.receivedFirstPacket { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since we already received a packet.") @@ -1035,7 +1040,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa } destConnID := s.connIDManager.Get() if hdr.SrcConnectionID == destConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") @@ -1050,7 +1055,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) if !bytes.Equal(data[len(data)-16:], tag[:]) { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") @@ -1062,12 +1067,12 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedRetry != nil { s.tracer.ReceivedRetry(hdr) } newDestConnID := hdr.SrcConnectionID s.receivedRetry = true - if err := s.sentPacketHandler.ResetForRetry(); err != nil { + if err := s.sentPacketHandler.ResetForRetry(rcvTime); err != nil { s.closeLocal(err) return false } @@ -1083,7 +1088,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return @@ -1091,7 +1096,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) @@ -1100,7 +1105,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { for _, v := range supportedVersions { if v == s.version { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) } // The Version Negotiation packet contains the version that we offered. @@ -1110,7 +1115,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { } s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedVersionNegotiationPacket != nil { s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) @@ -1122,7 +1127,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { s.logger.Infof("No compatible QUIC version found.") return } - if s.tracer != nil { + if s.tracer != nil && s.tracer.NegotiatedVersion != nil { s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) } @@ -1142,7 +1147,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( ) error { if !s.receivedFirstPacket { s.receivedFirstPacket = true - if !s.versionNegotiated && s.tracer != nil { + if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil { var clientVersions, serverVersions []protocol.VersionNumber switch s.perspective { case protocol.PerspectiveClient: @@ -1169,7 +1174,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.handshakeDestConnID = packet.hdr.SrcConnectionID s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.StartedConnection != nil { s.tracer.StartedConnection( s.conn.LocalAddr(), s.conn.RemoteAddr(), @@ -1193,9 +1198,9 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.keepAlivePingSent = false var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedLongHeaderPacket != nil { log = func(frames []logging.Frame) { - s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames) + s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames) } } isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log) @@ -1341,7 +1346,7 @@ func (s *connection) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: default: - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } @@ -1621,7 +1626,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { s.datagramQueue.CloseWithError(e) } - if s.tracer != nil && !errors.As(e, &recreateErr) { + if s.tracer != nil && s.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) { s.tracer.ClosedConnection(e) } @@ -1648,7 +1653,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil { s.tracer.DroppedEncryptionLevel(encLevel) } s.sentPacketHandler.DropPackets(encLevel) @@ -1683,7 +1688,7 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters } func (s *connection) handleTransportParameters(params *wire.TransportParameters) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedTransportParameters != nil { s.tracer.ReceivedTransportParameters(params) } if err := s.checkTransportParameters(params); err != nil { @@ -1835,9 +1840,10 @@ func (s *connection) sendPackets(now time.Time) error { if err != nil { return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + ecn := s.sentPacketHandler.ECNMode(true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1857,7 +1863,7 @@ func (s *connection) sendPackets(now time.Time) error { return err } s.sentFirstPacket = true - if err := s.sendPackedCoalescedPacket(packet, now); err != nil { + if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil { return err } sendMode := s.sentPacketHandler.SendMode(now) @@ -1878,7 +1884,8 @@ func (s *connection) sendPackets(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error { for { buf := getPacketBuffer() - if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { + ecn := s.sentPacketHandler.ECNMode(true) + if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { if err == errNothingToPack { buf.Release() return nil @@ -1886,7 +1893,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error { return err } - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0, ecn) if s.sendQueue.WouldBlock() { return nil @@ -1911,9 +1918,10 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { buf := getLargePacketBuffer() maxSize := s.mtuDiscoverer.CurrentSize() + ecn := s.sentPacketHandler.ECNMode(true) for { var dontSendMore bool - size, err := s.appendPacket(buf, maxSize, now) + size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { return err @@ -1935,15 +1943,19 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { } } + // Don't send more packets in this batch if they require a different ECN marking than the previous ones. + nextECN := s.sentPacketHandler.ECNMode(true) + // Append another packet if // 1. The congestion controller and pacer allow sending more // 2. The last packet appended was a full-size packet - // 3. We still have enough space for another full-size packet in the buffer - if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() { + // 3. The next packet will have the same ECN marking + // 4. We still have enough space for another full-size packet in the buffer + if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() { continue } - s.sendQueue.Send(buf, maxSize) + s.sendQueue.Send(buf, uint16(maxSize), ecn) if dontSendMore { return nil @@ -1972,6 +1984,7 @@ func (s *connection) resetPacingDeadline() { func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if !s.handshakeConfirmed { + ecn := s.sentPacketHandler.ECNMode(false) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err @@ -1979,9 +1992,10 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if packet == nil { return nil } - return s.sendPackedCoalescedPacket(packet, time.Now()) + return s.sendPackedCoalescedPacket(packet, ecn, time.Now()) } + ecn := s.sentPacketHandler.ECNMode(true) p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { @@ -1989,9 +2003,9 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { } return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) return nil } @@ -2023,24 +2037,24 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - return s.sendPackedCoalescedPacket(packet, now) + return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now) } -// appendPacket appends a new packet to the given packetBuffer. +// appendOneShortHeaderPacket appends a new packet to the given packetBuffer. // If there was nothing to pack, the returned size is 0. -func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { +func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { startLen := buf.Len() p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { return 0, err } size := buf.Len() - startLen - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) - s.registerPackedShortHeaderPacket(p, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false) + s.registerPackedShortHeaderPacket(p, ecn, now) return size, nil } -func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) { +func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } @@ -2049,12 +2063,12 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) s.connIDManager.SentPacket() } -func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error { - s.logCoalescedPacket(packet) +func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error { + s.logCoalescedPacket(packet, ecn) for _, p := range packet.longHdrPackets { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now @@ -2063,7 +2077,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.ack != nil { largestAcked = p.ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) + s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. @@ -2080,11 +2094,10 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer, packet.buffer.Len()) - + s.sendQueue.Send(packet.buffer, 0, ecn) return nil } @@ -2106,11 +2119,12 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { if err != nil { return nil, err } - s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len()) + ecn := s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()) + s.logCoalescedPacket(packet, ecn) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn) } -func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { +func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) { // quic-go logging if s.logger.Debug() { p.header.Log(s.logger) @@ -2126,7 +2140,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil { frames := make([]logging.Frame, 0, len(p.frames)) for _, f := range p.frames { frames = append(frames, logutils.ConvertFrame(f.Frame)) @@ -2138,7 +2152,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { if p.ack != nil { ack = logutils.ConvertAckFrame(p.ack) } - s.tracer.SentLongHeaderPacket(p.header, p.length, ack, frames) + s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames) } } @@ -2150,11 +2164,12 @@ func (s *connection) logShortHeaderPacket( pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, + ecn protocol.ECN, size protocol.ByteCount, isCoalesced bool, ) { if s.logger.Debug() && !isCoalesced { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID) + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn) } // quic-go logging if s.logger.Debug() { @@ -2171,7 +2186,7 @@ func (s *connection) logShortHeaderPacket( } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil { fs := make([]logging.Frame, 0, len(frames)+len(streamFrames)) for _, f := range frames { fs = append(fs, logutils.ConvertFrame(f.Frame)) @@ -2191,13 +2206,14 @@ func (s *connection) logShortHeaderPacket( KeyPhase: kp, }, size, + ecn, ack, fs, ) } } -func (s *connection) logCoalescedPacket(packet *coalescedPacket) { +func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) { if s.logger.Debug() { // There's a short period between dropping both Initial and Handshake keys and completion of the handshake, // during which we might call PackCoalescedPacket but just pack a short header packet. @@ -2210,6 +2226,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.KeyPhase, + ecn, packet.shortHdrPacket.Length, false, ) @@ -2222,10 +2239,10 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { } } for _, p := range packet.longHdrPackets { - s.logLongHeaderPacket(p) + s.logLongHeaderPacket(p, ecn) } if p := packet.shortHdrPacket; p != nil { - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true) } } @@ -2291,14 +2308,14 @@ func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging panic("shouldn't queue undecryptable packets after handshake completion") } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) } s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return } s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.BufferedPacket != nil { s.tracer.BufferedPacket(pt, p.Size()) } s.undecryptablePackets = append(s.undecryptablePackets, p) diff --git a/connection_test.go b/connection_test.go index de0983eaa..87cf05c95 100644 --- a/connection_test.go +++ b/connection_test.go @@ -101,9 +101,9 @@ var _ = Describe("Connection", func() { mconn.EXPECT().capabilities().DoAndReturn(func() connCapabilities { return capabilities }).AnyTimes() mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + tokenGenerator := handshake.NewTokenGenerator([32]byte{0xa, 0xb, 0xc}) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() @@ -122,7 +122,7 @@ var _ = Describe("Connection", func() { &tls.Config{}, tokenGenerator, false, - tracer, + tr, 1234, utils.DefaultLogger, protocol.Version1, @@ -455,7 +455,7 @@ var _ = Describe("Connection", func() { Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil }) - mconn.EXPECT().Write([]byte("connection close"), gomock.Any()) + mconn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { var appErr *ApplicationError @@ -476,7 +476,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -495,7 +495,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -517,7 +517,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -566,7 +566,7 @@ var _ = Describe("Connection", func() { close(returned) }() Consistently(returned).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -591,7 +591,7 @@ var _ = Describe("Connection", func() { return 3, protocol.PacketNumberLen2, protocol.KeyPhaseOne, b, nil }) gomock.InOrder( - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()), + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().Close(), ) @@ -610,14 +610,15 @@ var _ = Describe("Connection", func() { conn.handshakeConfirmed = true sconn := NewMockSendConn(mockCtrl) sconn.EXPECT().capabilities().AnyTimes() - sconn.EXPECT().Write(gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() + sconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() conn.sendQueue = newSendQueue(sconn) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() + sph.EXPECT().ECNMode(true).Return(protocol.ECT1).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() // only expect a single SentPacket() call - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -778,7 +779,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph packet.rcvTime = rcvTime tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), []logging.Frame{}) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{}) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -796,7 +797,12 @@ var _ = Describe("Connection", func() { ) conn.receivedPacketHandler = rph packet.rcvTime = rcvTime - tracer.EXPECT().ReceivedShortHeaderPacket(&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) + tracer.EXPECT().ReceivedShortHeaderPacket( + &logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, + protocol.ByteCount(len(packet.data)), + logging.ECT1, + []logging.Frame{&logging.PingFrame{}}, + ) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -838,7 +844,7 @@ var _ = Describe("Connection", func() { // make the go routine return tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -850,8 +856,7 @@ var _ = Describe("Connection", func() { pn++ return pn, protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { - }).Times(3) + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version) // only expect a single call for i := 0; i < 3; i++ { @@ -873,7 +878,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -886,8 +891,7 @@ var _ = Describe("Connection", func() { pn++ return pn, protocol.PacketNumberLen4, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { - }).Times(3) + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Times(3) for i := 0; i < 3; i++ { @@ -909,7 +913,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -931,7 +935,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) packet := getShortHeaderPacket(srcConnID, 0x42, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -959,7 +963,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -981,7 +985,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) @@ -1021,7 +1025,7 @@ var _ = Describe("Connection", func() { }, nil) p1 := getLongHeaderPacket(hdr1, nil) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(p1)).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. p2 := getLongHeaderPacket(hdr2, nil) @@ -1054,7 +1058,7 @@ var _ = Describe("Connection", func() { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil) packet := getShortHeaderPacket(srcConnID, 0x42, nil) packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) }) @@ -1094,12 +1098,13 @@ var _ = Describe("Connection", func() { }) cryptoSetup.EXPECT().DiscardInitialKeys() tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) It("handles coalesced packets", func() { hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) + packet1.ecn = protocol.ECT1 unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ @@ -1126,8 +1131,8 @@ var _ = Describe("Connection", func() { tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes() cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() gomock.InOrder( - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), logging.ECT1, gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.ECT1, gomock.Any()), ) packet1.data = append(packet1.data, packet2.data...) Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) @@ -1152,7 +1157,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() gomock.InOrder( tracer.EXPECT().BufferedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data))), - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any(), gomock.Any()), ) packet1.data = append(packet1.data, packet2.data...) Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) @@ -1178,7 +1183,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() // don't EXPECT any more calls to unpacker.UnpackLongHeader() gomock.InOrder( - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any(), gomock.Any()), tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), ) packet1.data = append(packet1.data, packet2.data...) @@ -1191,6 +1196,7 @@ var _ = Describe("Connection", func() { var ( connDone chan struct{} sender *MockSender + sph *mockackhandler.MockSentPacketHandler ) BeforeEach(func() { @@ -1199,14 +1205,17 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() conn.sendQueue = sender connDone = make(chan struct{}) + sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph }) AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNCE).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1227,12 +1236,11 @@ var _ = Describe("Connection", func() { It("sends packets", func() { conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.sentPacketHandler = sph + sph.EXPECT().ECNMode(true).Return(protocol.ECNNon).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() p := shortHeaderPacket{ DestConnID: protocol.ParseConnectionID([]byte{1, 2, 3}), @@ -1244,19 +1252,22 @@ var _ = Describe("Connection", func() { packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{ DestConnectionID: p.DestConnID, PacketNumber: p.PacketNumber, PacketNumberLen: p.PacketNumberLen, KeyPhase: p.KeyPhase, - }, gomock.Any(), nil, []logging.Frame{}) + }, gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) It("doesn't send packets if there's nothing to send", func() { conn.handshakeConfirmed = true + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(true).AnyTimes() runConn() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) @@ -1265,13 +1276,12 @@ var _ = Describe("Connection", func() { }) It("sends ACK only packets", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() done := make(chan struct{}) packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) - conn.sentPacketHandler = sph runConn() conn.scheduleSending() Eventually(done).Should(BeClosed()) @@ -1279,12 +1289,11 @@ var _ = Describe("Connection", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.sentPacketHandler = sph + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 13}, []byte("foobar")) @@ -1292,8 +1301,8 @@ var _ = Describe("Connection", func() { conn.connFlowController = fc runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) frames, _ := conn.framer.AppendControlFrames(nil, 1000, protocol.Version1) @@ -1301,11 +1310,9 @@ var _ = Describe("Connection", func() { }) It("doesn't send when the SentPacketHandler doesn't allow it", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone).AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() - conn.sentPacketHandler = sph runConn() conn.scheduleSending() time.Sleep(50 * time.Millisecond) @@ -1334,50 +1341,45 @@ var _ = Describe("Connection", func() { }) It("sends a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) + sph.EXPECT().ECNMode(gomock.Any()) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(123))) - }) + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any(), gomock.Any()) } else { - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any(), gomock.Any()) } conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) It("sends a PING as a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT0) sph.EXPECT().QueueProbePacket(encLevel).Return(false) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(123))) - }) - conn.sentPacketHandler = sph + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, logging.ECT0, gomock.Any(), gomock.Any()) } else { - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, logging.ECT0, gomock.Any(), gomock.Any()) } conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1396,7 +1398,7 @@ var _ = Describe("Connection", func() { ) BeforeEach(func() { - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() conn.handshakeConfirmed = true @@ -1410,10 +1412,11 @@ var _ = Describe("Connection", func() { AfterEach(func() { // make the go routine return + sph.EXPECT().ECNMode(gomock.Any()).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1422,17 +1425,18 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets one by one immediately", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal([]byte("packet10"))) }) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal([]byte("packet11"))) }) go func() { @@ -1447,7 +1451,8 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately, with GSO", func() { enableGSO() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(4) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1457,7 +1462,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1472,19 +1477,59 @@ var _ = Describe("Connection", func() { It("stops appending packets when a smaller packet is packed, with GSO", func() { enableGSO() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode(true).Times(4) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()-1) rand.Read(payload2) + payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload3) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 12}, payload3) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { + Expect(b.Data).To(Equal(payload3)) + }) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent + }) + + It("stops appending packets when the ECN marking changes, with GSO", func() { + enableGSO() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(2) + sph.EXPECT().ECNMode(true).Return(protocol.ECT0).Times(2) + payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload1) + payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload2) + payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload3) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload3) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { + Expect(b.Data).To(Equal(append(payload1, payload2...))) + }) + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { + Expect(b.Data).To(Equal(payload3)) + }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1496,12 +1541,13 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets, when the pacer allows immediate sending", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1513,13 +1559,14 @@ var _ = Describe("Connection", func() { }) It("allows an ACK to be sent when pacing limited", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) + sph.EXPECT().ECNMode(gomock.Any()) packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1533,12 +1580,13 @@ var _ = Describe("Connection", func() { // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck // we shouldn't send the ACK in the same run It("doesn't send an ACK right after becoming congestion limited", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1553,19 +1601,21 @@ var _ = Describe("Connection", func() { pacingDelay := scaleDuration(100 * time.Millisecond) gomock.InOrder( sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().ECNMode(gomock.Any()), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().ECNMode(gomock.Any()), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), ) written := make(chan struct{}, 2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1579,8 +1629,9 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets at once", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) + sph.EXPECT().ECNMode(gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { @@ -1588,7 +1639,7 @@ var _ = Describe("Connection", func() { } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(3) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1619,11 +1670,12 @@ var _ = Describe("Connection", func() { written := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) }) @@ -1640,14 +1692,15 @@ var _ = Describe("Connection", func() { written := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ByteCount, bool) { + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool) { sph.EXPECT().ReceivedBytes(gomock.Any()) conn.handlePacket(receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) @@ -1656,13 +1709,14 @@ var _ = Describe("Connection", func() { }) It("stops sending when the send queue is full", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) + sph.EXPECT().ECNMode(gomock.Any()) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1676,12 +1730,13 @@ var _ = Describe("Connection", func() { time.Sleep(scaleDuration(50 * time.Millisecond)) // now make room in the send queue - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1692,6 +1747,7 @@ var _ = Describe("Connection", func() { It("doesn't set a pacing timer when there is no data to send", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) // don't EXPECT any calls to mconn.Write() @@ -1709,12 +1765,13 @@ var _ = Describe("Connection", func() { mtuDiscoverer := NewMockMTUDiscoverer(mockCtrl) conn.mtuDiscoverer = mtuDiscoverer conn.config.DisablePathMTUDiscovery = false - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) + sph.EXPECT().ECNMode(true) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) @@ -1748,7 +1805,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) sender.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -1761,8 +1818,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1}, []byte("packet1")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) @@ -1777,8 +1835,8 @@ var _ = Describe("Connection", func() { time.Sleep(50 * time.Millisecond) // only EXPECT calls after scheduleSending is called written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() conn.scheduleSending() Eventually(written).Should(BeClosed()) }) @@ -1789,9 +1847,8 @@ var _ = Describe("Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(1234))) - }) + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(1234), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) @@ -1800,8 +1857,8 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1842,30 +1899,23 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(false).Return(protocol.ECT1).AnyTimes() sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() gomock.InOrder( - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) { - Expect(encLevel).To(Equal(protocol.EncryptionInitial)) - Expect(pn).To(Equal(protocol.PacketNumber(13))) - Expect(size).To(BeEquivalentTo(123)) - }), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) { - Expect(encLevel).To(Equal(protocol.EncryptionHandshake)) - Expect(pn).To(Equal(protocol.PacketNumber(37))) - Expect(size).To(BeEquivalentTo(1234)) - }), + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(13), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionInitial, protocol.ECT1, protocol.ByteCount(123), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(37), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionHandshake, protocol.ECT1, protocol.ByteCount(1234), gomock.Any()), ) gomock.InOrder( - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ logging.ECN, _ *wire.AckFrame, _ []logging.Frame) { Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) }), - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ logging.ECN, _ *wire.AckFrame, _ []logging.Frame) { Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) }), ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar"), protocol.ByteCount(6)).Do(func([]byte, protocol.ByteCount) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(sent) }) go func() { defer GinkgoRecover() @@ -1882,7 +1932,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -1953,7 +2003,7 @@ var _ = Describe("Connection", func() { }() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("handshake error")) Consistently(handshakeCtx).ShouldNot(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) @@ -1962,12 +2012,13 @@ var _ = Describe("Connection", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SetHandshakeConfirmed() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) @@ -1988,7 +2039,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handleHandshakeComplete()).To(Succeed()) conn.run() }() @@ -2017,7 +2068,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2044,7 +2095,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) @@ -2103,7 +2154,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2202,7 +2253,7 @@ var _ = Describe("Connection", func() { It("times out due to non-completed handshake", func() { conn.handshakeComplete = false - conn.creationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) + conn.creationTime = time.Now().Add(-2 * protocol.DefaultHandshakeIdleTimeout).Add(-time.Second) connRunner.EXPECT().Remove(gomock.Any()).Times(2) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -2256,13 +2307,13 @@ var _ = Describe("Connection", func() { // make the go routine return expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) It("closes the connection due to the idle timeout before handshake", func() { - conn.config.HandshakeIdleTimeout = 0 + conn.config.HandshakeIdleTimeout = scaleDuration(25 * time.Millisecond) packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() @@ -2339,7 +2390,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2491,7 +2542,8 @@ var _ = Describe("Client Connection", func() { tlsConf = &tls.Config{} } connRunner = NewMockConnRunner(mockCtrl) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() @@ -2507,7 +2559,7 @@ var _ = Describe("Client Connection", func() { 42, // initial packet number false, false, - tracer, + tr, 1234, utils.DefaultLogger, protocol.Version1, @@ -2542,7 +2594,7 @@ var _ = Describe("Client Connection", func() { }, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), []logging.Frame{}) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), gomock.Any(), []logging.Frame{}) Expect(conn.handlePacketImpl(p)).To(BeTrue()) go func() { defer GinkgoRecover() @@ -2555,7 +2607,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()).MaxTimes(1) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2586,7 +2638,7 @@ var _ = Describe("Client Connection", func() { DestConnectionID: srcConnID, SrcConnectionID: destConnID, } - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handleLongHeaderPacket(receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) @@ -2768,9 +2820,10 @@ var _ = Describe("Client Connection", func() { } It("handles Retry packets", func() { + now := time.Now() sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph - sph.EXPECT().ResetForRetry() + sph.EXPECT().ResetForRetry(now) sph.EXPECT().ReceivedBytes(gomock.Any()) cryptoSetup.EXPECT().ChangeConnectionID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})) packer.EXPECT().SetToken([]byte("foobar")) @@ -2779,7 +2832,9 @@ var _ = Describe("Client Connection", func() { Expect(hdr.SrcConnectionID).To(Equal(retryHdr.SrcConnectionID)) Expect(hdr.Token).To(Equal(retryHdr.Token)) }) - Expect(conn.handlePacketImpl(getPacket(retryHdr, getRetryTag(retryHdr)))).To(BeTrue()) + p := getPacket(retryHdr, getRetryTag(retryHdr)) + p.rcvTime = now + Expect(conn.handlePacketImpl(p)).To(BeTrue()) }) It("ignores Retry packets after receiving a regular packet", func() { @@ -2849,7 +2904,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().Close(), @@ -3095,7 +3150,7 @@ var _ = Describe("Client Connection", func() { hdr: hdr1, data: []byte{0}, // one PADDING frame }, nil) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) @@ -3122,7 +3177,7 @@ var _ = Describe("Client Connection", func() { It("fails on Initial-level ACK for unsent packet", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) @@ -3134,7 +3189,7 @@ var _ = Describe("Client Connection", func() { ReasonPhrase: "mitm attacker", } initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) }) @@ -3144,7 +3199,7 @@ var _ = Describe("Client Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) - sph.EXPECT().ResetForRetry() + sph.EXPECT().ResetForRetry(gomock.Any()) newSrcConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) packer.EXPECT().SetToken([]byte("foobar")) diff --git a/crypto_stream.go b/crypto_stream.go index 0c991089c..4ad097ce5 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -30,17 +30,10 @@ type cryptoStreamImpl struct { writeOffset protocol.ByteCount writeBuf []byte - - // Reassemble TLS handshake messages before returning them from GetCryptoData. - // This is only needed because crypto/tls doesn't correctly handle post-handshake messages. - onlyCompleteMsg bool } -func newCryptoStream(onlyCompleteMsg bool) cryptoStream { - return &cryptoStreamImpl{ - queue: newFrameSorter(), - onlyCompleteMsg: onlyCompleteMsg, - } +func newCryptoStream() cryptoStream { + return &cryptoStreamImpl{queue: newFrameSorter()} } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { @@ -78,20 +71,6 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { - if s.onlyCompleteMsg { - if len(s.msgBuf) < 4 { - return nil - } - msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) - if len(s.msgBuf) < msgLen { - return nil - } - msg := make([]byte, msgLen) - copy(msg, s.msgBuf[:msgLen]) - s.msgBuf = s.msgBuf[msgLen:] - return msg - } - b := s.msgBuf s.msgBuf = nil return b diff --git a/crypto_stream_test.go b/crypto_stream_test.go index af6ad986f..c875bcb53 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,7 +1,6 @@ package quic import ( - "crypto/rand" "fmt" "github.com/refraction-networking/uquic/internal/protocol" @@ -16,7 +15,7 @@ var _ = Describe("Crypto Stream", func() { var str cryptoStream BeforeEach(func() { - str = newCryptoStream(false) + str = newCryptoStream() }) Context("handling incoming data", func() { @@ -138,23 +137,4 @@ var _ = Describe("Crypto Stream", func() { Expect(f.Data).To(Equal([]byte("bar"))) }) }) - - It("reassembles data", func() { - str = newCryptoStream(true) - data := make([]byte, 1337) - l := len(data) - 4 - data[1] = uint8(l >> 16) - data[2] = uint8(l >> 8) - data[3] = uint8(l) - rand.Read(data[4:]) - - for i, b := range data { - Expect(str.GetCryptoData()).To(BeEmpty()) - Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: protocol.ByteCount(i), - Data: []byte{b}, - })).To(Succeed()) - } - Expect(str.GetCryptoData()).To(Equal(data)) - }) }) diff --git a/example/client/main.go b/example/client/main.go index 2ebe162ff..18a09a191 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -59,7 +59,7 @@ func main() { var qconf quic.Config if *enableQlog { - qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { filename := fmt.Sprintf("client_%x.qlog", connID) f, err := os.Create(filename) if err != nil { diff --git a/example/main.go b/example/main.go index 04cbb2afd..a77af3398 100644 --- a/example/main.go +++ b/example/main.go @@ -163,7 +163,7 @@ func main() { handler := setupHandler(*www) quicConf := &quic.Config{} if *enableQlog { - quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { filename := fmt.Sprintf("server_%x.qlog", connID) f, err := os.Create(filename) if err != nil { diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index 61ac3f541..f72283e87 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -2,24 +2,22 @@ package tokens import ( "encoding/binary" - "math/rand" "net" "time" + "github.com/refraction-networking/uquic" "github.com/refraction-networking/uquic/internal/handshake" "github.com/refraction-networking/uquic/internal/protocol" ) func Fuzz(data []byte) int { - if len(data) < 8 { + if len(data) < 32 { return -1 } - seed := binary.BigEndian.Uint64(data[:8]) - data = data[8:] - tg, err := handshake.NewTokenGenerator(rand.New(rand.NewSource(int64(seed)))) - if err != nil { - panic(err) - } + var key quic.TokenGeneratorKey + copy(key[:], data[:32]) + data = data[32:] + tg := handshake.NewTokenGenerator(key) if len(data) < 1 { return -1 } diff --git a/go.mod b/go.mod index 4f30d6d98..a01659a9f 100644 --- a/go.mod +++ b/go.mod @@ -4,32 +4,32 @@ go 1.20 require ( github.com/francoispqt/gojay v1.2.13 - github.com/gaukas/clienthellod v0.4.0 - github.com/onsi/ginkgo/v2 v2.12.0 - github.com/onsi/gomega v1.27.10 + github.com/gaukas/clienthellod v0.4.2 + github.com/onsi/ginkgo/v2 v2.13.0 + github.com/onsi/gomega v1.29.0 github.com/quic-go/qpack v0.4.0 - github.com/refraction-networking/utls v1.5.2 - go.uber.org/mock v0.2.0 + github.com/refraction-networking/utls v1.5.4 + go.uber.org/mock v0.3.0 golang.org/x/crypto v0.14.0 - golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/net v0.17.0 - golang.org/x/sync v0.3.0 + golang.org/x/sync v0.4.0 golang.org/x/sys v0.13.0 ) require ( - github.com/andybalholm/brotli v1.0.5 // indirect - github.com/cloudflare/circl v1.3.3 // indirect + github.com/andybalholm/brotli v1.0.6 // indirect + github.com/cloudflare/circl v1.3.5 // indirect github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/gopacket v1.1.19 // indirect - github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f // indirect - github.com/klauspost/compress v1.16.7 // indirect - github.com/quic-go/quic-go v0.38.1 // indirect - golang.org/x/mod v0.12.0 // indirect + github.com/google/pprof v0.0.0-20231023181126-ff6d637d2a7b // indirect + github.com/klauspost/compress v1.17.2 // indirect + github.com/quic-go/quic-go v0.39.2 // indirect + golang.org/x/mod v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect - golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect + golang.org/x/tools v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 651a8f037..288d66a92 100644 --- a/go.sum +++ b/go.sum @@ -8,15 +8,15 @@ dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1 dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= +github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= -github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= +github.com/cloudflare/circl v1.3.5 h1:g+wWynZqVALYAlpSQFAa7TscDnUK8mKYtrxMpw6AUKo= +github.com/cloudflare/circl v1.3.5/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -26,8 +26,8 @@ github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/gaukas/clienthellod v0.4.0 h1:DySeZT4c3Xw6OGMzHRlAuOHx9q1P7vQNjA7YkyHrqac= -github.com/gaukas/clienthellod v0.4.0/go.mod h1:gjt7a7cNNzZV4yTe0jKcXtj0a7u6RL2KQvijxFOvcZE= +github.com/gaukas/clienthellod v0.4.2 h1:LPJ+LSeqt99pqeCV4C0cllk+pyWmERisP7w6qWr7eqE= +github.com/gaukas/clienthellod v0.4.2/go.mod h1:M57+dsu0ZScvmdnNxaxsDPM46WhSEdPYAOdNgfL7IKA= github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -47,16 +47,16 @@ github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f h1:pDhu5sgp8yJlEF/g6osliIIpF9K4F5jvkULXa4daRDQ= -github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/google/pprof v0.0.0-20231023181126-ff6d637d2a7b h1:RMpPgZTSApbPf7xaVel+QkoGPRLFLrwFO89uDUHEGf0= +github.com/google/pprof v0.0.0-20231023181126-ff6d637d2a7b/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -66,8 +66,8 @@ github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0 github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= -github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= +github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -82,10 +82,10 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/onsi/ginkgo/v2 v2.12.0 h1:UIVDowFPwpg6yMUpPjGkYvf06K3RAiJXUhCxEwQVHRI= -github.com/onsi/ginkgo/v2 v2.12.0/go.mod h1:ZNEzXISYlqpb8S36iN71ifqLi3vVD1rVJGvWRCJOUpQ= -github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= -github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= +github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= +github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= +github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg= +github.com/onsi/gomega v1.29.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -96,10 +96,10 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE= -github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4= -github.com/refraction-networking/utls v1.5.2 h1:l6diiLbEoRqdQ+/osPDO0z0lTc8O8VZV+p82N+Hi+ws= -github.com/refraction-networking/utls v1.5.2/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= +github.com/quic-go/quic-go v0.39.2 h1:hmwAf8zAHlvan0Y5PXxeeBFZEW17IW99sXLry8I2kjk= +github.com/quic-go/quic-go v0.39.2/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q= +github.com/refraction-networking/utls v1.5.4 h1:9k6EO2b8TaOGsQ7Pl7p9w6PUhx18/ZCeT0WNTZ7Uw4o= +github.com/refraction-networking/utls v1.5.4/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -134,8 +134,8 @@ github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cb github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= -go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -145,15 +145,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= +golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -176,8 +176,8 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -197,8 +197,8 @@ golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= -golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc= +golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= @@ -216,7 +216,7 @@ google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmE google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= diff --git a/http3/body.go b/http3/body.go index 315825b2b..d29e9f72a 100644 --- a/http3/body.go +++ b/http3/body.go @@ -63,7 +63,8 @@ func (r *body) wasStreamHijacked() bool { } func (r *body) Read(b []byte) (int, error) { - return r.str.Read(b) + n, err := r.str.Read(b) + return n, maybeReplaceError(err) } func (r *body) Close() error { @@ -106,7 +107,7 @@ func (r *hijackableBody) Read(b []byte) (int, error) { if err != nil { r.requestDone() } - return n, err + return n, maybeReplaceError(err) } func (r *hijackableBody) requestDone() { diff --git a/http3/client.go b/http3/client.go index 974a81e1e..bac73f325 100644 --- a/http3/client.go +++ b/http3/client.go @@ -321,13 +321,13 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } - return nil, rerr.err + return nil, maybeReplaceError(rerr.err) } if opt.DontCloseRequestStream { close(reqDone) <-done } - return rsp, rerr.err + return rsp, maybeReplaceError(rerr.err) } // cancelingReader reads from the io.Reader. diff --git a/http3/error.go b/http3/error.go new file mode 100644 index 000000000..fc41f544e --- /dev/null +++ b/http3/error.go @@ -0,0 +1,58 @@ +package http3 + +import ( + "errors" + "fmt" + + quic "github.com/refraction-networking/uquic" +) + +// Error is returned from the round tripper (for HTTP clients) +// and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs. +// See section 8 of RFC 9114. +type Error struct { + Remote bool + ErrorCode ErrCode + ErrorMessage string +} + +var _ error = &Error{} + +func (e *Error) Error() string { + s := e.ErrorCode.string() + if s == "" { + s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode)) + } + // Usually errors are remote. Only make it explicit for local errors. + if !e.Remote { + s += " (local)" + } + if e.ErrorMessage != "" { + s += ": " + e.ErrorMessage + } + return s +} + +func maybeReplaceError(err error) error { + if err == nil { + return nil + } + + var ( + e Error + strErr *quic.StreamError + appErr *quic.ApplicationError + ) + switch { + default: + return err + case errors.As(err, &strErr): + e.Remote = strErr.Remote + e.ErrorCode = ErrCode(strErr.ErrorCode) + case errors.As(err, &appErr): + e.Remote = appErr.Remote + e.ErrorCode = ErrCode(appErr.ErrorCode) + e.ErrorMessage = appErr.ErrorMessage + } + return &e +} diff --git a/http3/error_codes.go b/http3/error_codes.go index 86e27ff76..db114c77f 100644 --- a/http3/error_codes.go +++ b/http3/error_codes.go @@ -30,6 +30,14 @@ const ( ) func (e ErrCode) String() string { + s := e.string() + if s != "" { + return s + } + return fmt.Sprintf("unknown error code: %#x", uint16(e)) +} + +func (e ErrCode) string() string { switch e { case ErrCodeNoError: return "H3_NO_ERROR" @@ -68,6 +76,6 @@ func (e ErrCode) String() string { case ErrCodeDatagramError: return "H3_DATAGRAM_ERROR" default: - return fmt.Sprintf("unknown error code: %#x", uint16(e)) + return "" } } diff --git a/http3/error_test.go b/http3/error_test.go new file mode 100644 index 000000000..56aa200bb --- /dev/null +++ b/http3/error_test.go @@ -0,0 +1,40 @@ +package http3 + +import ( + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + quic "github.com/refraction-networking/uquic" +) + +var _ = Describe("HTTP/3 errors", func() { + It("converts", func() { + Expect(maybeReplaceError(nil)).To(BeNil()) + Expect(maybeReplaceError(errors.New("foobar"))).To(MatchError("foobar")) + Expect(maybeReplaceError(&quic.StreamError{ + ErrorCode: 1337, + Remote: true, + })).To(Equal(&Error{ + Remote: true, + ErrorCode: 1337, + })) + Expect(maybeReplaceError(&quic.ApplicationError{ + ErrorCode: 42, + Remote: true, + ErrorMessage: "foobar", + })).To(Equal(&Error{ + Remote: true, + ErrorCode: 42, + ErrorMessage: "foobar", + })) + }) + + It("has a string representation", func() { + Expect((&Error{ErrorCode: 0x10c, Remote: true}).Error()).To(Equal("H3_REQUEST_CANCELLED")) + Expect((&Error{ErrorCode: 0x10c, Remote: true, ErrorMessage: "foobar"}).Error()).To(Equal("H3_REQUEST_CANCELLED: foobar")) + Expect((&Error{ErrorCode: 0x10c, Remote: false}).Error()).To(Equal("H3_REQUEST_CANCELLED (local)")) + Expect((&Error{ErrorCode: 0x10c, Remote: false, ErrorMessage: "foobar"}).Error()).To(Equal("H3_REQUEST_CANCELLED (local): foobar")) + Expect((&Error{ErrorCode: 0x1337, Remote: true}).Error()).To(Equal("H3 error (0x1337)")) + }) +}) diff --git a/http3/mock_quic_early_listener_test.go b/http3/mock_quic_early_listener_test.go index b43b3550d..3995c5302 100644 --- a/http3/mock_quic_early_listener_test.go +++ b/http3/mock_quic_early_listener_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/http3 (interfaces: QUICEarlyListener) - +// +// Generated by this command: +// +// mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/refraction-networking/uquic/http3 QUICEarlyListener +// // Package http3 is a generated GoMock package. package http3 @@ -46,7 +50,7 @@ func (m *MockQUICEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnecti } // Accept indicates an expected call of Accept. -func (mr *MockQUICEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { +func (mr *MockQUICEarlyListenerMockRecorder) Accept(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockQUICEarlyListener)(nil).Accept), arg0) } diff --git a/http3/mock_roundtripcloser_test.go b/http3/mock_roundtripcloser_test.go index 6550b2554..9f580b881 100644 --- a/http3/mock_roundtripcloser_test.go +++ b/http3/mock_roundtripcloser_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/http3 (interfaces: RoundTripCloser) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package http3 -destination mock_roundtripcloser_test.go github.com/refraction-networking/uquic/http3 RoundTripCloser +// // Package http3 is a generated GoMock package. package http3 @@ -72,7 +76,7 @@ func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt } // RoundTripOpt indicates an expected call of RoundTripOpt. -func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1) } diff --git a/http3/response_writer.go b/http3/response_writer.go index 8eb592a99..ed58e9d1b 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -166,9 +166,10 @@ func (w *responseWriter) Write(p []byte) (int, error) { w.buf = w.buf[:0] w.buf = df.Append(w.buf) if _, err := w.bufferedStr.Write(w.buf); err != nil { - return 0, err + return 0, maybeReplaceError(err) } - return w.bufferedStr.Write(p) + n, err := w.bufferedStr.Write(p) + return n, maybeReplaceError(err) } func (w *responseWriter) FlushError() error { @@ -177,7 +178,7 @@ func (w *responseWriter) FlushError() error { } if !w.written { if err := w.writeHeader(); err != nil { - return err + return maybeReplaceError(err) } w.written = true } diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index e24232c0b..29427a8ce 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -134,8 +134,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= -github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg= +github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -174,8 +174,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= -go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -194,15 +194,15 @@ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -305,7 +305,6 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 6ab814a82..c7e4cee7c 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -7,11 +7,11 @@ import ( "net" "time" - tls "github.com/refraction-networking/utls" - quic "github.com/refraction-networking/uquic" + quicproxy "github.com/refraction-networking/uquic/integrationtests/tools/proxy" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" + tls "github.com/refraction-networking/utls" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -80,6 +80,26 @@ var _ = Describe("Handshake tests", func() { }() } + It("returns the cancellation reason when a dial is canceled", func() { + ctx, cancel := context.WithCancelCause(context.Background()) + errChan := make(chan error, 1) + go func() { + _, err := quic.DialAddr( + ctx, + "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway + getTLSClientConfig(), + getQuicConfig(nil), + ) + errChan <- err + }() + + cancel(errors.New("application cancelled")) + var err error + Eventually(errChan).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("application cancelled")) + }) + // Context("using different cipher suites", func() { // for n, id := range map[string]uint16{ // "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, @@ -453,16 +473,31 @@ var _ = Describe("Handshake tests", func() { }) It("rejects invalid Retry token with the INVALID_TOKEN error", func() { + const rtt = 10 * time.Millisecond serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - serverConfig.MaxRetryTokenAge = -time.Second + // The validity period of the retry token is the handshake timeout, + // which is twice the handshake idle timeout. + // By setting the handshake timeout shorter than the RTT, the token will have expired by the time + // it reaches the server. + serverConfig.HandshakeIdleTimeout = rtt / 5 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) defer server.Close() + serverPort := server.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + _, err = quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), nil, ) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 698f3160c..c0d4f22ec 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -307,9 +307,10 @@ var _ = Describe("HTTP tests", func() { for { if _, err := w.Write([]byte("foobar")); err != nil { Expect(r.Context().Done()).To(BeClosed()) - var strErr *quic.StreamError - Expect(errors.As(err, &strErr)).To(BeTrue()) - Expect(strErr.ErrorCode).To(Equal(quic.StreamErrorCode(0x10c))) + var http3Err *http3.Error + Expect(errors.As(err, &http3Err)).To(BeTrue()) + Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c))) + Expect(http3Err.Error()).To(Equal("H3_REQUEST_CANCELLED")) return } } @@ -325,7 +326,10 @@ var _ = Describe("HTTP tests", func() { cancel() Eventually(handlerCalled).Should(BeClosed()) _, err = resp.Body.Read([]byte{0}) - Expect(err).To(HaveOccurred()) + var http3Err *http3.Error + Expect(errors.As(err, &http3Err)).To(BeTrue()) + Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c))) + Expect(http3Err.Error()).To(Equal("H3_REQUEST_CANCELLED (local)")) }) It("allows streamed HTTP requests", func() { diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index af1ba5bdb..1341f25d6 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -38,16 +38,13 @@ func countKeyPhases() (sent, received int) { return } -type keyUpdateConnTracer struct { - logging.NullConnectionTracer -} - -func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ *logging.AckFrame, _ []logging.Frame) { - sentHeaders = append(sentHeaders, hdr) -} - -func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, frames []logging.Frame) { - receivedHeaders = append(receivedHeaders, hdr) +var keyUpdateConnTracer = &logging.ConnectionTracer{ + SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) { + sentHeaders = append(sentHeaders, hdr) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) { + receivedHeaders = append(receivedHeaders, hdr) + }, } var _ = Describe("Key Update tests", func() { @@ -75,8 +72,8 @@ var _ = Describe("Key Update tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return &keyUpdateConnTracer{} + getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return keyUpdateConnTracer }}), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 7a35c32a5..5da97513f 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -21,7 +21,7 @@ var _ = Describe("Packetization", func() { It("bundles ACKs", func() { const numMsg = 100 - serverTracer := newPacketTracer() + serverCounter, serverTracer := newPacketTracer() server, err := quic.ListenAddr( "localhost:0", getTLSConfig(), @@ -43,7 +43,7 @@ var _ = Describe("Packetization", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - clientTracer := newPacketTracer() + clientCounter, clientTracer := newPacketTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), @@ -104,8 +104,8 @@ var _ = Describe("Packetization", func() { return } - numBundledIncoming := countBundledPackets(clientTracer.getRcvdShortHeaderPackets()) - numBundledOutgoing := countBundledPackets(serverTracer.getRcvdShortHeaderPackets()) + numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets()) + numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets()) fmt.Fprintf(GinkgoWriter, "bundled incoming packets: %d / %d\n", numBundledIncoming, numMsg) fmt.Fprintf(GinkgoWriter, "bundled outgoing packets: %d / %d\n", numBundledOutgoing, numMsg) Expect(numBundledIncoming).To(And( diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index f2efc18c7..d63bedcec 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -87,7 +87,7 @@ var ( logBuf *syncedBuffer versionParam string - qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer + qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer enableQlog bool version quic.VersionNumber @@ -178,10 +178,16 @@ func getQuicConfig(conf *quic.Config) *quic.Config { } if enableQlog { if conf.Tracer == nil { - conf.Tracer = qlogTracer + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + qlogTracer(ctx, p, connID), + // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere + &logging.ConnectionTracer{}, + ) + } } else if qlogTracer != nil { origTracer := conf.Tracer - conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { return logging.NewMultiplexedConnectionTracer( qlogTracer(ctx, p, connID), origTracer(ctx, p, connID), @@ -243,8 +249,8 @@ func scaleDuration(d time.Duration) time.Duration { return time.Duration(scaleFactor) * d } -func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer } +func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer } } type packet struct { @@ -259,49 +265,46 @@ type shortHeaderPacket struct { frames []logging.Frame } -type packetTracer struct { - logging.NullConnectionTracer +type packetCounter struct { closed chan struct{} sentShortHdr, rcvdShortHdr []shortHeaderPacket rcvdLongHdr []packet } -func newPacketTracer() *packetTracer { - return &packetTracer{closed: make(chan struct{})} -} - -func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) { - t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, frames []logging.Frame) { - t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) { - if ack != nil { - frames = append(frames, ack) - } - t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) Close() { close(t.closed) } - -func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket { +func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket { <-t.closed return t.sentShortHdr } -func (t *packetTracer) getRcvdLongHeaderPackets() []packet { +func (t *packetCounter) getRcvdLongHeaderPackets() []packet { <-t.closed return t.rcvdLongHdr } -func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket { +func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket { <-t.closed return t.rcvdShortHdr } +func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) { + c := &packetCounter{closed: make(chan struct{})} + return c, &logging.ConnectionTracer{ + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { + c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { + c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) { + if ack != nil { + frames = append(frames, ack) + } + c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) + }, + Close: func() { close(c.closed) }, + } +} + func TestSelf(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Self integration tests") diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index b70a6fe3f..fc5db68ae 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -57,7 +57,7 @@ var _ = Describe("Timeout tests", func() { context.Background(), "localhost:12345", getTLSClientConfig(), - getQuicConfig(&quic.Config{HandshakeIdleTimeout: 10 * time.Millisecond}), + getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}), ) errChan <- err }() @@ -194,7 +194,7 @@ var _ = Describe("Timeout tests", func() { close(serverConnClosed) }() - tr := newPacketTracer() + counter, tr := newPacketTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -215,7 +215,7 @@ var _ = Describe("Timeout tests", func() { }() Eventually(done, 2*idleTimeout).Should(BeClosed()) var lastAckElicitingPacketSentAt time.Time - for _, p := range tr.getSentShortHeaderPackets() { + for _, p := range counter.getSentShortHeaderPackets() { var hasAckElicitingFrame bool for _, f := range p.frames { if _, ok := f.(*logging.AckFrame); ok { @@ -228,7 +228,7 @@ var _ = Describe("Timeout tests", func() { lastAckElicitingPacketSentAt = p.time } } - rcvdPackets := tr.getRcvdShortHeaderPackets() + rcvdPackets := counter.getRcvdShortHeaderPackets() lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time // We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout. // This is ok since we're dealing with a lossless connection here, diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index db0ba5981..cdd5c0d8c 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -26,9 +26,9 @@ var _ = Describe("Handshake tests", func() { fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer) - var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer + var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer if enableQlog { - tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID) return nil @@ -38,13 +38,13 @@ var _ = Describe("Handshake tests", func() { }) } if enableCustomTracer { - tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return logging.NullConnectionTracer{} + tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return &logging.ConnectionTracer{} }) } c := conf.Clone() - c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { - tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors)) + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + tracers := make([]*logging.ConnectionTracer, 0, len(tracerConstructors)) for _, c := range tracerConstructors { if tr := c(ctx, p, connID); tr != nil { tracers = append(tracers, tr) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index 101c1ffdf..b59726529 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -12,7 +12,7 @@ import ( "sync/atomic" "time" - "github.com/refraction-networking/uquic" + quic "github.com/refraction-networking/uquic" tls "github.com/refraction-networking/utls" quicproxy "github.com/refraction-networking/uquic/integrationtests/tools/proxy" @@ -203,7 +203,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) } - // can be used to extract 0-RTT from a packetTracer + // can be used to extract 0-RTT from a packetCounter get0RTTPackets := func(packets []packet) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber for _, p := range packets { @@ -220,7 +220,7 @@ var _ = Describe("0-RTT", func() { It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -245,7 +245,7 @@ var _ = Describe("0-RTT", func() { ) var numNewConnIDs int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if _, ok := f.(*logging.NewConnectionIDFrame); ok { numNewConnIDs++ @@ -261,7 +261,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) }) @@ -274,7 +274,7 @@ var _ = Describe("0-RTT", func() { zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -331,7 +331,7 @@ var _ = Describe("0-RTT", func() { // check that 0-RTT packets only contain STREAM frames for the first stream var num0RTT int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { if p.hdr.Header.Type != protocol.PacketType0RTT { continue } @@ -356,7 +356,7 @@ var _ = Describe("0-RTT", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -407,7 +407,7 @@ var _ = Describe("0-RTT", func() { fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) Expect(numDropped).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) }) It("retransmits all 0-RTT data when the server performs a Retry", func() { @@ -431,7 +431,7 @@ var _ = Describe("0-RTT", func() { return } - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -481,7 +481,7 @@ var _ = Describe("0-RTT", func() { defer mutex.Unlock() Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) @@ -492,14 +492,12 @@ var _ = Describe("0-RTT", func() { MaxIncomingUniStreams: maxStreams, })) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -537,7 +535,7 @@ var _ = Describe("0-RTT", func() { MaxIncomingStreams: maxStreams, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -557,16 +555,17 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the ALPN changed", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now close the listener and dial new connection with a different ALPN - clientConf.NextProtos = []string{"new-alpn"} + // clientConf.NextProtos = []string{"new-alpn"} + clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn") tlsConf.NextProtos = []string{"new-alpn"} - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -586,14 +585,14 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the application doesn't allow it", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now close the listener and dial new connection with a different ALPN - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -613,12 +612,12 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { - tracer := newPacketTracer() + counter, tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) @@ -670,7 +669,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) var processedFirst bool - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if sf, ok := f.(*logging.StreamFrame); ok { if !processedFirst { @@ -696,7 +695,7 @@ var _ = Describe("0-RTT", func() { It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now dial new connection with different transport parameters - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -765,14 +764,14 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) } It("queues 0-RTT packets, if the Initial is delayed", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -797,8 +796,8 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) @@ -808,7 +807,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -853,12 +852,12 @@ var _ = Describe("0-RTT", func() { Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) Expect(receivedMessage).To(Equal(sentMessage)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) }) It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { @@ -866,7 +865,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -908,10 +907,10 @@ var _ = Describe("0-RTT", func() { Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 7fe819308..e5ad5d4d3 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -233,7 +233,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) } - // can be used to extract 0-RTT from a packetTracer + // can be used to extract 0-RTT from a packetCounter get0RTTPackets := func(packets []packet) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber for _, p := range packets { @@ -252,7 +252,7 @@ var _ = Describe("0-RTT", func() { clientTLSConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -277,7 +277,7 @@ var _ = Describe("0-RTT", func() { ) var numNewConnIDs int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if _, ok := f.(*logging.NewConnectionIDFrame); ok { numNewConnIDs++ @@ -293,7 +293,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) }) @@ -308,7 +308,7 @@ var _ = Describe("0-RTT", func() { zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -365,7 +365,7 @@ var _ = Describe("0-RTT", func() { // check that 0-RTT packets only contain STREAM frames for the first stream var num0RTT int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { if p.hdr.Header.Type != protocol.PacketType0RTT { continue } @@ -392,7 +392,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -443,7 +443,7 @@ var _ = Describe("0-RTT", func() { fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) Expect(numDropped).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) }) It("retransmits all 0-RTT data when the server performs a Retry", func() { @@ -469,7 +469,7 @@ var _ = Describe("0-RTT", func() { return } - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -519,7 +519,7 @@ var _ = Describe("0-RTT", func() { defer mutex.Unlock() Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) @@ -532,14 +532,12 @@ var _ = Describe("0-RTT", func() { MaxIncomingUniStreams: maxStreams, }), clientConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -579,7 +577,7 @@ var _ = Describe("0-RTT", func() { MaxIncomingStreams: maxStreams, }), clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -600,7 +598,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the ALPN changed", func() { @@ -613,7 +611,7 @@ var _ = Describe("0-RTT", func() { // Append to the client's ALPN. // crypto/tls will attempt to resume with the ALPN from the original connection clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn") - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -633,7 +631,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the application doesn't allow it", func() { @@ -642,7 +640,7 @@ var _ = Describe("0-RTT", func() { dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now close the listener and dial new connection with a different ALPN - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -662,12 +660,12 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { - tracer := newPacketTracer() + counter, tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) tlsConf := getTLSConfig() @@ -721,7 +719,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) var processedFirst bool - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if sf, ok := f.(*logging.StreamFrame); ok { if !processedFirst { @@ -749,7 +747,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now dial new connection with different transport parameters - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -818,7 +816,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) } @@ -827,7 +825,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -852,8 +850,8 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) @@ -879,14 +877,10 @@ var _ = Describe("0-RTT", func() { clientTLSConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -917,14 +911,10 @@ var _ = Describe("0-RTT", func() { } dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -947,7 +937,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, }), clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -990,13 +980,13 @@ var _ = Describe("0-RTT", func() { <-received Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(receivedMessage).To(Equal(sentMessage)) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) }) It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { @@ -1006,7 +996,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, }), clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -1048,10 +1038,10 @@ var _ = Describe("0-RTT", func() { Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) }) diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go index 9dcaac8dd..80d6476a0 100644 --- a/integrationtests/tools/qlog.go +++ b/integrationtests/tools/qlog.go @@ -14,8 +14,8 @@ import ( "github.com/refraction-networking/uquic/qlog" ) -func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { +func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { role := "server" if p == logging.PerspectiveClient { role = "client" diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index f68cfac08..eeef9ea35 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -2,8 +2,10 @@ package versionnegotiation import ( "context" + "errors" "fmt" "net" + "time" tls "github.com/refraction-networking/utls" @@ -20,29 +22,29 @@ type versioner interface { GetVersion() protocol.VersionNumber } -type versionNegotiationTracer struct { - logging.NullConnectionTracer - +type result struct { loggedVersions bool receivedVersionNegotiation bool chosen logging.VersionNumber clientVersions, serverVersions []logging.VersionNumber } -var _ logging.ConnectionTracer = &versionNegotiationTracer{} - -func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { - if t.loggedVersions { - Fail("only expected one call to NegotiatedVersions") +func newVersionNegotiationTracer() (*result, *logging.ConnectionTracer) { + r := &result{} + return r, &logging.ConnectionTracer{ + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + if r.loggedVersions { + Fail("only expected one call to NegotiatedVersions") + } + r.loggedVersions = true + r.chosen = chosen + r.clientVersions = clientVersions + r.serverVersions = serverVersions + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { + r.receivedVersionNegotiation = true + }, } - t.loggedVersions = true - t.chosen = chosen - t.clientVersions = clientVersions - t.serverVersions = serverVersions -} - -func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { - t.receivedVersionNegotiation = true } var _ = Describe("Handshake tests", func() { @@ -85,54 +87,54 @@ var _ = Describe("Handshake tests", func() { // but it supports a bunch of versions that the client doesn't speak serverConfig := &quic.Config{} serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} - serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverResult, serverTracer := newVersionNegotiationTracer() + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer { + maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }}), ) Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) - Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) - Expect(clientTracer.serverVersions).To(BeEmpty()) - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) + Expect(clientResult.chosen).To(Equal(expectedVersion)) + Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) + Expect(clientResult.clientVersions).To(Equal(protocol.SupportedVersions)) + Expect(clientResult.serverVersions).To(BeEmpty()) + Expect(serverResult.chosen).To(Equal(expectedVersion)) + Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverResult.clientVersions).To(BeEmpty()) }) It("when the client supports more versions than the server supports", func() { expectedVersion := protocol.SupportedVersions[0] - // the server doesn't support the highest supported version, which is the first one the client will try + // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak - serverTracer := &versionNegotiationTracer{} + serverResult, serverTracer := newVersionNegotiationTracer() serverConfig := &quic.Config{} serverConfig.Versions = supportedVersions - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, - Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }, }), @@ -140,13 +142,53 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) - Expect(clientTracer.clientVersions).To(Equal(clientVersions)) - Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) + Expect(clientResult.chosen).To(Equal(expectedVersion)) + Expect(clientResult.receivedVersionNegotiation).To(BeTrue()) + Expect(clientResult.clientVersions).To(Equal(clientVersions)) + Expect(clientResult.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions + Expect(serverResult.chosen).To(Equal(expectedVersion)) + Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverResult.clientVersions).To(BeEmpty()) + }) + + It("fails if the server disables version negotiation", func() { + // The server doesn't support the highest supported version, which is the first one the client will try, + // but it supports a bunch of versions that the client doesn't speak + _, serverTracer := newVersionNegotiationTracer() + serverConfig := &quic.Config{} + serverConfig.Versions = supportedVersions + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return serverTracer + } + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{ + Conn: conn, + DisableVersionNegotiationPackets: true, + } + ln, err := tr.Listen(getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} + clientResult, clientTracer := newVersionNegotiationTracer() + _, err = quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port), + getTLSClientConfig(), + maybeAddQLOGTracer(&quic.Config{ + Versions: clientVersions, + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return clientTracer + }, + HandshakeIdleTimeout: 100 * time.Millisecond, + }), + ) + Expect(err).To(HaveOccurred()) + var nerr net.Error + Expect(errors.As(err, &nerr)).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) }) } }) diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go index ad7a91232..e1c7ef70d 100644 --- a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -70,7 +70,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config { c.Tracer = qlogger } else if qlogger != nil { origTracer := c.Tracer - c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { return logging.NewMultiplexedConnectionTracer( qlogger(ctx, p, connID), origTracer(ctx, p, connID), diff --git a/interface.go b/interface.go index 27271d54f..68fccd680 100644 --- a/interface.go +++ b/interface.go @@ -213,6 +213,9 @@ type EarlyConnection interface { // StatelessResetKey is a key used to derive stateless reset tokens. type StatelessResetKey [32]byte +// TokenGeneratorKey is a key used to encrypt session resumption tokens. +type TokenGeneratorKey = handshake.TokenProtectorKey + // A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. // It is not able to handle QUIC Connection IDs longer than 20 bytes, // as they are allowed by RFC 8999. @@ -251,7 +254,8 @@ type Config struct { // If not set, it uses all versions available. Versions []VersionNumber // HandshakeIdleTimeout is the idle timeout before completion of the handshake. - // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. + // If we don't receive any packet from the peer within this time, the connection attempt is aborted. + // Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted. // If this value is zero, the timeout is set to 5 seconds. HandshakeIdleTimeout time.Duration // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. @@ -265,13 +269,6 @@ type Config struct { // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. // If not set, every client is forced to prove its remote address. RequireAddressValidation func(net.Addr) bool - // MaxRetryTokenAge is the maximum age of a Retry token. - // If not set, it defaults to 5 seconds. Only valid for a server. - MaxRetryTokenAge time.Duration - // MaxTokenAge is the maximum age of the token presented during the handshake, - // for tokens that were issued on a previous connection. - // If not set, it defaults to 24 hours. Only valid for a server. - MaxTokenAge time.Duration // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set @@ -323,16 +320,12 @@ type Config struct { // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. // If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. DisablePathMTUDiscovery bool - // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. - // This can be useful if version information is exchanged out-of-band. - // It has no effect for a client. - DisableVersionNegotiationPackets bool // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. // Only valid for the server. Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer + Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer } type ClientHelloInfo struct { @@ -352,4 +345,6 @@ type ConnectionState struct { Used0RTT bool // Version is the QUIC version of the QUIC connection. Version VersionNumber + // GSO says if generic segmentation offload is used + GSO bool } diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 9d0fa84cf..5f9071a6e 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -14,10 +14,11 @@ func NewAckHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) return sph, newReceivedPacketHandler(sph, rttStats, logger) } diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go new file mode 100644 index 000000000..6eb943c0f --- /dev/null +++ b/internal/ackhandler/ecn.go @@ -0,0 +1,296 @@ +package ackhandler + +import ( + "fmt" + + "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/logging" +) + +type ecnState uint8 + +const ( + ecnStateInitial ecnState = iota + ecnStateTesting + ecnStateUnknown + ecnStateCapable + ecnStateFailed +) + +// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type +const numECNTestingPackets = 10 + +type ecnHandler interface { + SentPacket(protocol.PacketNumber, protocol.ECN) + Mode() protocol.ECN + HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) + LostPacket(protocol.PacketNumber) +} + +// The ecnTracker performs ECN validation of a path. +// Once failed, it doesn't do any re-validation of the path. +// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces. +// In order to avoid revealing any internal state to on-path observers, +// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent. +// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4. +type ecnTracker struct { + state ecnState + numSentTesting, numLostTesting uint8 + + firstTestingPacket protocol.PacketNumber + lastTestingPacket protocol.PacketNumber + firstCapablePacket protocol.PacketNumber + + numSentECT0, numSentECT1 int64 + numAckedECT0, numAckedECT1, numAckedECNCE int64 + + tracer *logging.ConnectionTracer + logger utils.Logger +} + +var _ ecnHandler = &ecnTracker{} + +func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker { + return &ecnTracker{ + firstTestingPacket: protocol.InvalidPacketNumber, + lastTestingPacket: protocol.InvalidPacketNumber, + firstCapablePacket: protocol.InvalidPacketNumber, + state: ecnStateInitial, + logger: logger, + tracer: tracer, + } +} + +func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { + //nolint:exhaustive // These are the only ones we need to take care of. + switch ecn { + case protocol.ECNNon: + return + case protocol.ECT0: + e.numSentECT0++ + case protocol.ECT1: + e.numSentECT1++ + case protocol.ECNUnsupported: + if e.state != ecnStateFailed { + panic("didn't expect ECN to be unsupported") + } + default: + panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn)) + } + + if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber { + e.firstCapablePacket = pn + } + + if e.state != ecnStateTesting { + return + } + + e.numSentTesting++ + if e.firstTestingPacket == protocol.InvalidPacketNumber { + e.firstTestingPacket = pn + } + if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateUnknown + e.lastTestingPacket = pn + } +} + +func (e *ecnTracker) Mode() protocol.ECN { + switch e.state { + case ecnStateInitial: + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateTesting + return e.Mode() + case ecnStateTesting, ecnStateCapable: + return protocol.ECT0 + case ecnStateUnknown, ecnStateFailed: + return protocol.ECNNon + default: + panic(fmt.Sprintf("unknown ECN state: %d", e.state)) + } +} + +func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { + if e.state != ecnStateTesting && e.state != ecnStateUnknown { + return + } + if !e.isTestingPacket(pn) { + return + } + e.numLostTesting++ + // Only proceed if we have sent all 10 testing packets. + if e.state != ecnStateUnknown { + return + } + if e.numLostTesting >= e.numSentTesting { + e.logger.Debugf("Disabling ECN. All testing packets were lost.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) + } + e.state = ecnStateFailed + return + } + // Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked + e.failIfMangled() +} + +// HandleNewlyAcked handles the ECN counts on an ACK frame. +// It must only be called for ACK frames that increase the largest acknowledged packet number, +// see section 13.4.2.1 of RFC 9000. +func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) { + if e.state == ecnStateFailed { + return false + } + + // ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds + // the total number of packets sent with each corresponding ECT codepoint. + if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 { + e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) + } + e.state = ecnStateFailed + return false + } + + // Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged. + var ackedECT0, ackedECT1 int64 + for _, p := range packets { + //nolint:exhaustive // We only ever send ECT(0) and ECT(1). + switch e.ecnMarking(p.PacketNumber) { + case protocol.ECT0: + ackedECT0++ + case protocol.ECT1: + ackedECT1++ + } + } + + // If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1) + // codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame. + // This check detects: + // * paths that bleach all ECN marks, and + // * peers that don't report any ECN counts + if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 { + e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + } + e.state = ecnStateFailed + return false + } + + // Determine the increase in ECT0, ECT1 and ECNCE marks + newECT0 := ect0 - e.numAckedECT0 + newECT1 := ect1 - e.numAckedECT1 + newECNCE := ecnce - e.numAckedECNCE + + // We're only processing ACKs that increase the Largest Acked. + // Therefore, the ECN counters should only ever increase. + // Any decrease means that the peer's counting logic is broken. + if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 { + e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts) + } + e.state = ecnStateFailed + return false + } + + // ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number + // of newly acknowledged packets that were originally sent with an ECT(0) marking. + // This could be the result of (partial) bleaching. + if newECT0+newECNCE < ackedECT0 { + e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + } + e.state = ecnStateFailed + return false + } + // Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than + // the number of newly acknowledged packets sent with an ECT(1) marking. + if newECT1+newECNCE < ackedECT1 { + e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + } + e.state = ecnStateFailed + return false + } + + // update our counters + e.numAckedECT0 = ect0 + e.numAckedECT1 = ect1 + e.numAckedECNCE = ecnce + + // Detect mangling (a path remarking all ECN-marked testing packets as CE), + // once all 10 testing packets have been sent out. + if e.state == ecnStateUnknown { + e.failIfMangled() + if e.state == ecnStateFailed { + return false + } + } + if e.state == ecnStateTesting || e.state == ecnStateUnknown { + var ackedTestingPacket bool + for _, p := range packets { + if e.isTestingPacket(p.PacketNumber) { + ackedTestingPacket = true + break + } + } + // This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE). + if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) { + e.logger.Debugf("ECN capability confirmed.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateCapable + } + } + + // Don't trust CE marks before having confirmed ECN capability of the path. + // Otherwise, mangling would be misinterpreted as actual congestion. + return e.state == ecnStateCapable && newECNCE > 0 +} + +// failIfMangled fails ECN validation if all testing packets are lost or CE-marked. +func (e *ecnTracker) failIfMangled() { + numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting) + if e.numSentECT0+e.numSentECT1 > numAckedECNCE { + return + } + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + } + e.state = ecnStateFailed +} + +func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN { + if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber { + return protocol.ECNNon + } + if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber { + return protocol.ECT0 + } + if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber { + return protocol.ECNNon + } + // We don't need to deal with the case when ECN validation fails, + // since we're ignoring any ECN counts reported in ACK frames in that case. + return protocol.ECT0 +} + +func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool { + if e.firstTestingPacket == protocol.InvalidPacketNumber { + return false + } + return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber) +} diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go new file mode 100644 index 000000000..644a025c2 --- /dev/null +++ b/internal/ackhandler/ecn_test.go @@ -0,0 +1,272 @@ +package ackhandler + +import ( + mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging" + "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/logging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ECN tracker", func() { + var ecnTracker *ecnTracker + var tracer *mocklogging.MockConnectionTracer + + getAckedPackets := func(pns ...protocol.PacketNumber) []*packet { + var packets []*packet + for _, p := range pns { + packets = append(packets, &packet{PacketNumber: p}) + } + return packets + } + + BeforeEach(func() { + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + ecnTracker = newECNTracker(utils.DefaultLogger, tr) + }) + + It("sends exactly 10 testing packets", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 9; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + // Do this twice to make sure only sent packets are counted + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(10+i), protocol.ECT0) + } + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + ecnTracker.SentPacket(20, protocol.ECT0) + // In unknown state, packets shouldn't be ECN-marked. + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + }) + + sendAllTestingPackets := func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + for i := 0; i < 10; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + } + } + + It("fails ECN validation if all ECN testing packets are lost", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + for i := 0; i < 9; i++ { + ecnTracker.LostPacket(protocol.PacketNumber(i)) + } + // We don't care about the loss of non-testing packets + ecnTracker.LostPacket(15) + // Now lose the last testing packet. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) + ecnTracker.LostPacket(9) + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + // We still don't care about more non-testing packets being lost + ecnTracker.LostPacket(16) + }) + + It("only detects ECN mangling after sending all testing packets", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 9; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + ecnTracker.LostPacket(protocol.PacketNumber(i)) + } + // Send the last testing packet, and receive a + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(9, protocol.ECT0) + // Now lose the last testing packet. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) + ecnTracker.LostPacket(9) + }) + + It("passes ECN validation when a testing packet is acknowledged, while still in testing state", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 5; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(3), 1, 0, 0)).To(BeFalse()) + // make sure we continue sending ECT(0) packets + for i := 5; i < 100; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + } + }) + + It("passes ECN validation when a testing packet is acknowledged, while in unknown state", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // Lose some packets to make sure this doesn't influence the outcome. + for i := 0; i < 5; i++ { + ecnTracker.LostPacket(protocol.PacketNumber(i)) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked([]*packet{{PacketNumber: 7}}, 1, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK contains more ECN counts than we sent packets", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // only 10 ECT(0) packets were sent, but the ACK claims to have received 12 of them + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 12, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK contains ECN counts for the wrong code point", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // We sent ECT(0), but this ACK acknowledges ECT(1). + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 0, 1, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK doesn't contain ECN counts", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // First only acknowledge packets sent without ECN marks. + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(12, 13, 14), 0, 0, 0)).To(BeFalse()) + // Now acknowledge some packets sent with ECN marks. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 15), 0, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when an ACK decreases ECN counts", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 3, 0, 0)).To(BeFalse()) + // Now acknowledge some more packets, but decrease the ECN counts. Obviously, this doesn't make any sense. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 2, 0, 0)).To(BeFalse()) + // make sure that new ACKs are ignored + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 5, 0, 0)).To(BeFalse()) + }) + + // This can happen if ACK are lost / reordered. + It("doesn't fail validation if the ACK contains more ECN counts than it acknowledges packets", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 8, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK doesn't contain enough ECN counts", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // First only acknowledge some packets sent with ECN marks. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)).To(BeTrue()) + // Now acknowledge some more packets sent with ECN marks, but don't increase the counters enough. + // This ACK acknowledges 3 more ECN-marked packets, but the counters only increase by 2. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 15), 3, 0, 2)).To(BeFalse()) + }) + + It("detects ECN mangling if all testing packets are marked CE", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3), 0, 0, 4)).To(BeFalse()) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 10, 11, 12), 0, 0, 7)).To(BeFalse()) + // With the next ACK, all testing packets will now have been marked CE. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 13), 0, 0, 10)).To(BeFalse()) + }) + + It("only detects ECN mangling after sending all testing packets", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 9; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(protocol.PacketNumber(i)), 0, 0, int64(i+1))).To(BeFalse()) + } + // Send the last testing packet, and receive a + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(9, protocol.ECT0) + // This ACK now reports the last testing packets as CE as well. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 10)).To(BeFalse()) + }) + + It("detects ECN mangling, if some testing packets are marked CE, and then others are lost", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3), 0, 0, 4)).To(BeFalse()) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(6, 7, 8, 9), 0, 0, 8)).To(BeFalse()) + // Lose one of the two unacknowledged packets. + ecnTracker.LostPacket(4) + // By losing the last unacknowledged testing packets, we should detect the mangling. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + ecnTracker.LostPacket(5) + }) + + It("detects ECN mangling, if some testing packets are lost, and then others are marked CE", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // Lose a few packets. + ecnTracker.LostPacket(0) + ecnTracker.LostPacket(1) + ecnTracker.LostPacket(2) + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(3, 4, 5, 6, 7, 8), 0, 0, 6)).To(BeFalse()) + // By CE-marking the last unacknowledged testing packets, we should detect the mangling. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 7)).To(BeFalse()) + }) + + It("declares congestion", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // Receive one CE count. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)).To(BeTrue()) + // No increase in CE. No congestion. + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 5, 0, 1)).To(BeFalse()) + // Increase in CE. More congestion. + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 7, 0, 2)).To(BeTrue()) + }) +}) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index b478ef777..56996fcc9 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -10,13 +10,13 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) + SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool) // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. - ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) + ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) - ResetForRetry() error + ResetForRetry(rcvTime time.Time) error SetHandshakeConfirmed() // The SendMode determines if and what kind of packets can be sent. @@ -29,6 +29,7 @@ type SentPacketHandler interface { // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ + ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber @@ -44,7 +45,7 @@ type sentPacketTracker interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool - ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error + ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, ackEliciting bool) error DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time diff --git a/internal/ackhandler/mock_ecn_handler_test.go b/internal/ackhandler/mock_ecn_handler_test.go new file mode 100644 index 000000000..949268f5e --- /dev/null +++ b/internal/ackhandler/mock_ecn_handler_test.go @@ -0,0 +1,91 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/refraction-networking/uquic/internal/ackhandler (interfaces: ECNHandler) +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package ackhandler -destination mock_ecn_handler_test.go github.com/refraction-networking/uquic/internal/ackhandler ECNHandler +// +// Package ackhandler is a generated GoMock package. +package ackhandler + +import ( + reflect "reflect" + + protocol "github.com/refraction-networking/uquic/internal/protocol" + gomock "go.uber.org/mock/gomock" +) + +// MockECNHandler is a mock of ECNHandler interface. +type MockECNHandler struct { + ctrl *gomock.Controller + recorder *MockECNHandlerMockRecorder +} + +// MockECNHandlerMockRecorder is the mock recorder for MockECNHandler. +type MockECNHandlerMockRecorder struct { + mock *MockECNHandler +} + +// NewMockECNHandler creates a new mock instance. +func NewMockECNHandler(ctrl *gomock.Controller) *MockECNHandler { + mock := &MockECNHandler{ctrl: ctrl} + mock.recorder = &MockECNHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockECNHandler) EXPECT() *MockECNHandlerMockRecorder { + return m.recorder +} + +// HandleNewlyAcked mocks base method. +func (m *MockECNHandler) HandleNewlyAcked(arg0 []*packet, arg1, arg2, arg3 int64) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleNewlyAcked", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HandleNewlyAcked indicates an expected call of HandleNewlyAcked. +func (mr *MockECNHandlerMockRecorder) HandleNewlyAcked(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNewlyAcked", reflect.TypeOf((*MockECNHandler)(nil).HandleNewlyAcked), arg0, arg1, arg2, arg3) +} + +// LostPacket mocks base method. +func (m *MockECNHandler) LostPacket(arg0 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockECNHandlerMockRecorder) LostPacket(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockECNHandler)(nil).LostPacket), arg0) +} + +// Mode mocks base method. +func (m *MockECNHandler) Mode() protocol.ECN { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Mode") + ret0, _ := ret[0].(protocol.ECN) + return ret0 +} + +// Mode indicates an expected call of Mode. +func (mr *MockECNHandlerMockRecorder) Mode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mode", reflect.TypeOf((*MockECNHandler)(nil).Mode)) +} + +// SentPacket mocks base method. +func (m *MockECNHandler) SentPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockECNHandlerMockRecorder) SentPacket(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockECNHandler)(nil).SentPacket), arg0, arg1) +} diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index cd34cfdbc..2e755658d 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/ackhandler (interfaces: SentPacketTracker) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker +// // Package ackhandler is a generated GoMock package. package ackhandler @@ -55,7 +59,7 @@ func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { } // ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) } diff --git a/internal/ackhandler/mockgen.go b/internal/ackhandler/mockgen.go index 3d2f10823..b36c0de1e 100644 --- a/internal/ackhandler/mockgen.go +++ b/internal/ackhandler/mockgen.go @@ -4,3 +4,6 @@ package ackhandler //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker" type SentPacketTracker = sentPacketTracker + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/refraction-networking/uquic/internal/ackhandler ECNHandler" +type ECNHandler = ecnHandler diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index 4d7d6f6d2..98c5e6132 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -40,29 +40,29 @@ func (h *receivedPacketHandler) ReceivedPacket( ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, - shouldInstigateAck bool, + ackEliciting bool, ) error { h.sentPackets.ReceivedPacket(encLevel) switch encLevel { case protocol.EncryptionInitial: - return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.EncryptionHandshake: // The Handshake packet number space might already have been dropped as a result // of processing the CRYPTO frame that was contained in this packet. if h.handshakePackets == nil { return nil } - return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.Encryption0RTT: if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) } - return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.Encryption1RTT: if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { h.lowest1RTTPacket = pn } - if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck); err != nil { + if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil { return err } h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 2d8a813b9..750d13f93 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -13,10 +13,10 @@ import ( const packetsBeforeAck = 2 type receivedPacketTracker struct { - largestObserved protocol.PacketNumber - ignoreBelow protocol.PacketNumber - largestObservedReceivedTime time.Time - ect0, ect1, ecnce uint64 + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber + largestObservedRcvdTime time.Time + ect0, ect1, ecnce uint64 packetHistory *receivedPacketHistory @@ -45,25 +45,25 @@ func newReceivedPacketTracker( } } -func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) error { - if isNew := h.packetHistory.ReceivedPacket(packetNumber); !isNew { - return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", packetNumber) +func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { + if isNew := h.packetHistory.ReceivedPacket(pn); !isNew { + return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) } - isMissing := h.isMissing(packetNumber) - if packetNumber >= h.largestObserved { - h.largestObserved = packetNumber - h.largestObservedReceivedTime = rcvTime + isMissing := h.isMissing(pn) + if pn >= h.largestObserved { + h.largestObserved = pn + h.largestObservedRcvdTime = rcvTime } - if shouldInstigateAck { + if ackEliciting { h.hasNewAck = true } - if shouldInstigateAck { - h.maybeQueueAck(packetNumber, rcvTime, isMissing) + if ackEliciting { + h.maybeQueueACK(pn, rcvTime, isMissing) } + //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE. switch ecn { - case protocol.ECNNon: case protocol.ECT0: h.ect0++ case protocol.ECT1: @@ -76,14 +76,14 @@ func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumbe // IgnoreBelow sets a lower limit for acknowledging packets. // Packets with packet numbers smaller than p will not be acked. -func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) { - if p <= h.ignoreBelow { +func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { + if pn <= h.ignoreBelow { return } - h.ignoreBelow = p - h.packetHistory.DeleteBelow(p) + h.ignoreBelow = pn + h.packetHistory.DeleteBelow(pn) if h.logger.Debug() { - h.logger.Debugf("\tIgnoring all packets below %d.", p) + h.logger.Debugf("\tIgnoring all packets below %d.", pn) } } @@ -103,8 +103,8 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool { return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 } -// maybeQueueAck queues an ACK, if necessary. -func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { +// maybeQueueACK queues an ACK, if necessary. +func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { // always acknowledge the first packet if h.lastAck == nil { if !h.ackQueued { @@ -175,7 +175,7 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { ack = &wire.AckFrame{} } ack.Reset() - ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) + ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedRcvdTime)) ack.ECT0 = h.ect0 ack.ECT1 = h.ect1 ack.ECNCE = h.ecnce diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index 9e6b72def..a95e05684 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -25,26 +25,26 @@ var _ = Describe("Received Packet Tracker", func() { Context("accepting packets", func() { It("saves the time when each packet arrived", func() { Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, time.Now(), true)).To(Succeed()) - Expect(tracker.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) + Expect(tracker.largestObservedRcvdTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) }) - It("updates the largestObserved and the largestObservedReceivedTime", func() { + It("updates the largestObserved and the largestObservedRcvdTime", func() { now := time.Now() tracker.largestObserved = 3 - tracker.largestObservedReceivedTime = now.Add(-1 * time.Second) + tracker.largestObservedRcvdTime = now.Add(-1 * time.Second) Expect(tracker.ReceivedPacket(5, protocol.ECNNon, now, true)).To(Succeed()) Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) - Expect(tracker.largestObservedReceivedTime).To(Equal(now)) + Expect(tracker.largestObservedRcvdTime).To(Equal(now)) }) - It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() { + It("doesn't update the largestObserved and the largestObservedRcvdTime for a belated packet", func() { now := time.Now() timestamp := now.Add(-1 * time.Second) tracker.largestObserved = 5 - tracker.largestObservedReceivedTime = timestamp + tracker.largestObservedRcvdTime = timestamp Expect(tracker.ReceivedPacket(4, protocol.ECNNon, now, true)).To(Succeed()) Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) - Expect(tracker.largestObservedReceivedTime).To(Equal(timestamp)) + Expect(tracker.largestObservedRcvdTime).To(Equal(timestamp)) }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 3ba7e5868..716c406cc 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -92,9 +92,12 @@ type sentPacketHandler struct { // The alarm timeout alarm time.Time + enableECN bool + ecnTracker ecnHandler + perspective protocol.Perspective - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -110,8 +113,9 @@ func newSentPacketHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) *sentPacketHandler { congestion := congestion.NewCubicSender( @@ -122,7 +126,7 @@ func newSentPacketHandler( tracer, ) - return &sentPacketHandler{ + h := &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, initialPackets: newPacketNumberSpace(initialPN, false), @@ -134,6 +138,11 @@ func newSentPacketHandler( tracer: tracer, logger: logger, } + if enableECN { + h.enableECN = true + h.ecnTracker = newECNTracker(logger, tracer) + } + return h } func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { @@ -187,7 +196,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 @@ -228,6 +237,7 @@ func (h *sentPacketHandler) SentPacket( streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, + ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool, ) { @@ -252,6 +262,10 @@ func (h *sentPacketHandler) SentPacket( } h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { + h.ecnTracker.SentPacket(pn, ecn) + } + if !isAckEliciting { pnSpace.history.SentNonAckElicitingPacket(pn) if !h.peerCompletedAddressValidation { @@ -272,7 +286,7 @@ func (h *sentPacketHandler) SentPacket( p.includedInBytesInFlight = true pnSpace.history.SentAckElicitingPacket(p) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } h.setLossDetectionTimer() @@ -302,8 +316,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) - // Servers complete address validation when a protected packet is received. if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { @@ -333,6 +345,17 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.congestion.MaybeExitSlowStart() } } + + // Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked. + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { + congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) + if congested { + h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight) + } + } + + pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) + if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err } @@ -353,14 +376,14 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 } h.numProbesToSend = 0 - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } @@ -439,7 +462,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, err } - if h.tracer != nil { + if h.tracer != nil && h.tracer.AcknowledgedPacket != nil { h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) } } @@ -532,7 +555,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = lossTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) } return @@ -543,7 +566,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. Amplification limited.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -555,7 +578,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. No packets in flight.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -568,14 +591,14 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !oldAlarm.IsZero() { h.alarm = time.Time{} h.logger.Debugf("Canceling loss detection timer. No PTO needed..") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } return } h.alarm = ptoTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) } } @@ -606,7 +629,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) } } @@ -616,7 +639,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) } } @@ -635,7 +658,10 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) if !p.IsPathMTUProbePacket { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight) + } + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { + h.ecnTracker.LostPacket(p.PacketNumber) } } } @@ -650,7 +676,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { if h.logger.Debug() { h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerExpired != nil { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection @@ -687,8 +713,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) } if h.tracer != nil { - h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) - h.tracer.UpdatedPTOCount(h.ptoCount) + if h.tracer.LossTimerExpired != nil { + h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) + } + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(h.ptoCount) + } } h.numProbesToSend += 2 //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. @@ -712,6 +742,16 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } +func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { + if !h.enableECN { + return protocol.ECNUnsupported + } + if !isShortHeaderPacket { + return protocol.ECNNon + } + return h.ecnTracker.Mode() +} + func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) pn := pnSpace.pns.Peek() @@ -825,7 +865,7 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { p.Frames = nil } -func (h *sentPacketHandler) ResetForRetry() error { +func (h *sentPacketHandler) ResetForRetry(now time.Time) error { h.bytesInFlight = 0 var firstPacketSendTime time.Time h.initialPackets.history.Iterate(func(p *packet) (bool, error) { @@ -851,12 +891,11 @@ func (h *sentPacketHandler) ResetForRetry() error { // Otherwise, we don't know which Initial the Retry was sent in response to. if h.ptoCount == 0 { // Don't set the RTT to a value lower than 5ms here. - now := time.Now() h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } @@ -865,8 +904,10 @@ func (h *sentPacketHandler) ResetForRetry() error { oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { - h.tracer.UpdatedPTOCount(0) - if !oldAlarm.IsZero() { + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(0) + } + if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 15a09fa2c..fc962f5e3 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -44,7 +44,7 @@ var _ = Describe("SentPacketHandler", func() { JustBeforeEach(func() { lostPackets = nil rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger) streamFrame = wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -106,7 +106,7 @@ var _ = Describe("SentPacketHandler", func() { } sentPacket := func(p *packet) { - handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, p.Length, p.IsPathMTUProbePacket) + handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, protocol.ECNNon, p.Length, p.IsPathMTUProbePacket) } expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { @@ -563,7 +563,7 @@ var _ = Describe("SentPacketHandler", func() { // lose packet 1 gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} @@ -575,7 +575,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) }) - It("doesn't call OnPacketLost when a Path MTU probe packet is lost", func() { + It("doesn't call OnCongestionEvent when a Path MTU probe packet is lost", func() { cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) var mtuPacketDeclaredLost bool sentPacket(ackElicitingPacket(&packet{ @@ -590,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() { }, })) sentPacket(ackElicitingPacket(&packet{PacketNumber: 2})) - // lose packet 1, but don't EXPECT any calls to OnPacketLost() + // lose packet 1, but don't EXPECT any calls to OnCongestionEvent() gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), @@ -602,7 +602,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.bytesInFlight).To(BeZero()) }) - It("calls OnPacketAcked and OnPacketLost with the right bytes_in_flight value", func() { + It("calls OnPacketAcked and OnCongestionEvent with the right bytes_in_flight value", func() { cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(4) sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) sentPacket(ackElicitingPacket(&packet{PacketNumber: 2, SendTime: time.Now().Add(-30 * time.Minute)})) @@ -611,7 +611,7 @@ var _ = Describe("SentPacketHandler", func() { // receive the first ACK gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} @@ -620,7 +620,7 @@ var _ = Describe("SentPacketHandler", func() { // receive the second ACK gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}} @@ -984,7 +984,7 @@ var _ = Describe("SentPacketHandler", func() { Context("amplification limit, for the server, with validated address", func() { JustBeforeEach(func() { rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, false, perspective, nil, utils.DefaultLogger) }) It("do not limits the window", func() { @@ -1334,7 +1334,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.bytesInFlight).ToNot(BeZero()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) // now receive a Retry - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(time.Now())).To(Succeed()) Expect(lostPackets).To(Equal([]protocol.PacketNumber{42})) Expect(handler.bytesInFlight).To(BeZero()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) @@ -1369,7 +1369,7 @@ var _ = Describe("SentPacketHandler", func() { }) Expect(handler.bytesInFlight).ToNot(BeZero()) // now receive a Retry - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(time.Now())).To(Succeed()) Expect(handler.bytesInFlight).To(BeZero()) Expect(lostInitial).To(BeTrue()) Expect(lost0RTT).To(BeTrue()) @@ -1379,50 +1379,201 @@ var _ = Describe("SentPacketHandler", func() { }) It("uses a Retry for an RTT estimate, if it was not retransmitted", func() { + now := time.Now() sentPacket(ackElicitingPacket(&packet{ PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-500 * time.Millisecond), + SendTime: now, })) sentPacket(ackElicitingPacket(&packet{ PacketNumber: 43, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-10 * time.Millisecond), + SendTime: now.Add(500 * time.Millisecond), })) - Expect(handler.ResetForRetry()).To(Succeed()) - Expect(handler.rttStats.SmoothedRTT()).To(BeNumerically("~", 500*time.Millisecond, 100*time.Millisecond)) + Expect(handler.ResetForRetry(now.Add(time.Second))).To(Succeed()) + Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) }) It("uses a Retry for an RTT estimate, but doesn't set the RTT to a value lower than 5ms", func() { + now := time.Now() sentPacket(ackElicitingPacket(&packet{ PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-500 * time.Microsecond), + SendTime: now, })) sentPacket(ackElicitingPacket(&packet{ PacketNumber: 43, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-10 * time.Microsecond), + SendTime: now.Add(2 * time.Millisecond), })) - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(now.Add(4 * time.Millisecond))).To(Succeed()) + Expect(minRTTAfterRetry).To(BeNumerically(">", 4*time.Millisecond)) Expect(handler.rttStats.SmoothedRTT()).To(Equal(minRTTAfterRetry)) }) It("doesn't use a Retry for an RTT estimate, if it was not retransmitted", func() { + now := time.Now() sentPacket(ackElicitingPacket(&packet{ PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-800 * time.Millisecond), + SendTime: now, })) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOInitial)) sentPacket(ackElicitingPacket(&packet{ PacketNumber: 43, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-100 * time.Millisecond), + SendTime: now.Add(500 * time.Millisecond), })) - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(now.Add(time.Second))).To(Succeed()) Expect(handler.rttStats.SmoothedRTT()).To(BeZero()) }) }) + + Context("ECN handling", func() { + var ecnHandler *MockECNHandler + var cong *mocks.MockSendAlgorithmWithDebugInfos + + JustBeforeEach(func() { + cong = mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl) + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + cong.EXPECT().OnPacketAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + cong.EXPECT().MaybeExitSlowStart().AnyTimes() + ecnHandler = NewMockECNHandler(mockCtrl) + lostPackets = nil + rttStats := utils.NewRTTStats() + rttStats.UpdateRTT(time.Hour, 0, time.Now()) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger) + handler.ecnTracker = ecnHandler + handler.congestion = cong + }) + + It("informs about sent packets", func() { + // Check that only 1-RTT packets are reported + handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false) + handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false) + + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1) + handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + }) + + It("informs about sent packets", func() { + // Check that only 1-RTT packets are reported + handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false) + handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false) + + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1) + handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + }) + + It("informs about lost packets", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + cong.EXPECT().OnCongestionEvent(gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(10)) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(11)) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(12)) + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 16, Smallest: 13}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("processes ACKs", func() { + // Check that we only care about 1-RTT packets. + handler.SentPacket(time.Now(), 100, -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 100}}}, protocol.EncryptionInitial, time.Now()) + Expect(err).ToNot(HaveOccurred()) + + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(5)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(10))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(11))) + Expect(packets[2].PacketNumber).To(Equal(protocol.PacketNumber(12))) + Expect(packets[3].PacketNumber).To(Equal(protocol.PacketNumber(14))) + Expect(packets[4].PacketNumber).To(Equal(protocol.PacketNumber(15))) + return false + }) + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{ + {Largest: 15, Smallest: 14}, + {Largest: 12, Smallest: 10}, + }, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores reordered ACKs", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(2)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(12))) + return false + }) + _, err := handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 12, Smallest: 11}}, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + // acknowledge packet 10 now, but don't increase the largest acked + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 12, Smallest: 10}}, + ECT0: 1, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores ACKs that don't increase the largest acked", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(1)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11))) + return false + }) + _, err := handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 11, Smallest: 11}}, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 11, Smallest: 10}}, + ECT0: 1, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("informs the congestion controller about CE events", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT0) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT0, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(0), int64(0), int64(0)).Return(true) + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(15), gomock.Any(), gomock.Any()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 15, Smallest: 10}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + }) }) diff --git a/internal/ackhandler/u_ackhandler.go b/internal/ackhandler/u_ackhandler.go index 8d590390b..00a2a8c0a 100644 --- a/internal/ackhandler/u_ackhandler.go +++ b/internal/ackhandler/u_ackhandler.go @@ -12,11 +12,12 @@ func NewUAckHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) return &uSentPacketHandler{ sentPacketHandler: sph, }, newReceivedPacketHandler(sph, rttStats, logger) diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index 2e6895bfa..e5d297644 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -56,7 +56,7 @@ type cubicSender struct { maxDatagramSize protocol.ByteCount lastState logging.CongestionState - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer } var ( @@ -70,7 +70,7 @@ func NewCubicSender( rttStats *utils.RTTStats, initialMaxDatagramSize protocol.ByteCount, reno bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { return newCubicSender( clock, @@ -90,7 +90,7 @@ func newCubicSender( initialMaxDatagramSize, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { c := &cubicSender{ rttStats: rttStats, @@ -108,7 +108,7 @@ func newCubicSender( maxDatagramSize: initialMaxDatagramSize, } c.pacer = newPacer(c.BandwidthEstimate) - if c.tracer != nil { + if c.tracer != nil && c.tracer.UpdatedCongestionState != nil { c.lastState = logging.CongestionStateSlowStart c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) } @@ -188,7 +188,7 @@ func (c *cubicSender) OnPacketAcked( } } -func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { +func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { @@ -296,7 +296,7 @@ func (c *cubicSender) OnConnectionMigration() { } func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { - if c.tracer == nil || new == c.lastState { + if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState { return } c.tracer.UpdatedCongestionState(new) diff --git a/internal/congestion/cubic_sender_test.go b/internal/congestion/cubic_sender_test.go index 01f008b18..26f20ccab 100644 --- a/internal/congestion/cubic_sender_test.go +++ b/internal/congestion/cubic_sender_test.go @@ -80,14 +80,14 @@ var _ = Describe("Cubic Sender", func() { LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) { for i := 0; i < n; i++ { ackedPacketNumber++ - sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight) + sender.OnCongestionEvent(ackedPacketNumber, packetLength, bytesInFlight) } bytesInFlight -= protocol.ByteCount(n) * packetLength } // Does not increment acked_packet_number_. LosePacket := func(number protocol.PacketNumber) { - sender.OnPacketLost(number, maxDatagramSize, bytesInFlight) + sender.OnCongestionEvent(number, maxDatagramSize, bytesInFlight) bytesInFlight -= maxDatagramSize } diff --git a/internal/congestion/interface.go b/internal/congestion/interface.go index 1cefb7c50..4c15b40fd 100644 --- a/internal/congestion/interface.go +++ b/internal/congestion/interface.go @@ -14,7 +14,7 @@ type SendAlgorithm interface { CanSend(bytesInFlight protocol.ByteCount) bool MaybeExitSlowStart() OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) - OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) + OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) OnRetransmissionTimeout(packetsRetransmitted bool) SetMaxDatagramSize(protocol.ByteCount) } diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index f2e6b67dd..07ade6aec 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -44,7 +44,7 @@ type cryptoSetup struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger perspective protocol.Perspective @@ -78,7 +78,7 @@ func NewCryptoSetupClient( tlsConf *tls.Config, enable0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, ) CryptoSetup { @@ -112,7 +112,7 @@ func NewCryptoSetupServer( tlsConf *tls.Config, allow0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, ) CryptoSetup { @@ -128,7 +128,7 @@ func NewCryptoSetupServer( cs.allow0RTT = allow0RTT quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} - qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) + qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket) addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) cs.tlsConf = quicConf.TLSConfig @@ -166,13 +166,13 @@ func newCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, version protocol.VersionNumber, ) *cryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) - if tracer != nil { + if tracer != nil && tracer.UpdatedKeyFromTLS != nil { tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -194,7 +194,7 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) h.initialSealer = initialSealer h.initialOpener = initialOpener - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -349,10 +349,13 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo } func (h *cryptoSetup) getDataForSessionTicket() []byte { - return (&sessionTicket{ - Parameters: h.ourParams, - RTT: h.rttStats.SmoothedRTT(), - }).Marshal() + ticket := &sessionTicket{ + RTT: h.rttStats.SmoothedRTT(), + } + if h.allow0RTT { + ticket.Parameters = h.ourParams + } + return ticket.Marshal() } // GetSessionTicket generates a new session ticket. @@ -381,12 +384,16 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { return ticket, nil } -// accept0RTT is called for the server when receiving the client's session ticket. -// It decides whether to accept 0-RTT. -func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { +// handleSessionTicket is called for the server when receiving the client's session ticket. +// It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT. +func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool { var t sessionTicket - if err := t.Unmarshal(sessionTicketData); err != nil { - h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) + if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil { + h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error()) + return false + } + h.rttStats.SetInitialRTT(t.RTT) + if !using0RTT { return false } valid := h.ourParams.ValidFor0RTT(t.Parameters) @@ -399,7 +406,6 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { return false } h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) - h.rttStats.SetInitialRTT(t.RTT) return true } @@ -453,7 +459,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr } h.mutex.Unlock() h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } @@ -475,7 +481,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } // don't set used0RTT here. 0-RTT might still get rejected. @@ -499,7 +505,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } @@ -507,7 +513,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t panic("unexpected write encryption level") } h.mutex.Unlock() - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } @@ -647,7 +653,7 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index fbc82512d..2a57ff91c 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -7,6 +7,8 @@ import ( "crypto/x509/pkix" "math/big" "net" + "runtime" + "strings" "time" tls "github.com/refraction-networking/utls" @@ -418,11 +420,13 @@ var _ = Describe("Crypto Setup TLS", func() { close(receivedSessionTicket) }) clientConf.ClientSessionCache = csc + const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, - clientOrigRTTStats, &utils.RTTStats{}, + clientOrigRTTStats, serverOrigRTTStats, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) @@ -435,9 +439,10 @@ var _ = Describe("Crypto Setup TLS", func() { csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} + serverRTTStats := &utils.RTTStats{} client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( clientConf, serverConf, - clientRTTStats, &utils.RTTStats{}, + clientRTTStats, serverRTTStats, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) @@ -447,6 +452,9 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + if !strings.Contains(runtime.Version(), "go1.20") { + Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) + } }) It("doesn't use session resumption if the server disabled it", func() { diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go index 5f5421b57..23a37792a 100644 --- a/internal/handshake/session_ticket.go +++ b/internal/handshake/session_ticket.go @@ -10,7 +10,7 @@ import ( "github.com/refraction-networking/uquic/quicvarint" ) -const sessionTicketRevision = 3 +const sessionTicketRevision = 4 type sessionTicket struct { Parameters *wire.TransportParameters @@ -21,10 +21,13 @@ func (t *sessionTicket) Marshal() []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, sessionTicketRevision) b = quicvarint.Append(b, uint64(t.RTT.Microseconds())) + if t.Parameters == nil { + return b + } return t.Parameters.MarshalForSessionTicket(b) } -func (t *sessionTicket) Unmarshal(b []byte) error { +func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error { r := bytes.NewReader(b) rev, err := quicvarint.Read(r) if err != nil { @@ -37,11 +40,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error { if err != nil { return errors.New("failed to read RTT") } - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + if using0RTT { + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + } + t.Parameters = &tp + } else if r.Len() > 0 { + return fmt.Errorf("the session ticket has more bytes than expected") } - t.Parameters = &tp t.RTT = time.Duration(rtt) * time.Microsecond return nil } diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go index 83c64a7d8..e42105108 100644 --- a/internal/handshake/session_ticket_test.go +++ b/internal/handshake/session_ticket_test.go @@ -11,7 +11,7 @@ import ( ) var _ = Describe("Session Ticket", func() { - It("marshals and unmarshals a session ticket", func() { + It("marshals and unmarshals a 0-RTT session ticket", func() { ticket := &sessionTicket{ Parameters: &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: 1, @@ -22,33 +22,65 @@ var _ = Describe("Session Ticket", func() { RTT: 1337 * time.Microsecond, } var t sessionTicket - Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) + Expect(t.Unmarshal(ticket.Marshal(), true)).To(Succeed()) Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10)) Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20)) Expect(t.RTT).To(Equal(1337 * time.Microsecond)) + // fails to unmarshal the ticket as a non-0-RTT ticket + Expect(t.Unmarshal(ticket.Marshal(), false)).To(MatchError("the session ticket has more bytes than expected")) + }) + + It("marshals and unmarshals a non-0-RTT session ticket", func() { + ticket := &sessionTicket{ + RTT: 1337 * time.Microsecond, + } + var t sessionTicket + Expect(t.Unmarshal(ticket.Marshal(), false)).To(Succeed()) + Expect(t.Parameters).To(BeNil()) + Expect(t.RTT).To(Equal(1337 * time.Microsecond)) + // fails to unmarshal the ticket as a 0-RTT ticket + Expect(t.Unmarshal(ticket.Marshal(), true)).To(MatchError(ContainSubstring("unmarshaling transport parameters from session ticket failed"))) }) It("refuses to unmarshal if the ticket is too short for the revision", func() { - Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision")) + Expect((&sessionTicket{}).Unmarshal([]byte{}, true)).To(MatchError("failed to read session ticket revision")) + Expect((&sessionTicket{}).Unmarshal([]byte{}, false)).To(MatchError("failed to read session ticket revision")) }) It("refuses to unmarshal if the revision doesn't match", func() { b := quicvarint.Append(nil, 1337) - Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("unknown session ticket revision: 1337")) + Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("unknown session ticket revision: 1337")) + Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("unknown session ticket revision: 1337")) }) It("refuses to unmarshal if the RTT cannot be read", func() { b := quicvarint.Append(nil, sessionTicketRevision) - Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("failed to read RTT")) + Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("failed to read RTT")) + Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("failed to read RTT")) }) - It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { + It("refuses to unmarshal a 0-RTT session ticket if unmarshaling the transport parameters fails", func() { b := quicvarint.Append(nil, sessionTicketRevision) b = append(b, []byte("foobar")...) - err := (&sessionTicket{}).Unmarshal(b) + err := (&sessionTicket{}).Unmarshal(b, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed")) }) + + It("refuses to unmarshal if the non-0-RTT session ticket has more bytes than expected", func() { + ticket := &sessionTicket{ + Parameters: &wire.TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + ActiveConnectionIDLimit: 10, + MaxDatagramFrameSize: 20, + }, + RTT: 1234 * time.Microsecond, + } + err := (&sessionTicket{}).Unmarshal(ticket.Marshal(), false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("the session ticket has more bytes than expected")) + }) }) diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index aaff94001..772bdccd6 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/asn1" "fmt" - "io" "net" "time" @@ -45,15 +44,9 @@ type TokenGenerator struct { tokenProtector tokenProtector } -// NewTokenGenerator initializes a new TookenGenerator -func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { - tokenProtector, err := newTokenProtector(rand) - if err != nil { - return nil, err - } - return &TokenGenerator{ - tokenProtector: tokenProtector, - }, nil +// NewTokenGenerator initializes a new TokenGenerator +func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator { + return &TokenGenerator{tokenProtector: newTokenProtector(key)} } // NewRetryToken generates a new token for a Retry for a given source address diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 50fa0545d..7f5c3f1b6 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -16,9 +16,9 @@ var _ = Describe("Token Generator", func() { var tokenGen *TokenGenerator BeforeEach(func() { - var err error - tokenGen, err = NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) + var key TokenProtectorKey + rand.Read(key[:]) + tokenGen = NewTokenGenerator(key) }) It("generates a token", func() { diff --git a/internal/handshake/token_protector.go b/internal/handshake/token_protector.go index 650f230b2..f3a99e411 100644 --- a/internal/handshake/token_protector.go +++ b/internal/handshake/token_protector.go @@ -3,6 +3,7 @@ package handshake import ( "crypto/aes" "crypto/cipher" + "crypto/rand" "crypto/sha256" "fmt" "io" @@ -10,6 +11,9 @@ import ( "golang.org/x/crypto/hkdf" ) +// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens. +type TokenProtectorKey [32]byte + // TokenProtector is used to create and verify a token type tokenProtector interface { // NewToken creates a new token @@ -18,40 +22,29 @@ type tokenProtector interface { DecodeToken([]byte) ([]byte, error) } -const ( - tokenSecretSize = 32 - tokenNonceSize = 32 -) +const tokenNonceSize = 32 // tokenProtector is used to create and verify a token type tokenProtectorImpl struct { - rand io.Reader - secret []byte + key TokenProtectorKey } // newTokenProtector creates a source for source address tokens -func newTokenProtector(rand io.Reader) (tokenProtector, error) { - secret := make([]byte, tokenSecretSize) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - return &tokenProtectorImpl{ - rand: rand, - secret: secret, - }, nil +func newTokenProtector(key TokenProtectorKey) tokenProtector { + return &tokenProtectorImpl{key: key} } // NewToken encodes data into a new token. func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, tokenNonceSize) - if _, err := s.rand.Read(nonce); err != nil { + var nonce [tokenNonceSize]byte + if _, err := rand.Read(nonce[:]); err != nil { return nil, err } - aead, aeadNonce, err := s.createAEAD(nonce) + aead, aeadNonce, err := s.createAEAD(nonce[:]) if err != nil { return nil, err } - return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil + return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil } // DecodeToken decodes a token. @@ -68,7 +61,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { } func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) + h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source")) key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 if _, err := io.ReadFull(h, key); err != nil { return nil, nil, err diff --git a/internal/handshake/token_protector_test.go b/internal/handshake/token_protector_test.go index 03cd5320c..74eb1f0cd 100644 --- a/internal/handshake/token_protector_test.go +++ b/internal/handshake/token_protector_test.go @@ -7,55 +7,54 @@ import ( . "github.com/onsi/gomega" ) -type zeroReader struct{} - -func (r *zeroReader) Read(b []byte) (int, error) { - for i := range b { - b[i] = 0 - } - return len(b), nil -} - var _ = Describe("Token Protector", func() { var tp tokenProtector BeforeEach(func() { + var key TokenProtectorKey + rand.Read(key[:]) var err error - tp, err = newTokenProtector(rand.Reader) + tp = newTokenProtector(key) Expect(err).ToNot(HaveOccurred()) }) - It("uses the random source", func() { - tp1, err := newTokenProtector(&zeroReader{}) + It("encodes and decodes tokens", func() { + token, err := tp.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - tp2, err := newTokenProtector(&zeroReader{}) + Expect(token).ToNot(ContainSubstring("foobar")) + decoded, err := tp.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) + Expect(decoded).To(Equal([]byte("foobar"))) + }) + + It("uses the different keys", func() { + var key1, key2 TokenProtectorKey + rand.Read(key1[:]) + rand.Read(key2[:]) + tp1 := newTokenProtector(key1) + tp2 := newTokenProtector(key2) t1, err := tp1.NewToken([]byte("foo")) Expect(err).ToNot(HaveOccurred()) t2, err := tp2.NewToken([]byte("foo")) Expect(err).ToNot(HaveOccurred()) - Expect(t1).To(Equal(t2)) - tp3, err := newTokenProtector(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - t3, err := tp3.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t3).ToNot(Equal(t1)) - }) - It("encodes and decodes tokens", func() { - token, err := tp.NewToken([]byte("foobar")) + _, err = tp1.DecodeToken(t1) Expect(err).ToNot(HaveOccurred()) - Expect(token).ToNot(ContainSubstring("foobar")) - decoded, err := tp.DecodeToken(token) + _, err = tp1.DecodeToken(t2) + Expect(err).To(HaveOccurred()) + + // now create another token protector, reusing key1 + tp3 := newTokenProtector(key1) + _, err = tp3.DecodeToken(t1) Expect(err).ToNot(HaveOccurred()) - Expect(decoded).To(Equal([]byte("foobar"))) + _, err = tp3.DecodeToken(t2) + Expect(err).To(HaveOccurred()) }) - It("fails deconding invalid tokens", func() { + It("doesn't decode invalid tokens", func() { token, err := tp.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - token = token[1:] // remove the first byte - _, err = tp.DecodeToken(token) + _, err = tp.DecodeToken(token[1:]) // the token is invalid without the first byte Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("message authentication failed")) }) diff --git a/internal/handshake/u_crypto_setup.go b/internal/handshake/u_crypto_setup.go index 92074ca8b..d6dc16b17 100644 --- a/internal/handshake/u_crypto_setup.go +++ b/internal/handshake/u_crypto_setup.go @@ -34,7 +34,7 @@ type uCryptoSetup struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger perspective protocol.Perspective @@ -69,7 +69,7 @@ func NewUCryptoSetupClient( tlsConf *tls.Config, enable0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, chs *tls.ClientHelloSpec, @@ -100,7 +100,7 @@ func newUCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, version protocol.VersionNumber, diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index a34beb4b8..0d9664a0d 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -58,7 +58,7 @@ type updatableAEAD struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger version protocol.VersionNumber @@ -71,7 +71,7 @@ var ( _ ShortHeaderSealer = &updatableAEAD{} ) -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { return &updatableAEAD{ firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -87,7 +87,7 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, func (a *updatableAEAD) rollKeys() { if a.prevRcvAEAD != nil { a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } a.prevRcvAEADExpiry = time.Time{} @@ -183,7 +183,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac a.prevRcvAEAD = nil a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) a.prevRcvAEADExpiry = time.Time{} - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } } @@ -217,7 +217,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac // The peer initiated this key update. It's safe to drop the keys for the previous generation now. // Start a timer to drop the previous key generation. a.startKeyDropTimer(rcvTime) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, true) } a.firstRcvdWithCurrentKey = pn @@ -309,7 +309,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { a.rollKeys() a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, false) } } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 4bc1fa86b..af8d393f0 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -12,6 +12,7 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -63,7 +64,8 @@ var _ = Describe("Updatable AEAD", func() { ) BeforeEach(func() { - serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) @@ -71,7 +73,7 @@ var _ = Describe("Updatable AEAD", func() { rttStats = utils.NewRTTStats() client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) - server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) + server = newUpdatableAEAD(rttStats, tr, utils.DefaultLogger, v) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 569b690ef..8be3e071d 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/ackhandler (interfaces: ReceivedPacketHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler +// // Package mockackhandler is a generated GoMock package. package mockackhandler @@ -43,7 +47,7 @@ func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { } // DropPackets indicates an expected call of DropPackets. -func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) } @@ -57,7 +61,7 @@ func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, a } // GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) } @@ -85,7 +89,7 @@ func (m *MockReceivedPacketHandler) IsPotentiallyDuplicate(arg0 protocol.PacketN } // IsPotentiallyDuplicate indicates an expected call of IsPotentiallyDuplicate. -func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) } @@ -99,7 +103,7 @@ func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, a } // ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) } diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 37fae0e6a..a0234086b 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/ackhandler (interfaces: SentPacketHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler +// // Package mockackhandler is a generated GoMock package. package mockackhandler @@ -44,11 +48,25 @@ func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { } // DropPackets indicates an expected call of DropPackets. -func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) } +// ECNMode mocks base method. +func (m *MockSentPacketHandler) ECNMode(arg0 bool) protocol.ECN { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ECNMode", arg0) + ret0, _ := ret[0].(protocol.ECN) + return ret0 +} + +// ECNMode indicates an expected call of ECNMode. +func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), arg0) +} + // GetLossDetectionTimeout mocks base method. func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { m.ctrl.T.Helper() @@ -87,7 +105,7 @@ func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) } // PeekPacketNumber indicates an expected call of PeekPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) } @@ -101,7 +119,7 @@ func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) p } // PopPacketNumber indicates an expected call of PopPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) } @@ -115,7 +133,7 @@ func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) } // QueueProbePacket indicates an expected call of QueueProbePacket. -func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) } @@ -130,7 +148,7 @@ func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.E } // ReceivedAck indicates an expected call of ReceivedAck. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) } @@ -142,23 +160,23 @@ func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { } // ReceivedBytes indicates an expected call of ReceivedBytes. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) } // ResetForRetry mocks base method. -func (m *MockSentPacketHandler) ResetForRetry() error { +func (m *MockSentPacketHandler) ResetForRetry(arg0 time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResetForRetry") + ret := m.ctrl.Call(m, "ResetForRetry", arg0) ret0, _ := ret[0].(error) return ret0 } // ResetForRetry indicates an expected call of ResetForRetry. -func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry), arg0) } // SendMode mocks base method. @@ -170,21 +188,21 @@ func (m *MockSentPacketHandler) SendMode(arg0 time.Time) ackhandler.SendMode { } // SendMode indicates an expected call of SendMode. -func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode), arg0) } // SentPacket mocks base method. -func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ByteCount, arg7 bool) { +func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ECN, arg7 protocol.ByteCount, arg8 bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } // SentPacket indicates an expected call of SentPacket. -func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } // SetHandshakeConfirmed mocks base method. @@ -206,7 +224,7 @@ func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) } diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go index a2ba0aabc..99ffd8169 100644 --- a/internal/mocks/congestion.go +++ b/internal/mocks/congestion.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/congestion (interfaces: SendAlgorithmWithDebugInfos) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos +// // Package mocks is a generated GoMock package. package mocks @@ -44,7 +48,7 @@ func (m *MockSendAlgorithmWithDebugInfos) CanSend(arg0 protocol.ByteCount) bool } // CanSend indicates an expected call of CanSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) } @@ -72,7 +76,7 @@ func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget(arg0 time.Time) bool { } // HasPacingBudget indicates an expected call of HasPacingBudget. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget), arg0) } @@ -117,28 +121,28 @@ func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) } -// OnPacketAcked mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { +// OnCongestionEvent mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnCongestionEvent(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3) + m.ctrl.Call(m, "OnCongestionEvent", arg0, arg1, arg2) } -// OnPacketAcked indicates an expected call of OnPacketAcked. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// OnCongestionEvent indicates an expected call of OnCongestionEvent. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnCongestionEvent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnCongestionEvent), arg0, arg1, arg2) } -// OnPacketLost mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { +// OnPacketAcked mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2) + m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3) } -// OnPacketLost indicates an expected call of OnPacketLost. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call { +// OnPacketAcked indicates an expected call of OnPacketAcked. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketLost), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) } // OnPacketSent mocks base method. @@ -148,7 +152,7 @@ func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 prot } // OnPacketSent indicates an expected call of OnPacketSent. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) } @@ -160,7 +164,7 @@ func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(arg0 bool) { } // OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) } @@ -172,7 +176,7 @@ func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteC } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) } @@ -186,7 +190,7 @@ func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) } // TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) } diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go index 60d87d39c..a2edfc229 100644 --- a/internal/mocks/connection_flow_controller.go +++ b/internal/mocks/connection_flow_controller.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/flowcontrol (interfaces: ConnectionFlowController) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController +// // Package mocks is a generated GoMock package. package mocks @@ -41,7 +45,7 @@ func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { } // AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) } @@ -53,7 +57,7 @@ func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { } // AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) } @@ -122,7 +126,7 @@ func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) } // UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) } diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 5558f9730..e4151c46c 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/handshake (interfaces: CryptoSetup) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup +// // Package mocks is a generated GoMock package. package mocks @@ -42,7 +46,7 @@ func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { } // ChangeConnectionID indicates an expected call of ChangeConnectionID. -func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) } @@ -231,7 +235,7 @@ func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLev } // HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) } @@ -271,7 +275,7 @@ func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error } // SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. -func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) } diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index 4a9d1fd70..c9728eece 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -1,377 +1,108 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/refraction-networking/uquic/logging (interfaces: ConnectionTracer) +//go:build !gomock && !generate -// Package mocklogging is a generated GoMock package. package mocklogging import ( - net "net" - reflect "reflect" - time "time" + "net" + "time" - protocol "github.com/refraction-networking/uquic/internal/protocol" - utils "github.com/refraction-networking/uquic/internal/utils" - wire "github.com/refraction-networking/uquic/internal/wire" - logging "github.com/refraction-networking/uquic/logging" - - gomock "go.uber.org/mock/gomock" -) - -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} - -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0, arg1) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} - -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} + "github.com/refraction-networking/uquic/internal/mocks/logging/internal" + "github.com/refraction-networking/uquic/logging" -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2) -} - -// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) -} - -// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3) -} - -// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3) -} - -// SentShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3) -} - -// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} + "go.uber.org/mock/gomock" +) -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +type MockConnectionTracer = internal.MockConnectionTracer + +func NewMockConnectionTracer(ctrl *gomock.Controller) (*logging.ConnectionTracer, *MockConnectionTracer) { + t := internal.NewMockConnectionTracer(ctrl) + return &logging.ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { + t.StartedConnection(local, remote, srcConnID, destConnID) + }, + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + }, + ClosedConnection: func(e error) { + t.ClosedConnection(e) + }, + SentTransportParameters: func(tp *logging.TransportParameters) { + t.SentTransportParameters(tp) + }, + ReceivedTransportParameters: func(tp *logging.TransportParameters) { + t.ReceivedTransportParameters(tp) + }, + RestoredTransportParameters: func(tp *logging.TransportParameters) { + t.RestoredTransportParameters(tp) + }, + SentLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + }, + ReceivedRetry: func(hdr *logging.Header) { + t.ReceivedRetry(hdr) + }, + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + }, + BufferedPacket: func(typ logging.PacketType, size logging.ByteCount) { + t.BufferedPacket(typ, size) + }, + DroppedPacket: func(typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(typ, size, reason) + }, + UpdatedMetrics: func(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + }, + AcknowledgedPacket: func(encLevel logging.EncryptionLevel, pn logging.PacketNumber) { + t.AcknowledgedPacket(encLevel, pn) + }, + LostPacket: func(encLevel logging.EncryptionLevel, pn logging.PacketNumber, reason logging.PacketLossReason) { + t.LostPacket(encLevel, pn, reason) + }, + UpdatedCongestionState: func(state logging.CongestionState) { + t.UpdatedCongestionState(state) + }, + UpdatedPTOCount: func(value uint32) { + t.UpdatedPTOCount(value) + }, + UpdatedKeyFromTLS: func(encLevel logging.EncryptionLevel, perspective logging.Perspective) { + t.UpdatedKeyFromTLS(encLevel, perspective) + }, + UpdatedKey: func(generation logging.KeyPhase, remote bool) { + t.UpdatedKey(generation, remote) + }, + DroppedEncryptionLevel: func(encLevel logging.EncryptionLevel) { + t.DroppedEncryptionLevel(encLevel) + }, + DroppedKey: func(generation logging.KeyPhase) { + t.DroppedKey(generation) + }, + SetLossTimer: func(typ logging.TimerType, encLevel logging.EncryptionLevel, exp time.Time) { + t.SetLossTimer(typ, encLevel, exp) + }, + LossTimerExpired: func(typ logging.TimerType, encLevel logging.EncryptionLevel) { + t.LossTimerExpired(typ, encLevel) + }, + LossTimerCanceled: func() { + t.LossTimerCanceled() + }, + ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.ECNStateUpdated(state, trigger) + }, + Close: func() { + t.Close() + }, + Debug: func(name, msg string) { + t.Debug(name, msg) + }, + }, t } diff --git a/logging/mock_connection_tracer_test.go b/internal/mocks/logging/internal/connection_tracer.go similarity index 79% rename from logging/mock_connection_tracer_test.go rename to internal/mocks/logging/internal/connection_tracer.go index e5dd38c2b..b703f980c 100644 --- a/logging/mock_connection_tracer_test.go +++ b/internal/mocks/logging/internal/connection_tracer.go @@ -1,18 +1,23 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/refraction-networking/uquic/logging (interfaces: ConnectionTracer) - -// Package logging is a generated GoMock package. -package logging +// Source: github.com/refraction-networking/uquic/internal/mocks/logging (interfaces: ConnectionTracer) +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package internal -destination internal/connection_tracer.go github.com/refraction-networking/uquic/internal/mocks/logging ConnectionTracer +// +// Package internal is a generated GoMock package. +package internal import ( net "net" reflect "reflect" time "time" - gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" utils "github.com/refraction-networking/uquic/internal/utils" wire "github.com/refraction-networking/uquic/internal/wire" + logging "github.com/refraction-networking/uquic/logging" + gomock "go.uber.org/mock/gomock" ) // MockConnectionTracer is a mock of ConnectionTracer interface. @@ -45,19 +50,19 @@ func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, } // AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) } // BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 PacketType, arg1 protocol.ByteCount) { +func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount) { m.ctrl.T.Helper() m.ctrl.Call(m, "BufferedPacket", arg0, arg1) } // BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) } @@ -81,7 +86,7 @@ func (m *MockConnectionTracer) ClosedConnection(arg0 error) { } // ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) } @@ -93,7 +98,7 @@ func (m *MockConnectionTracer) Debug(arg0, arg1 string) { } // Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) } @@ -105,7 +110,7 @@ func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLe } // DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) } @@ -117,23 +122,35 @@ func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { } // DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) } // DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 PacketType, arg1 protocol.ByteCount, arg2 PacketDropReason) { +func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { m.ctrl.T.Helper() m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) } // DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) } +// ECNStateUpdated mocks base method. +func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 logging.ECNStateTrigger) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) +} + +// ECNStateUpdated indicates an expected call of ECNStateUpdated. +func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) +} + // LossTimerCanceled mocks base method. func (m *MockConnectionTracer) LossTimerCanceled() { m.ctrl.T.Helper() @@ -147,25 +164,25 @@ func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { } // LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 TimerType, arg1 protocol.EncryptionLevel) { +func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { m.ctrl.T.Helper() m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) } // LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) } // LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 PacketLossReason) { +func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { m.ctrl.T.Helper() m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) } // LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) } @@ -177,21 +194,21 @@ func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, ar } // NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) } // ReceivedLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) { +func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2) + m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) } // ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) } // ReceivedRetry mocks base method. @@ -201,21 +218,21 @@ func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { } // ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) } // ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) { +func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) + m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) } // ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) } // ReceivedTransportParameters mocks base method. @@ -225,7 +242,7 @@ func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportP } // ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) } @@ -237,7 +254,7 @@ func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 proto } // ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) } @@ -249,33 +266,33 @@ func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportP } // RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) } // SentLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { +func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3) + m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) } // SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) } // SentShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { +func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3) + m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) } // SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) } // SentTransportParameters mocks base method. @@ -285,19 +302,19 @@ func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParam } // SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) } // SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { +func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) } // SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) } @@ -309,19 +326,19 @@ func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 } // StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) } // UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 CongestionState) { +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { m.ctrl.T.Helper() m.ctrl.Call(m, "UpdatedCongestionState", arg0) } // UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) } @@ -333,7 +350,7 @@ func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { } // UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) } @@ -345,7 +362,7 @@ func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, } // UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) } @@ -357,7 +374,7 @@ func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 p } // UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) } @@ -369,7 +386,7 @@ func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { } // UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) } diff --git a/logging/mock_tracer_test.go b/internal/mocks/logging/internal/tracer.go similarity index 77% rename from logging/mock_tracer_test.go rename to internal/mocks/logging/internal/tracer.go index d295c96d6..707abb431 100644 --- a/logging/mock_tracer_test.go +++ b/internal/mocks/logging/internal/tracer.go @@ -1,13 +1,18 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/refraction-networking/uquic/logging (interfaces: Tracer) - -// Package logging is a generated GoMock package. -package logging +// Source: github.com/refraction-networking/uquic/internal/mocks/logging (interfaces: Tracer) +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package internal -destination internal/tracer.go github.com/refraction-networking/uquic/internal/mocks/logging Tracer +// +// Package internal is a generated GoMock package. +package internal import ( net "net" reflect "reflect" + logging "github.com/refraction-networking/uquic/logging" gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" @@ -37,25 +42,25 @@ func (m *MockTracer) EXPECT() *MockTracerMockRecorder { } // DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 PacketType, arg2 protocol.ByteCount, arg3 PacketDropReason) { +func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { m.ctrl.T.Helper() m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) } // DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) } // SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []Frame) { +func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { m.ctrl.T.Helper() m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) } // SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) } @@ -67,7 +72,7 @@ func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 prot } // SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. -func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) } diff --git a/internal/mocks/logging/mockgen.go b/internal/mocks/logging/mockgen.go new file mode 100644 index 000000000..a4a7f4f05 --- /dev/null +++ b/internal/mocks/logging/mockgen.go @@ -0,0 +1,51 @@ +//go:build gomock || generate + +package mocklogging + +import ( + "net" + "time" + + "github.com/refraction-networking/uquic/logging" +) + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/tracer.go github.com/refraction-networking/uquic/internal/mocks/logging Tracer" +type Tracer interface { + SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) + SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) + DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) +} + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/connection_tracer.go github.com/refraction-networking/uquic/internal/mocks/logging ConnectionTracer" +type ConnectionTracer interface { + StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) + NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) + ClosedConnection(error) + SentTransportParameters(*logging.TransportParameters) + ReceivedTransportParameters(*logging.TransportParameters) + RestoredTransportParameters(parameters *logging.TransportParameters) // for 0-RTT + SentLongHeaderPacket(*logging.ExtendedHeader, logging.ByteCount, logging.ECN, *logging.AckFrame, []logging.Frame) + SentShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, *logging.AckFrame, []logging.Frame) + ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) + ReceivedRetry(*logging.Header) + ReceivedLongHeaderPacket(*logging.ExtendedHeader, logging.ByteCount, logging.ECN, []logging.Frame) + ReceivedShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, []logging.Frame) + BufferedPacket(logging.PacketType, logging.ByteCount) + DroppedPacket(logging.PacketType, logging.ByteCount, logging.PacketDropReason) + UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) + AcknowledgedPacket(logging.EncryptionLevel, logging.PacketNumber) + LostPacket(logging.EncryptionLevel, logging.PacketNumber, logging.PacketLossReason) + UpdatedCongestionState(logging.CongestionState) + UpdatedPTOCount(value uint32) + UpdatedKeyFromTLS(logging.EncryptionLevel, logging.Perspective) + UpdatedKey(generation logging.KeyPhase, remote bool) + DroppedEncryptionLevel(logging.EncryptionLevel) + DroppedKey(generation logging.KeyPhase) + SetLossTimer(logging.TimerType, logging.EncryptionLevel, time.Time) + LossTimerExpired(logging.TimerType, logging.EncryptionLevel) + LossTimerCanceled() + ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) + // Close is called when the connection is closed. + Close() + Debug(name, msg string) +} diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index b479c66c1..9fc620996 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -1,74 +1,29 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/refraction-networking/uquic/logging (interfaces: Tracer) +//go:build !gomock && !generate -// Package mocklogging is a generated GoMock package. package mocklogging import ( - net "net" - reflect "reflect" + "net" - protocol "github.com/refraction-networking/uquic/internal/protocol" - wire "github.com/refraction-networking/uquic/internal/wire" - logging "github.com/refraction-networking/uquic/logging" - gomock "go.uber.org/mock/gomock" -) - -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} - -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} + "github.com/refraction-networking/uquic/internal/mocks/logging/internal" + "github.com/refraction-networking/uquic/logging" -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentVersionNegotiationPacket mocks base method. -func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) -} + "go.uber.org/mock/gomock" +) -// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. -func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) +type MockTracer = internal.MockTracer + +func NewMockTracer(ctrl *gomock.Controller) (*logging.Tracer, *MockTracer) { + t := internal.NewMockTracer(ctrl) + return &logging.Tracer{ + SentPacket: func(remote net.Addr, hdr *logging.Header, size logging.ByteCount, frames []logging.Frame) { + t.SentPacket(remote, hdr, size, frames) + }, + SentVersionNegotiationPacket: func(remote net.Addr, dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + }, + DroppedPacket: func(remote net.Addr, typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(remote, typ, size, reason) + }, + }, t } diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go index 9552aff7e..adfcd1f92 100644 --- a/internal/mocks/long_header_opener.go +++ b/internal/mocks/long_header_opener.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/handshake (interfaces: LongHeaderOpener) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener +// // Package mocks is a generated GoMock package. package mocks @@ -43,7 +47,7 @@ func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, ar } // DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) } @@ -55,7 +59,7 @@ func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byt } // DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) } @@ -70,7 +74,7 @@ func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumbe } // Open indicates an expected call of Open. -func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) } diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 58ed55786..30174b48a 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,19 +1,19 @@ +//go:build gomock || generate + package mocks -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/refraction-networking/uquic/logging Tracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/refraction-networking/uquic/logging ConnectionTracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/refraction-networking/uquic/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/refraction-networking/uquic/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler" // The following command produces a warning message on OSX, however, it still generates the correct mock file. // See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index 428c0eb17..cfcb27835 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: EarlyConnection) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection +// // Package mockquic is a generated GoMock package. package mockquic @@ -47,7 +51,7 @@ func (m *MockEarlyConnection) AcceptStream(arg0 context.Context) (quic.Stream, e } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptStream), arg0) } @@ -62,7 +66,7 @@ func (m *MockEarlyConnection) AcceptUniStream(arg0 context.Context) (quic.Receiv } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptUniStream), arg0) } @@ -76,7 +80,7 @@ func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlyConnection)(nil).CloseWithError), arg0, arg1) } @@ -176,7 +180,7 @@ func (m *MockEarlyConnection) OpenStreamSync(arg0 context.Context) (quic.Stream, } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStreamSync), arg0) } @@ -206,7 +210,7 @@ func (m *MockEarlyConnection) OpenUniStreamSync(arg0 context.Context) (quic.Send } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStreamSync), arg0) } @@ -221,7 +225,7 @@ func (m *MockEarlyConnection) ReceiveMessage(arg0 context.Context) ([]byte, erro } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage), arg0) } @@ -249,7 +253,7 @@ func (m *MockEarlyConnection) SendMessage(arg0 []byte) error { } // SendMessage indicates an expected call of SendMessage. -func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlyConnection)(nil).SendMessage), arg0) } diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go index 97c52c526..8ee289520 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/mocks/quic/stream.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: Stream) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream +// // Package mockquic is a generated GoMock package. package mockquic @@ -44,7 +48,7 @@ func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) CancelRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0) } @@ -56,7 +60,7 @@ func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) CancelWrite(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0) } @@ -99,7 +103,7 @@ func (m *MockStream) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) Read(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) } @@ -113,7 +117,7 @@ func (m *MockStream) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) SetDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0) } @@ -127,7 +131,7 @@ func (m *MockStream) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0) } @@ -141,7 +145,7 @@ func (m *MockStream) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0) } @@ -170,7 +174,7 @@ func (m *MockStream) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0) } diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go index 4f609b937..08aaf8f92 100644 --- a/internal/mocks/short_header_opener.go +++ b/internal/mocks/short_header_opener.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/handshake (interfaces: ShortHeaderOpener) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener +// // Package mocks is a generated GoMock package. package mocks @@ -44,7 +48,7 @@ func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, a } // DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) } @@ -56,7 +60,7 @@ func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []by } // DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) } @@ -71,7 +75,7 @@ func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 pro } // Open indicates an expected call of Open. -func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) } diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go index dfb654bfe..768543d60 100644 --- a/internal/mocks/short_header_sealer.go +++ b/internal/mocks/short_header_sealer.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/handshake (interfaces: ShortHeaderSealer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer +// // Package mocks is a generated GoMock package. package mocks @@ -41,7 +45,7 @@ func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []by } // EncryptHeader indicates an expected call of EncryptHeader. -func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) } @@ -83,7 +87,7 @@ func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumb } // Seal indicates an expected call of Seal. -func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) } diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index 3a16e9375..3bacec394 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic/internal/flowcontrol (interfaces: StreamFlowController) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController +// // Package mocks is a generated GoMock package. package mocks @@ -53,7 +57,7 @@ func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { } // AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) } @@ -65,7 +69,7 @@ func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { } // AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) } @@ -122,7 +126,7 @@ func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount } // UpdateHighestReceived indicates an expected call of UpdateHighestReceived. -func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) } @@ -134,7 +138,7 @@ func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { } // UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) } diff --git a/internal/mocks/tls/client_session_cache.go b/internal/mocks/tls/client_session_cache.go index 9a7c34440..30483c14f 100644 --- a/internal/mocks/tls/client_session_cache.go +++ b/internal/mocks/tls/client_session_cache.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: crypto/tls (interfaces: ClientSessionCache) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache +// // Package mocktls is a generated GoMock package. package mocktls @@ -44,7 +48,7 @@ func (m *MockClientSessionCache) Get(arg0 string) (*tls.ClientSessionState, bool } // Get indicates an expected call of Get. -func (mr *MockClientSessionCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { +func (mr *MockClientSessionCacheMockRecorder) Get(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) } @@ -56,7 +60,7 @@ func (m *MockClientSessionCache) Put(arg0 string, arg1 *tls.ClientSessionState) } // Put indicates an expected call of Put. -func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) } diff --git a/internal/protocol/params.go b/internal/protocol/params.go index fe3a75625..3ca68bf83 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -65,9 +65,6 @@ const MaxAcceptQueueSize = 32 // TokenValidity is the duration that a (non-retry) token is considered valid const TokenValidity = 24 * time.Hour -// RetryTokenValidity is the duration that a retry token is considered valid -const RetryTokenValidity = 10 * time.Second - // MaxOutstandingSentPackets is maximum number of packets saved for retransmission. // When reached, it imposes a soft limit on sending new packets: // Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. @@ -108,9 +105,6 @@ const DefaultIdleTimeout = 30 * time.Second // DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. const DefaultHandshakeIdleTimeout = 5 * time.Second -// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. -const DefaultHandshakeTimeout = 10 * time.Second - // MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. // It should be shorter than the time that NATs clear their mapping. const MaxKeepAliveInterval = 20 * time.Second diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index b8104882b..d056cb9de 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -37,14 +37,48 @@ func (t PacketType) String() string { type ECN uint8 const ( - ECNNon ECN = iota // 00 - ECT1 // 01 - ECT0 // 10 - ECNCE // 11 + ECNUnsupported ECN = iota + ECNNon // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 ) +func ParseECNHeaderBits(bits byte) ECN { + switch bits { + case 0: + return ECNNon + case 0b00000010: + return ECT0 + case 0b00000001: + return ECT1 + case 0b00000011: + return ECNCE + default: + panic("invalid ECN bits") + } +} + +func (e ECN) ToHeaderBits() byte { + //nolint:exhaustive // There are only 4 values. + switch e { + case ECNNon: + return 0 + case ECT0: + return 0b00000010 + case ECT1: + return 0b00000001 + case ECNCE: + return 0b00000011 + default: + panic("ECN unsupported") + } +} + func (e ECN) String() string { switch e { + case ECNUnsupported: + return "ECN unsupported" case ECNNon: return "Not-ECT" case ECT1: diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index e672d31e1..22359e6af 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -17,13 +17,22 @@ var _ = Describe("Protocol", func() { }) It("converts ECN bits from the IP header wire to the correct types", func() { - Expect(ECN(0)).To(Equal(ECNNon)) - Expect(ECN(0b00000010)).To(Equal(ECT0)) - Expect(ECN(0b00000001)).To(Equal(ECT1)) - Expect(ECN(0b00000011)).To(Equal(ECNCE)) + Expect(ParseECNHeaderBits(0)).To(Equal(ECNNon)) + Expect(ParseECNHeaderBits(0b00000010)).To(Equal(ECT0)) + Expect(ParseECNHeaderBits(0b00000001)).To(Equal(ECT1)) + Expect(ParseECNHeaderBits(0b00000011)).To(Equal(ECNCE)) + Expect(func() { ParseECNHeaderBits(0b1010101) }).To(Panic()) + }) + + It("converts to IP header bits", func() { + for _, v := range [...]ECN{ECNNon, ECT0, ECT1, ECNCE} { + Expect(ParseECNHeaderBits(v.ToHeaderBits())).To(Equal(v)) + } + Expect(func() { ECN(42).ToHeaderBits() }).To(Panic()) }) It("has a string representation for ECN", func() { + Expect(ECNUnsupported.String()).To(Equal("ECN unsupported")) Expect(ECNNon.String()).To(Equal("Not-ECT")) Expect(ECT0.String()).To(Equal("ECT(0)")) Expect(ECT1.String()).To(Equal("ECT(1)")) diff --git a/internal/qtls/utls.go b/internal/qtls/utls.go index 149c6cc3c..45f19e11e 100644 --- a/internal/qtls/utls.go +++ b/internal/qtls/utls.go @@ -55,7 +55,7 @@ func UQUICClient(config *QUICConfig, clientHelloSpec *tls.ClientHelloSpec) *UQUI return uqc } -func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) { +func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) { conf := qconf.TLSConfig // Workaround for https://github.com/golang/go/issues/60506. @@ -69,11 +69,9 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce // add callbacks to save transport parameters into the session ticket origWrapSession := conf.WrapSession conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { - // Add QUIC transport parameters if this is a 0-RTT packet. - // TODO(#3853): also save the RTT for non-0-RTT tickets - if state.EarlyData { - state.Extra = append(state.Extra, addExtraPrefix(getData())) - } + // Add QUIC session ticket + state.Extra = append(state.Extra, addExtraPrefix(getData())) + if origWrapSession != nil { return origWrapSession(cs, state) } @@ -97,14 +95,14 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce if err != nil || state == nil { return nil, err } - if state.EarlyData { - extra := findExtraData(state.Extra) - if unwrapCount == 1 && extra != nil { // first session ticket - state.EarlyData = accept0RTT(extra) - } else { // subsequent session ticket, can't be used for 0-RTT - state.EarlyData = false - } + + extra := findExtraData(state.Extra) + if extra != nil { + state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1) + } else { + state.EarlyData = false } + return state, nil } } diff --git a/internal/utils/log_test.go b/internal/utils/log_test.go index 98b513d4d..4be75dffe 100644 --- a/internal/utils/log_test.go +++ b/internal/utils/log_test.go @@ -82,7 +82,7 @@ var _ = Describe("Log", func() { DefaultLogger.Infof("info") t, err := time.Parse(format, b.String()[:b.Len()-6]) Expect(err).ToNot(HaveOccurred()) - Expect(t).To(BeTemporally("~", time.Now(), 25*time.Hour)) + Expect(t).To(BeTemporally("~", time.Now(), 48*time.Hour)) }) It("says whether debug is enabled", func() { diff --git a/logging/connection_tracer.go b/logging/connection_tracer.go new file mode 100644 index 000000000..e3f322d91 --- /dev/null +++ b/logging/connection_tracer.go @@ -0,0 +1,255 @@ +package logging + +import ( + "net" + "time" +) + +// A ConnectionTracer records events. +type ConnectionTracer struct { + StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) + ClosedConnection func(error) + SentTransportParameters func(*TransportParameters) + ReceivedTransportParameters func(*TransportParameters) + RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT + SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) + SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) + ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []VersionNumber) + ReceivedRetry func(*Header) + ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame) + ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame) + BufferedPacket func(PacketType, ByteCount) + DroppedPacket func(PacketType, ByteCount, PacketDropReason) + UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) + AcknowledgedPacket func(EncryptionLevel, PacketNumber) + LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState func(CongestionState) + UpdatedPTOCount func(value uint32) + UpdatedKeyFromTLS func(EncryptionLevel, Perspective) + UpdatedKey func(generation KeyPhase, remote bool) + DroppedEncryptionLevel func(EncryptionLevel) + DroppedKey func(generation KeyPhase) + SetLossTimer func(TimerType, EncryptionLevel, time.Time) + LossTimerExpired func(TimerType, EncryptionLevel) + LossTimerCanceled func() + ECNStateUpdated func(state ECNState, trigger ECNStateTrigger) + // Close is called when the connection is closed. + Close func() + Debug func(name, msg string) +} + +// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. +func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) { + for _, t := range tracers { + if t.StartedConnection != nil { + t.StartedConnection(local, remote, srcConnID, destConnID) + } + } + }, + NegotiatedVersion: func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range tracers { + if t.NegotiatedVersion != nil { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } + } + }, + ClosedConnection: func(e error) { + for _, t := range tracers { + if t.ClosedConnection != nil { + t.ClosedConnection(e) + } + } + }, + SentTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.SentTransportParameters != nil { + t.SentTransportParameters(tp) + } + } + }, + ReceivedTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.ReceivedTransportParameters != nil { + t.ReceivedTransportParameters(tp) + } + } + }, + RestoredTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.RestoredTransportParameters != nil { + t.RestoredTransportParameters(tp) + } + } + }, + SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentLongHeaderPacket != nil { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentShortHeaderPacket != nil { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.ReceivedVersionNegotiationPacket != nil { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + } + } + }, + ReceivedRetry: func(hdr *Header) { + for _, t := range tracers { + if t.ReceivedRetry != nil { + t.ReceivedRetry(hdr) + } + } + }, + ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedLongHeaderPacket != nil { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + } + } + }, + ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedShortHeaderPacket != nil { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + } + } + }, + BufferedPacket: func(typ PacketType, size ByteCount) { + for _, t := range tracers { + if t.BufferedPacket != nil { + t.BufferedPacket(typ, size) + } + } + }, + DroppedPacket: func(typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(typ, size, reason) + } + } + }, + UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { + for _, t := range tracers { + if t.UpdatedMetrics != nil { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + } + } + }, + AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range tracers { + if t.AcknowledgedPacket != nil { + t.AcknowledgedPacket(encLevel, pn) + } + } + }, + LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range tracers { + if t.LostPacket != nil { + t.LostPacket(encLevel, pn, reason) + } + } + }, + UpdatedCongestionState: func(state CongestionState) { + for _, t := range tracers { + if t.UpdatedCongestionState != nil { + t.UpdatedCongestionState(state) + } + } + }, + UpdatedPTOCount: func(value uint32) { + for _, t := range tracers { + if t.UpdatedPTOCount != nil { + t.UpdatedPTOCount(value) + } + } + }, + UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) { + for _, t := range tracers { + if t.UpdatedKeyFromTLS != nil { + t.UpdatedKeyFromTLS(encLevel, perspective) + } + } + }, + UpdatedKey: func(generation KeyPhase, remote bool) { + for _, t := range tracers { + if t.UpdatedKey != nil { + t.UpdatedKey(generation, remote) + } + } + }, + DroppedEncryptionLevel: func(encLevel EncryptionLevel) { + for _, t := range tracers { + if t.DroppedEncryptionLevel != nil { + t.DroppedEncryptionLevel(encLevel) + } + } + }, + DroppedKey: func(generation KeyPhase) { + for _, t := range tracers { + if t.DroppedKey != nil { + t.DroppedKey(generation) + } + } + }, + SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) { + for _, t := range tracers { + if t.SetLossTimer != nil { + t.SetLossTimer(typ, encLevel, exp) + } + } + }, + LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) { + for _, t := range tracers { + if t.LossTimerExpired != nil { + t.LossTimerExpired(typ, encLevel) + } + } + }, + LossTimerCanceled: func() { + for _, t := range tracers { + if t.LossTimerCanceled != nil { + t.LossTimerCanceled() + } + } + }, + ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { + for _, t := range tracers { + if t.ECNStateUpdated != nil { + t.ECNStateUpdated(state, trigger) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + Debug: func(name, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + } +} diff --git a/logging/interface.go b/logging/interface.go index dfa434fd4..355bc09aa 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -3,9 +3,6 @@ package logging import ( - "net" - "time" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/utils" @@ -15,6 +12,8 @@ import ( type ( // A ByteCount is used to count bytes. ByteCount = protocol.ByteCount + // ECN is the ECN value + ECN = protocol.ECN // A ConnectionID is a QUIC Connection ID. ConnectionID = protocol.ConnectionID // An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long. @@ -58,6 +57,19 @@ type ( RTTStats = utils.RTTStats ) +const ( + // ECNUnsupported means that no ECN value was set / received + ECNUnsupported = protocol.ECNUnsupported + // ECTNot is Not-ECT + ECTNot = protocol.ECNNon + // ECT0 is ECT(0) + ECT0 = protocol.ECT0 + // ECT1 is ECT(1) + ECT1 = protocol.ECT1 + // ECNCE is CE + ECNCE = protocol.ECNCE +) + const ( // KeyPhaseZero is key phase bit 0 KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero @@ -97,43 +109,3 @@ type ShortHeader struct { PacketNumberLen protocol.PacketNumberLen KeyPhase KeyPhaseBit } - -// A Tracer traces events. -type Tracer interface { - SentPacket(net.Addr, *Header, ByteCount, []Frame) - SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) - DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) -} - -// A ConnectionTracer records events. -type ConnectionTracer interface { - StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) - NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) - ClosedConnection(error) - SentTransportParameters(*TransportParameters) - ReceivedTransportParameters(*TransportParameters) - RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) - SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) - ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) - ReceivedRetry(*Header) - ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) - ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) - BufferedPacket(PacketType, ByteCount) - DroppedPacket(PacketType, ByteCount, PacketDropReason) - UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket(EncryptionLevel, PacketNumber) - LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) - UpdatedCongestionState(CongestionState) - UpdatedPTOCount(value uint32) - UpdatedKeyFromTLS(EncryptionLevel, Perspective) - UpdatedKey(generation KeyPhase, remote bool) - DroppedEncryptionLevel(EncryptionLevel) - DroppedKey(generation KeyPhase) - SetLossTimer(TimerType, EncryptionLevel, time.Time) - LossTimerExpired(TimerType, EncryptionLevel) - LossTimerCanceled() - // Close is called when the connection is closed. - Close() - Debug(name, msg string) -} diff --git a/logging/logging_suite_test.go b/logging/logging_suite_test.go index d808adfec..e595313de 100644 --- a/logging/logging_suite_test.go +++ b/logging/logging_suite_test.go @@ -1,4 +1,4 @@ -package logging +package logging_test import ( "testing" diff --git a/logging/mockgen.go b/logging/mockgen.go deleted file mode 100644 index 66f712bcf..000000000 --- a/logging/mockgen.go +++ /dev/null @@ -1,4 +0,0 @@ -package logging - -//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_connection_tracer_test.go github.com/refraction-networking/uquic/logging ConnectionTracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_tracer_test.go github.com/refraction-networking/uquic/logging Tracer" diff --git a/logging/multiplex.go b/logging/multiplex.go deleted file mode 100644 index 672a5cdbd..000000000 --- a/logging/multiplex.go +++ /dev/null @@ -1,226 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -type tracerMultiplexer struct { - tracers []Tracer -} - -var _ Tracer = &tracerMultiplexer{} - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...Tracer) Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &tracerMultiplexer{tracers} -} - -func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(remote, hdr, size, frames) - } -} - -func (m *tracerMultiplexer) SentVersionNegotiationPacket(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.SentVersionNegotiationPacket(remote, dest, src, versions) - } -} - -func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(remote, typ, size, reason) - } -} - -type connTracerMultiplexer struct { - tracers []ConnectionTracer -} - -var _ ConnectionTracer = &connTracerMultiplexer{} - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &connTracerMultiplexer{tracers: tracers} -} - -func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range m.tracers { - t.StartedConnection(local, remote, srcConnID, destConnID) - } -} - -func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { - for _, t := range m.tracers { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } -} - -func (m *connTracerMultiplexer) ClosedConnection(e error) { - for _, t := range m.tracers { - t.ClosedConnection(e) - } -} - -func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.SentTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.ReceivedTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.RestoredTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentLongHeaderPacket(hdr, size, ack, frames) - } -} - -func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentShortHeaderPacket(hdr, size, ack, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(dest, src, versions) - } -} - -func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { - for _, t := range m.tracers { - t.ReceivedRetry(hdr) - } -} - -func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedLongHeaderPacket(hdr, size, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedShortHeaderPacket(hdr, size, frames) - } -} - -func (m *connTracerMultiplexer) BufferedPacket(typ PacketType, size ByteCount) { - for _, t := range m.tracers { - t.BufferedPacket(typ, size) - } -} - -func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(typ, size, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { - for _, t := range m.tracers { - t.UpdatedCongestionState(state) - } -} - -func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { - for _, t := range m.tracers { - t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) - } -} - -func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range m.tracers { - t.AcknowledgedPacket(encLevel, pn) - } -} - -func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range m.tracers { - t.LostPacket(encLevel, pn, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { - for _, t := range m.tracers { - t.UpdatedPTOCount(value) - } -} - -func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range m.tracers { - t.UpdatedKeyFromTLS(encLevel, perspective) - } -} - -func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { - for _, t := range m.tracers { - t.UpdatedKey(generation, remote) - } -} - -func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.DroppedEncryptionLevel(encLevel) - } -} - -func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { - for _, t := range m.tracers { - t.DroppedKey(generation) - } -} - -func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range m.tracers { - t.SetLossTimer(typ, encLevel, exp) - } -} - -func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.LossTimerExpired(typ, encLevel) - } -} - -func (m *connTracerMultiplexer) LossTimerCanceled() { - for _, t := range m.tracers { - t.LossTimerCanceled() - } -} - -func (m *connTracerMultiplexer) Debug(name, msg string) { - for _, t := range m.tracers { - t.Debug(name, msg) - } -} - -func (m *connTracerMultiplexer) Close() { - for _, t := range m.tracers { - t.Close() - } -} diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 0d21a0c6e..cd87641d4 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -1,12 +1,14 @@ -package logging +package logging_test import ( "errors" "net" "time" + mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" + . "github.com/refraction-networking/uquic/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -19,21 +21,22 @@ var _ = Describe("Tracing", func() { }) It("returns the raw tracer if only one tracer is passed in", func() { - tr := NewMockTracer(mockCtrl) + tr := &Tracer{} tracer := NewMultiplexedTracer(tr) - Expect(tracer).To(BeAssignableToTypeOf(&MockTracer{})) + Expect(tracer).To(Equal(tr)) }) Context("tracing events", func() { var ( - tracer Tracer - tr1, tr2 *MockTracer + tracer *Tracer + tr1, tr2 *mocklogging.MockTracer ) BeforeEach(func() { - tr1 = NewMockTracer(mockCtrl) - tr2 = NewMockTracer(mockCtrl) - tracer = NewMultiplexedTracer(tr1, tr2) + var t1, t2 *Tracer + t1, tr1 = mocklogging.NewMockTracer(mockCtrl) + t2, tr2 = mocklogging.NewMockTracer(mockCtrl) + tracer = NewMultiplexedTracer(t1, t2, &Tracer{}) }) It("traces the PacketSent event", func() { @@ -66,18 +69,19 @@ var _ = Describe("Tracing", func() { Context("Connection Tracer", func() { var ( - tracer ConnectionTracer - tr1 *MockConnectionTracer - tr2 *MockConnectionTracer + tracer *ConnectionTracer + tr1 *mocklogging.MockConnectionTracer + tr2 *mocklogging.MockConnectionTracer ) BeforeEach(func() { - tr1 = NewMockConnectionTracer(mockCtrl) - tr2 = NewMockConnectionTracer(mockCtrl) - tracer = NewMultiplexedConnectionTracer(tr1, tr2) + var t1, t2 *ConnectionTracer + t1, tr1 = mocklogging.NewMockConnectionTracer(mockCtrl) + t2, tr2 = mocklogging.NewMockConnectionTracer(mockCtrl) + tracer = NewMultiplexedConnectionTracer(t1, t2) }) - It("trace the ConnectionStarted event", func() { + It("traces the StartedConnection event", func() { local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} dest := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) @@ -87,6 +91,15 @@ var _ = Describe("Tracing", func() { tracer.StartedConnection(local, remote, src, dest) }) + It("traces the NegotiatedVersion event", func() { + chosen := protocol.Version2 + client := []protocol.VersionNumber{protocol.Version1} + server := []protocol.VersionNumber{13, 37} + tr1.EXPECT().NegotiatedVersion(chosen, client, server) + tr2.EXPECT().NegotiatedVersion(chosen, client, server) + tracer.NegotiatedVersion(chosen, client, server) + }) + It("traces the ClosedConnection event", func() { e := errors.New("test err") tr1.EXPECT().ClosedConnection(e) @@ -119,18 +132,18 @@ var _ = Describe("Tracing", func() { hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} ping := &PingFrame{} - tr1.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tr2.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tracer.SentLongHeaderPacket(hdr, 1337, ack, []Frame{ping}) + tr1.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ECTNot, ack, []Frame{ping}) + tr2.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ECTNot, ack, []Frame{ping}) + tracer.SentLongHeaderPacket(hdr, 1337, ECTNot, ack, []Frame{ping}) }) It("traces the SentShortHeaderPacket event", func() { hdr := &ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} ping := &PingFrame{} - tr1.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tr2.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tracer.SentShortHeaderPacket(hdr, 1337, ack, []Frame{ping}) + tr1.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ECNCE, ack, []Frame{ping}) + tr2.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ECNCE, ack, []Frame{ping}) + tracer.SentShortHeaderPacket(hdr, 1337, ECNCE, ack, []Frame{ping}) }) It("traces the ReceivedVersionNegotiationPacket event", func() { @@ -151,17 +164,17 @@ var _ = Describe("Tracing", func() { It("traces the ReceivedLongHeaderPacket event", func() { hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ping := &PingFrame{} - tr1.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tr2.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tracer.ReceivedLongHeaderPacket(hdr, 1337, []Frame{ping}) + tr1.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), ECT1, []Frame{ping}) + tr2.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), ECT1, []Frame{ping}) + tracer.ReceivedLongHeaderPacket(hdr, 1337, ECT1, []Frame{ping}) }) It("traces the ReceivedShortHeaderPacket event", func() { hdr := &ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} ping := &PingFrame{} - tr1.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tr2.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tracer.ReceivedShortHeaderPacket(hdr, 1337, []Frame{ping}) + tr1.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), ECT0, []Frame{ping}) + tr2.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), ECT0, []Frame{ping}) + tracer.ReceivedShortHeaderPacket(hdr, 1337, ECT0, []Frame{ping}) }) It("traces the BufferedPacket event", func() { diff --git a/logging/null_tracer.go b/logging/null_tracer.go deleted file mode 100644 index de9703857..000000000 --- a/logging/null_tracer.go +++ /dev/null @@ -1,58 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -// The NullTracer is a Tracer that does nothing. -// It is useful for embedding. -type NullTracer struct{} - -var _ Tracer = &NullTracer{} - -func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} -func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullTracer) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) {} - -// The NullConnectionTracer is a ConnectionTracer that does nothing. -// It is useful for embedding. -type NullConnectionTracer struct{} - -var _ ConnectionTracer = &NullConnectionTracer{} - -func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { -} - -func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { -} -func (n NullConnectionTracer) ClosedConnection(err error) {} -func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {} -func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, *AckFrame, []Frame) {} -func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullConnectionTracer) ReceivedRetry(*Header) {} -func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, []Frame) {} -func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, []Frame) {} -func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {} -func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {} - -func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { -} -func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {} -func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {} -func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {} -func (n NullConnectionTracer) UpdatedPTOCount(uint32) {} -func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {} -func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {} -func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {} -func (n NullConnectionTracer) DroppedKey(KeyPhase) {} -func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {} -func (n NullConnectionTracer) LossTimerExpired(timerType TimerType, level EncryptionLevel) {} -func (n NullConnectionTracer) LossTimerCanceled() {} -func (n NullConnectionTracer) Close() {} -func (n NullConnectionTracer) Debug(name, msg string) {} diff --git a/logging/packet_header_test.go b/logging/packet_header_test.go index c10a028f4..67d3f3e30 100644 --- a/logging/packet_header_test.go +++ b/logging/packet_header_test.go @@ -1,8 +1,9 @@ -package logging +package logging_test import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" + . "github.com/refraction-networking/uquic/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" diff --git a/logging/tracer.go b/logging/tracer.go new file mode 100644 index 000000000..5918f30f8 --- /dev/null +++ b/logging/tracer.go @@ -0,0 +1,43 @@ +package logging + +import "net" + +// A Tracer traces events. +type Tracer struct { + SentPacket func(net.Addr, *Header, ByteCount, []Frame) + SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) + DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) +} + +// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. +func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &Tracer{ + SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range tracers { + if t.SentPacket != nil { + t.SentPacket(remote, hdr, size, frames) + } + } + }, + SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.SentVersionNegotiationPacket != nil { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + } + } + }, + DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(remote, typ, size, reason) + } + } + }, + } +} diff --git a/logging/types.go b/logging/types.go index ad8006923..0d79b0a90 100644 --- a/logging/types.go +++ b/logging/types.go @@ -92,3 +92,37 @@ const ( // CongestionStateApplicationLimited means that the congestion controller is application limited CongestionStateApplicationLimited ) + +// ECNState is the state of the ECN state machine (see Appendix A.4 of RFC 9000) +type ECNState uint8 + +const ( + // ECNStateTesting is the testing state + ECNStateTesting ECNState = 1 + iota + // ECNStateUnknown is the unknown state + ECNStateUnknown + // ECNStateFailed is the failed state + ECNStateFailed + // ECNStateCapable is the capable state + ECNStateCapable +) + +// ECNStateTrigger is a trigger for an ECN state transition. +type ECNStateTrigger uint8 + +const ( + ECNTriggerNoTrigger ECNStateTrigger = iota + // ECNFailedNoECNCounts is emitted when an ACK acknowledges ECN-marked packets, + // but doesn't contain any ECN counts + ECNFailedNoECNCounts + // ECNFailedDecreasedECNCounts is emitted when an ACK frame decreases ECN counts + ECNFailedDecreasedECNCounts + // ECNFailedLostAllTestingPackets is emitted when all ECN testing packets are declared lost + ECNFailedLostAllTestingPackets + // ECNFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent + ECNFailedMoreECNCountsThanSent + // ECNFailedTooFewECNCounts is emitted when an ACK contains fewer ECN counts than it acknowledges packets + ECNFailedTooFewECNCounts + // ECNFailedManglingDetected is emitted when the path marks all ECN-marked packets as CE + ECNFailedManglingDetected +) diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go index 4ebc2d425..bc4969d47 100644 --- a/mock_ack_frame_source_test.go +++ b/mock_ack_frame_source_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: AckFrameSource) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_ack_frame_source_test.go github.com/refraction-networking/uquic AckFrameSource +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockAckFrameSource) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 boo } // GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0, arg1) } diff --git a/mock_batch_conn_test.go b/mock_batch_conn_test.go index 9621e7b4e..7d3319574 100644 --- a/mock_batch_conn_test.go +++ b/mock_batch_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: sys_conn_oob.go - +// +// Generated by this command: +// +// mockgen -package quic -self_package github.com/refraction-networking/uquic -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockBatchConn) ReadBatch(ms []ipv4.Message, flags int) (int, error) { } // ReadBatch indicates an expected call of ReadBatch. -func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags interface{}) *gomock.Call { +func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBatch", reflect.TypeOf((*MockBatchConn)(nil).ReadBatch), ms, flags) } diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index d857fd580..45056041b 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: ConnRunner) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_conn_runner_test.go github.com/refraction-networking/uquic ConnRunner +// // Package quic is a generated GoMock package. package quic @@ -43,7 +47,7 @@ func (m *MockConnRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) boo } // Add indicates an expected call of Add. -func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) } @@ -55,7 +59,7 @@ func (m *MockConnRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 p } // AddResetToken indicates an expected call of AddResetToken. -func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) } @@ -69,7 +73,7 @@ func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) prot } // GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) } @@ -81,7 +85,7 @@ func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { } // Remove indicates an expected call of Remove. -func (mr *MockConnRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Remove(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) } @@ -93,7 +97,7 @@ func (m *MockConnRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { } // RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) } @@ -105,7 +109,7 @@ func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 pr } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2) } @@ -117,7 +121,7 @@ func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) { } // Retire indicates an expected call of Retire. -func (mr *MockConnRunnerMockRecorder) Retire(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Retire(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) } diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go index eb2e90fa1..c8a601012 100644 --- a/mock_crypto_data_handler_test.go +++ b/mock_crypto_data_handler_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: CryptoDataHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_data_handler_test.go github.com/refraction-networking/uquic CryptoDataHandler +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.Encrypt } // HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) } diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 002b1d17c..cc01d35a7 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: CryptoStream) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_stream_test.go github.com/refraction-networking/uquic CryptoStream +// // Package quic is a generated GoMock package. package quic @@ -72,7 +76,7 @@ func (m *MockCryptoStream) HandleCryptoFrame(arg0 *wire.CryptoFrame) error { } // HandleCryptoFrame indicates an expected call of HandleCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).HandleCryptoFrame), arg0) } @@ -100,7 +104,7 @@ func (m *MockCryptoStream) PopCryptoFrame(arg0 protocol.ByteCount) *wire.CryptoF } // PopCryptoFrame indicates an expected call of PopCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).PopCryptoFrame), arg0) } @@ -115,7 +119,7 @@ func (m *MockCryptoStream) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0) } diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 7319739cf..5454890f5 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: FrameSource) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_frame_source_test.go github.com/refraction-networking/uquic FrameSource +// // Package quic is a generated GoMock package. package quic @@ -45,7 +49,7 @@ func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 prot } // AppendControlFrames indicates an expected call of AppendControlFrames. -func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendControlFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendControlFrames), arg0, arg1, arg2) } @@ -60,7 +64,7 @@ func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.StreamFrame, arg1 } // AppendStreamFrames indicates an expected call of AppendStreamFrames. -func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1, arg2) } diff --git a/mock_mtu_discoverer_test.go b/mock_mtu_discoverer_test.go index a1c15fb1c..1af111f4f 100644 --- a/mock_mtu_discoverer_test.go +++ b/mock_mtu_discoverer_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: MTUDiscoverer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_mtu_discoverer_test.go github.com/refraction-networking/uquic MTUDiscoverer +// // Package quic is a generated GoMock package. package quic @@ -74,7 +78,7 @@ func (m *MockMTUDiscoverer) ShouldSendProbe(arg0 time.Time) bool { } // ShouldSendProbe indicates an expected call of ShouldSendProbe. -func (mr *MockMTUDiscovererMockRecorder) ShouldSendProbe(arg0 interface{}) *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) ShouldSendProbe(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMTUDiscoverer)(nil).ShouldSendProbe), arg0) } @@ -86,7 +90,7 @@ func (m *MockMTUDiscoverer) Start(arg0 protocol.ByteCount) { } // Start indicates an expected call of Start. -func (mr *MockMTUDiscovererMockRecorder) Start(arg0 interface{}) *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) Start(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockMTUDiscoverer)(nil).Start), arg0) } diff --git a/mock_packer_test.go b/mock_packer_test.go index 409a7a8ca..f73f6f5c7 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: Packer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_packer_test.go github.com/refraction-networking/uquic Packer +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockPacker) AppendPacket(arg0 *packetBuffer, arg1 protocol.ByteCount, a } // AppendPacket indicates an expected call of AppendPacket. -func (mr *MockPackerMockRecorder) AppendPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) AppendPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendPacket", reflect.TypeOf((*MockPacker)(nil).AppendPacket), arg0, arg1, arg2) } @@ -61,7 +65,7 @@ func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 pr } // MaybePackProbePacket indicates an expected call of MaybePackProbePacket. -func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0, arg1, arg2) } @@ -77,7 +81,7 @@ func (m *MockPacker) PackAckOnlyPacket(arg0 protocol.ByteCount, arg1 protocol.Ve } // PackAckOnlyPacket indicates an expected call of PackAckOnlyPacket. -func (mr *MockPackerMockRecorder) PackAckOnlyPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackAckOnlyPacket(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAckOnlyPacket", reflect.TypeOf((*MockPacker)(nil).PackAckOnlyPacket), arg0, arg1) } @@ -92,7 +96,7 @@ func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 prot } // PackApplicationClose indicates an expected call of PackApplicationClose. -func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0, arg1, arg2) } @@ -107,7 +111,7 @@ func (m *MockPacker) PackCoalescedPacket(arg0 bool, arg1 protocol.ByteCount, arg } // PackCoalescedPacket indicates an expected call of PackCoalescedPacket. -func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0, arg1, arg2) } @@ -122,7 +126,7 @@ func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protoco } // PackConnectionClose indicates an expected call of PackConnectionClose. -func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0, arg1, arg2) } @@ -138,7 +142,7 @@ func (m *MockPacker) PackMTUProbePacket(arg0 ackhandler.Frame, arg1 protocol.Byt } // PackMTUProbePacket indicates an expected call of PackMTUProbePacket. -func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), arg0, arg1, arg2) } @@ -150,7 +154,7 @@ func (m *MockPacker) SetToken(arg0 []byte) { } // SetToken indicates an expected call of SetToken. -func (mr *MockPackerMockRecorder) SetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) SetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0) } diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index b9e28eebe..ea49f66bc 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: PacketHandlerManager) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager +// // Package quic is a generated GoMock package. package quic @@ -43,7 +47,7 @@ func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHa } // Add indicates an expected call of Add. -func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } @@ -55,7 +59,7 @@ func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetTok } // AddResetToken indicates an expected call of AddResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) } @@ -69,7 +73,7 @@ func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionI } // AddWithConnID indicates an expected call of AddWithConnID. -func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) } @@ -81,7 +85,7 @@ func (m *MockPacketHandlerManager) Close(arg0 error) { } // Close indicates an expected call of Close. -func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) } @@ -108,7 +112,7 @@ func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandle } // Get indicates an expected call of Get. -func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) } @@ -123,7 +127,7 @@ func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetT } // GetByResetToken indicates an expected call of GetByResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0) } @@ -137,7 +141,7 @@ func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.Connecti } // GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) } @@ -149,7 +153,7 @@ func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { } // Remove indicates an expected call of Remove. -func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) } @@ -161,7 +165,7 @@ func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessReset } // RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) } @@ -173,7 +177,7 @@ func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionI } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2) } @@ -185,7 +189,7 @@ func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { } // Retire indicates an expected call of Retire. -func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) } diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index 8402d5e46..852d1f004 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: PacketHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler +// // Package quic is a generated GoMock package. package quic @@ -41,7 +45,7 @@ func (m *MockPacketHandler) destroy(arg0 error) { } // destroy indicates an expected call of destroy. -func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerMockRecorder) destroy(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) } @@ -67,7 +71,7 @@ func (m *MockPacketHandler) handlePacket(arg0 receivedPacket) { } // handlePacket indicates an expected call of handlePacket. -func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) } diff --git a/mock_packetconn_test.go b/mock_packetconn_test.go index c8e20bf28..6e317e3ed 100644 --- a/mock_packetconn_test.go +++ b/mock_packetconn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: net (interfaces: PacketConn) - +// +// Generated by this command: +// +// mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_packetconn_test.go net PacketConn +// // Package quic is a generated GoMock package. package quic @@ -74,7 +78,7 @@ func (m *MockPacketConn) ReadFrom(arg0 []byte) (int, net.Addr, error) { } // ReadFrom indicates an expected call of ReadFrom. -func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0) } @@ -88,7 +92,7 @@ func (m *MockPacketConn) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0) } @@ -102,7 +106,7 @@ func (m *MockPacketConn) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0) } @@ -116,7 +120,7 @@ func (m *MockPacketConn) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0) } @@ -131,7 +135,7 @@ func (m *MockPacketConn) WriteTo(arg0 []byte, arg1 net.Addr) (int, error) { } // WriteTo indicates an expected call of WriteTo. -func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1) } diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index 7ce24333d..fc7f1efaa 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: QUICConn) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_quic_conn_test.go github.com/refraction-networking/uquic QUICConn +// // Package quic is a generated GoMock package. package quic @@ -47,7 +51,7 @@ func (m *MockQUICConn) AcceptStream(arg0 context.Context) (Stream, error) { } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockQUICConnMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) AcceptStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptStream), arg0) } @@ -62,7 +66,7 @@ func (m *MockQUICConn) AcceptUniStream(arg0 context.Context) (ReceiveStream, err } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockQUICConnMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptUniStream), arg0) } @@ -76,7 +80,7 @@ func (m *MockQUICConn) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 strin } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockQUICConnMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQUICConn)(nil).CloseWithError), arg0, arg1) } @@ -190,7 +194,7 @@ func (m *MockQUICConn) OpenStreamSync(arg0 context.Context) (Stream, error) { } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockQUICConnMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenStreamSync), arg0) } @@ -220,7 +224,7 @@ func (m *MockQUICConn) OpenUniStreamSync(arg0 context.Context) (SendStream, erro } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenUniStreamSync), arg0) } @@ -235,7 +239,7 @@ func (m *MockQUICConn) ReceiveMessage(arg0 context.Context) ([]byte, error) { } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage), arg0) } @@ -263,7 +267,7 @@ func (m *MockQUICConn) SendMessage(arg0 []byte) error { } // SendMessage indicates an expected call of SendMessage. -func (mr *MockQUICConnMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) SendMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQUICConn)(nil).SendMessage), arg0) } @@ -275,7 +279,7 @@ func (m *MockQUICConn) destroy(arg0 error) { } // destroy indicates an expected call of destroy. -func (mr *MockQUICConnMockRecorder) destroy(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) destroy(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQUICConn)(nil).destroy), arg0) } @@ -315,7 +319,7 @@ func (m *MockQUICConn) handlePacket(arg0 receivedPacket) { } // handlePacket indicates an expected call of handlePacket. -func (mr *MockQUICConnMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) handlePacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQUICConn)(nil).handlePacket), arg0) } diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go index 0a1a0f3ac..f7862c62c 100644 --- a/mock_raw_conn_test.go +++ b/mock_raw_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: RawConn) - +// Source: github.com/refraction-networking/uquic (interfaces: RawConn) +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_raw_conn_test.go github.com/refraction-networking/uquic RawConn +// // Package quic is a generated GoMock package. package quic @@ -9,6 +13,7 @@ import ( reflect "reflect" time "time" + protocol "github.com/refraction-networking/uquic/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -87,24 +92,24 @@ func (m *MockRawConn) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) } // WritePacket mocks base method. -func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte) (int, error) { +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16, arg4 protocol.ECN) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // WritePacket indicates an expected call of WritePacket. -func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4) } // capabilities mocks base method. diff --git a/mock_receive_stream_internal_test.go b/mock_receive_stream_internal_test.go index 6237feef5..48b9063c8 100644 --- a/mock_receive_stream_internal_test.go +++ b/mock_receive_stream_internal_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: ReceiveStreamI) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_receive_stream_internal_test.go github.com/refraction-networking/uquic ReceiveStreamI +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockReceiveStreamI) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0) } @@ -59,7 +63,7 @@ func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockReceiveStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) Read(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0) } @@ -73,7 +77,7 @@ func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0) } @@ -99,7 +103,7 @@ func (m *MockReceiveStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0) } @@ -127,7 +131,7 @@ func (m *MockReceiveStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) } // handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleResetStreamFrame), arg0) } @@ -141,7 +145,7 @@ func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { } // handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0) } diff --git a/mock_sealing_manager_test.go b/mock_sealing_manager_test.go index 1b9c9214c..6a691b3be 100644 --- a/mock_sealing_manager_test.go +++ b/mock_sealing_manager_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: SealingManager) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_sealing_manager_test.go github.com/refraction-networking/uquic SealingManager +// // Package quic is a generated GoMock package. package quic diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index f55feaeb2..92fe43a96 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: SendConn) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_conn_test.go github.com/refraction-networking/uquic SendConn +// // Package quic is a generated GoMock package. package quic @@ -78,17 +82,17 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { } // Write mocks base method. -func (m *MockSendConn) Write(arg0 []byte, arg1 protocol.ByteCount) error { +func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0, arg1) + ret := m.ctrl.Call(m, "Write", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockSendConnMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2) } // capabilities mocks base method. diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index 46bf75659..6b878c943 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: SendStreamI) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_stream_internal_test.go github.com/refraction-networking/uquic SendStreamI +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockSendStreamI) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0) } @@ -88,7 +92,7 @@ func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0) } @@ -117,7 +121,7 @@ func (m *MockSendStreamI) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockSendStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0) } @@ -129,7 +133,7 @@ func (m *MockSendStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0) } @@ -141,7 +145,7 @@ func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { } // handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) } @@ -171,7 +175,7 @@ func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol. } // popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0, arg1) } @@ -183,7 +187,7 @@ func (m *MockSendStreamI) updateSendWindow(arg0 protocol.ByteCount) { } // updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockSendStreamI)(nil).updateSendWindow), arg0) } diff --git a/mock_sender_test.go b/mock_sender_test.go index 4c9899374..8e9291b56 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: Sender) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_sender_test.go github.com/refraction-networking/uquic Sender +// // Package quic is a generated GoMock package. package quic @@ -75,15 +79,15 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call { } // Send mocks base method. -func (m *MockSender) Send(arg0 *packetBuffer, arg1 protocol.ByteCount) { +func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16, arg2 protocol.ECN) { m.ctrl.T.Helper() - m.ctrl.Call(m, "Send", arg0, arg1) + m.ctrl.Call(m, "Send", arg0, arg1, arg2) } // Send indicates an expected call of Send. -func (mr *MockSenderMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2) } // WouldBlock mocks base method. diff --git a/mock_stream_getter_test.go b/mock_stream_getter_test.go index 3ed365151..21a873c98 100644 --- a/mock_stream_getter_test.go +++ b/mock_stream_getter_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: StreamGetter) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_getter_test.go github.com/refraction-networking/uquic StreamGetter +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (recei } // GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0) } @@ -59,7 +63,7 @@ func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStre } // GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0) } diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 369f43be1..3c482fb23 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: StreamI) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_internal_test.go github.com/refraction-networking/uquic StreamI +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockStreamI) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) CancelRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0) } @@ -58,7 +62,7 @@ func (m *MockStreamI) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) CancelWrite(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0) } @@ -101,7 +105,7 @@ func (m *MockStreamI) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) Read(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0) } @@ -115,7 +119,7 @@ func (m *MockStreamI) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamIMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0) } @@ -129,7 +133,7 @@ func (m *MockStreamI) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0) } @@ -143,7 +147,7 @@ func (m *MockStreamI) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) } @@ -172,7 +176,7 @@ func (m *MockStreamI) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0) } @@ -184,7 +188,7 @@ func (m *MockStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) } @@ -212,7 +216,7 @@ func (m *MockStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error } // handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleResetStreamFrame), arg0) } @@ -224,7 +228,7 @@ func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { } // handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0) } @@ -238,7 +242,7 @@ func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { } // handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) } @@ -268,7 +272,7 @@ func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Vers } // popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockStreamIMockRecorder) popStreamFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0, arg1) } @@ -280,7 +284,7 @@ func (m *MockStreamI) updateSendWindow(arg0 protocol.ByteCount) { } // updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockStreamI)(nil).updateSendWindow), arg0) } diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 20e69a837..3d5e18de2 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: StreamManager) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_manager_test.go github.com/refraction-networking/uquic StreamManager +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) { } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) } @@ -61,7 +65,7 @@ func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) } @@ -73,7 +77,7 @@ func (m *MockStreamManager) CloseWithError(arg0 error) { } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) } @@ -87,7 +91,7 @@ func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error { } // DeleteStream indicates an expected call of DeleteStream. -func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) } @@ -102,7 +106,7 @@ func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (rece } // GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) } @@ -117,7 +121,7 @@ func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStr } // GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) } @@ -129,7 +133,7 @@ func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) { } // HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame. -func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) } @@ -159,7 +163,7 @@ func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) } @@ -189,7 +193,7 @@ func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) } @@ -213,7 +217,7 @@ func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { } // UpdateLimits indicates an expected call of UpdateLimits. -func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) } diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index 2c0608568..f3165076a 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: StreamSender) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_sender_test.go github.com/refraction-networking/uquic StreamSender +// // Package quic is a generated GoMock package. package quic @@ -42,7 +46,7 @@ func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) { } // onHasStreamData indicates an expected call of onHasStreamData. -func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) } @@ -54,7 +58,7 @@ func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { } // onStreamCompleted indicates an expected call of onStreamCompleted. -func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) } @@ -66,7 +70,7 @@ func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { } // queueControlFrame indicates an expected call of queueControlFrame. -func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) } diff --git a/mock_token_store_test.go b/mock_token_store_test.go index c5d286431..3d921b2ea 100644 --- a/mock_token_store_test.go +++ b/mock_token_store_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: TokenStore) - +// +// Generated by this command: +// +// mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_token_store_test.go github.com/refraction-networking/uquic TokenStore +// // Package quic is a generated GoMock package. package quic @@ -42,7 +46,7 @@ func (m *MockTokenStore) Pop(arg0 string) *ClientToken { } // Pop indicates an expected call of Pop. -func (mr *MockTokenStoreMockRecorder) Pop(arg0 interface{}) *gomock.Call { +func (mr *MockTokenStoreMockRecorder) Pop(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pop", reflect.TypeOf((*MockTokenStore)(nil).Pop), arg0) } @@ -54,7 +58,7 @@ func (m *MockTokenStore) Put(arg0 string, arg1 *ClientToken) { } // Put indicates an expected call of Put. -func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockTokenStore)(nil).Put), arg0, arg1) } diff --git a/mock_unknown_packet_handler_test.go b/mock_unknown_packet_handler_test.go deleted file mode 100644 index fa702529c..000000000 --- a/mock_unknown_packet_handler_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/refraction-networking/uquic (interfaces: UnknownPacketHandler) - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface. -type MockUnknownPacketHandler struct { - ctrl *gomock.Controller - recorder *MockUnknownPacketHandlerMockRecorder -} - -// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler. -type MockUnknownPacketHandlerMockRecorder struct { - mock *MockUnknownPacketHandler -} - -// NewMockUnknownPacketHandler creates a new mock instance. -func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler { - mock := &MockUnknownPacketHandler{ctrl: ctrl} - mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder { - return m.recorder -} - -// handlePacket mocks base method. -func (m *MockUnknownPacketHandler) handlePacket(arg0 receivedPacket) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "handlePacket", arg0) -} - -// handlePacket indicates an expected call of handlePacket. -func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0) -} - -// setCloseError mocks base method. -func (m *MockUnknownPacketHandler) setCloseError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "setCloseError", arg0) -} - -// setCloseError indicates an expected call of setCloseError. -func (mr *MockUnknownPacketHandlerMockRecorder) setCloseError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setCloseError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).setCloseError), arg0) -} diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index 89dbf9c87..56cbe69cc 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/refraction-networking/uquic (interfaces: Unpacker) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_unpacker_test.go github.com/refraction-networking/uquic Unpacker +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockUnpacker) UnpackLongHeader(arg0 *wire.Header, arg1 time.Time, arg2 } // UnpackLongHeader indicates an expected call of UnpackLongHeader. -func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), arg0, arg1, arg2, arg3) } @@ -64,7 +68,7 @@ func (m *MockUnpacker) UnpackShortHeader(arg0 time.Time, arg1 []byte) (protocol. } // UnpackShortHeader indicates an expected call of UnpackShortHeader. -func (mr *MockUnpackerMockRecorder) UnpackShortHeader(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockUnpackerMockRecorder) UnpackShortHeader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackShortHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackShortHeader), arg0, arg1) } diff --git a/mockgen.go b/mockgen.go index 5adfbf591..b6267e5c3 100644 --- a/mockgen.go +++ b/mockgen.go @@ -62,9 +62,6 @@ type QUICConn = quicConn //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler" type PacketHandler = packetHandler -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unknown_packet_handler_test.go github.com/refraction-networking/uquic UnknownPacketHandler" -type UnknownPacketHandler = unknownPacketHandler - //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager" type PacketHandlerManager = packetHandlerManager diff --git a/packet_handler_map.go b/packet_handler_map.go index 41287c2aa..9c4ebf876 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -4,7 +4,6 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" - "errors" "hash" "io" "net" @@ -21,14 +20,17 @@ type connCapabilities struct { DF bool // GSO (Generic Segmentation Offload) supported GSO bool + // ECN (Explicit Congestion Notifications) supported + ECN bool } // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (receivedPacket, error) // WritePacket writes a packet on the wire. - // If GSO is enabled, it's the caller's responsibility to set the correct control message. - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + // gsoSize is the size of a single packet, or 0 to disable GSO. + // It is invalid to set gsoSize if capabilities.GSO is not set. + WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer @@ -42,13 +44,6 @@ type closePacket struct { info packetInfo } -type unknownPacketHandler interface { - handlePacket(receivedPacket) - setCloseError(error) -} - -var errListenerAlreadySet = errors.New("listener already set") - type packetHandlerMap struct { mutex sync.Mutex handlers map[protocol.ConnectionID]packetHandler diff --git a/packet_packer.go b/packet_packer.go index e59304ee6..d1b548898 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -1,6 +1,8 @@ package quic import ( + crand "crypto/rand" + "encoding/binary" "errors" "fmt" @@ -9,6 +11,7 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/wire" + "golang.org/x/exp/rand" ) var errNothingToPack = errors.New("nothing to pack") @@ -67,6 +70,11 @@ type coalescedPacket struct { shortHdrPacket *shortHeaderPacket } +// IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets). +func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool { + return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil +} + func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). switch p.header.Type { @@ -122,6 +130,7 @@ type packetPacker struct { acks ackFrameSource datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue + rand rand.Rand numNonAckElicitingAcks int } @@ -140,6 +149,9 @@ func newPacketPacker( datagramQueue *datagramQueue, perspective protocol.Perspective, ) *packetPacker { + var b [8]byte + _, _ = crand.Read(b[:]) + return &packetPacker{ cryptoSetup: cryptoSetup, getDestConnID: getDestConnID, @@ -151,6 +163,7 @@ func newPacketPacker( perspective: perspective, framer: framer, acks: acks, + rand: *rand.New(rand.NewSource(binary.BigEndian.Uint64(b[:]))), pnManager: packetNumberManager, } } @@ -832,6 +845,8 @@ func (p *packetPacker) appendShortHeaderPacket( }, nil } +// appendPacketPayload serializes the payload of a packet into the raw byte slice. +// It modifies the order of payload.frames. func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { payloadOffset := len(raw) if pl.ack != nil { @@ -844,6 +859,11 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr if paddingLen > 0 { raw = append(raw, make([]byte, paddingLen)...) } + // Randomize the order of the control frames. + // This makes sure that the receiver doesn't rely on the order in which frames are packed. + if len(pl.frames) > 1 { + p.rand.Shuffle(len(pl.frames), func(i, j int) { pl.frames[i], pl.frames[j] = pl.frames[j], pl.frames[i] }) + } for _, f := range pl.frames { var err error raw, err = f.Frame.Append(raw, v) diff --git a/packet_packer_test.go b/packet_packer_test.go index 74e05a67a..13b20d63a 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -655,6 +655,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.longHdrPackets).To(HaveLen(1)) + Expect(packet.IsOnlyShortHeaderPacket()).To(BeFalse()) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] hdr, _, _, err := wire.ParsePacket(packet.buffer.Data) @@ -874,6 +875,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse()) parsePacket(p.buffer.Data) }) @@ -1047,6 +1049,7 @@ var _ = Describe("Packet packer", func() { packer.retransmissionQueue.addAppData(&wire.PingFrame{}) p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse()) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -1422,6 +1425,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket @@ -1448,6 +1452,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket diff --git a/qlog/event.go b/qlog/event.go index 1bc19f754..ad90824dc 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -158,6 +158,7 @@ type eventPacketSent struct { PayloadLength logging.ByteCount Frames frames IsCoalesced bool + ECN logging.ECN Trigger string } @@ -172,6 +173,9 @@ func (e eventPacketSent) MarshalJSONObject(enc *gojay.Encoder) { enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) enc.ArrayKeyOmitEmpty("frames", e.Frames) enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) + if e.ECN != logging.ECNUnsupported { + enc.StringKey("ecn", ecn(e.ECN).String()) + } enc.StringKeyOmitEmpty("trigger", e.Trigger) } @@ -180,6 +184,7 @@ type eventPacketReceived struct { Length logging.ByteCount PayloadLength logging.ByteCount Frames frames + ECN logging.ECN IsCoalesced bool Trigger string } @@ -195,6 +200,9 @@ func (e eventPacketReceived) MarshalJSONObject(enc *gojay.Encoder) { enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) enc.ArrayKeyOmitEmpty("frames", e.Frames) enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) + if e.ECN != logging.ECNUnsupported { + enc.StringKey("ecn", ecn(e.ECN).String()) + } enc.StringKeyOmitEmpty("trigger", e.Trigger) } @@ -516,6 +524,20 @@ func (e eventCongestionStateUpdated) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("new", e.state.String()) } +type eventECNStateUpdated struct { + state logging.ECNState + trigger logging.ECNStateTrigger +} + +func (e eventECNStateUpdated) Category() category { return categoryRecovery } +func (e eventECNStateUpdated) Name() string { return "ecn_state_updated" } +func (e eventECNStateUpdated) IsNil() bool { return false } + +func (e eventECNStateUpdated) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("new", ecnState(e.state).String()) + enc.StringKeyOmitEmpty("trigger", ecnStateTrigger(e.trigger).String()) +} + type eventGeneric struct { name string msg string diff --git a/qlog/qlog.go b/qlog/qlog.go index d120c320a..e94da3e20 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -63,11 +63,9 @@ type connectionTracer struct { lastMetrics *metrics } -var _ logging.ConnectionTracer = &connectionTracer{} - // NewConnectionTracer creates a new tracer to record a qlog for a connection. -func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { - t := &connectionTracer{ +func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) *logging.ConnectionTracer { + t := connectionTracer{ w: w, perspective: p, odcid: odcid, @@ -76,7 +74,84 @@ func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protoco referenceTime: time.Now(), } go t.run() - return t + return &logging.ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { + t.StartedConnection(local, remote, srcConnID, destConnID) + }, + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + }, + ClosedConnection: func(e error) { t.ClosedConnection(e) }, + SentTransportParameters: func(tp *wire.TransportParameters) { t.SentTransportParameters(tp) }, + ReceivedTransportParameters: func(tp *wire.TransportParameters) { t.ReceivedTransportParameters(tp) }, + RestoredTransportParameters: func(tp *wire.TransportParameters) { t.RestoredTransportParameters(tp) }, + SentLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + }, + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedRetry: func(hdr *wire.Header) { + t.ReceivedRetry(hdr) + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + }, + BufferedPacket: func(pt logging.PacketType, size protocol.ByteCount) { + t.BufferedPacket(pt, size) + }, + DroppedPacket: func(pt logging.PacketType, size protocol.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(pt, size, reason) + }, + UpdatedMetrics: func(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + }, + LostPacket: func(encLevel protocol.EncryptionLevel, pn protocol.PacketNumber, lossReason logging.PacketLossReason) { + t.LostPacket(encLevel, pn, lossReason) + }, + UpdatedCongestionState: func(state logging.CongestionState) { + t.UpdatedCongestionState(state) + }, + UpdatedPTOCount: func(value uint32) { + t.UpdatedPTOCount(value) + }, + UpdatedKeyFromTLS: func(encLevel protocol.EncryptionLevel, pers protocol.Perspective) { + t.UpdatedKeyFromTLS(encLevel, pers) + }, + UpdatedKey: func(generation protocol.KeyPhase, remote bool) { + t.UpdatedKey(generation, remote) + }, + DroppedEncryptionLevel: func(encLevel protocol.EncryptionLevel) { + t.DroppedEncryptionLevel(encLevel) + }, + DroppedKey: func(generation protocol.KeyPhase) { + t.DroppedKey(generation) + }, + SetLossTimer: func(tt logging.TimerType, encLevel protocol.EncryptionLevel, timeout time.Time) { + t.SetLossTimer(tt, encLevel, timeout) + }, + LossTimerExpired: func(tt logging.TimerType, encLevel protocol.EncryptionLevel) { + t.LossTimerExpired(tt, encLevel) + }, + LossTimerCanceled: func() { + t.LossTimerCanceled() + }, + ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.ECNStateUpdated(state, trigger) + }, + Debug: func(name, msg string) { + t.Debug(name, msg) + }, + Close: func() { + t.Close() + }, + } } func (t *connectionTracer) run() { @@ -253,15 +328,33 @@ func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) * } } -func (t *connectionTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { - t.sentPacket(*transformLongHeader(hdr), packetSize, hdr.Length, ack, frames) -} - -func (t *connectionTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { - t.sentPacket(*transformShortHeader(hdr), packetSize, 0, ack, frames) -} - -func (t *connectionTracer) sentPacket(hdr gojay.MarshalerJSONObject, packetSize, payloadLen logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { +func (t *connectionTracer) SentLongHeaderPacket( + hdr *logging.ExtendedHeader, + size logging.ByteCount, + ecn logging.ECN, + ack *logging.AckFrame, + frames []logging.Frame, +) { + t.sentPacket(*transformLongHeader(hdr), size, hdr.Length, ecn, ack, frames) +} + +func (t *connectionTracer) SentShortHeaderPacket( + hdr *logging.ShortHeader, + size logging.ByteCount, + ecn logging.ECN, + ack *logging.AckFrame, + frames []logging.Frame, +) { + t.sentPacket(*transformShortHeader(hdr), size, 0, ecn, ack, frames) +} + +func (t *connectionTracer) sentPacket( + hdr gojay.MarshalerJSONObject, + size, payloadLen logging.ByteCount, + ecn logging.ECN, + ack *logging.AckFrame, + frames []logging.Frame, +) { numFrames := len(frames) if ack != nil { numFrames++ @@ -276,14 +369,15 @@ func (t *connectionTracer) sentPacket(hdr gojay.MarshalerJSONObject, packetSize, t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketSent{ Header: hdr, - Length: packetSize, + Length: size, PayloadLength: payloadLen, + ECN: ecn, Frames: fs, }) t.mutex.Unlock() } -func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, frames []logging.Frame) { +func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { fs := make([]frame, len(frames)) for i, f := range frames { fs[i] = frame{Frame: f} @@ -292,14 +386,15 @@ func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketReceived{ Header: header, - Length: packetSize, + Length: size, PayloadLength: hdr.Length, + ECN: ecn, Frames: fs, }) t.mutex.Unlock() } -func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, packetSize logging.ByteCount, frames []logging.Frame) { +func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { fs := make([]frame, len(frames)) for i, f := range frames { fs[i] = frame{Frame: f} @@ -308,8 +403,9 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, p t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketReceived{ Header: header, - Length: packetSize, - PayloadLength: packetSize - wire.ShortHeaderLen(hdr.DestConnectionID, hdr.PacketNumberLen), + Length: size, + PayloadLength: size - wire.ShortHeaderLen(hdr.DestConnectionID, hdr.PacketNumberLen), + ECN: ecn, Frames: fs, }) t.mutex.Unlock() @@ -482,6 +578,12 @@ func (t *connectionTracer) LossTimerCanceled() { t.mutex.Unlock() } +func (t *connectionTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventECNStateUpdated{state: state, trigger: trigger}) + t.mutex.Unlock() +} + func (t *connectionTracer) Debug(name, msg string) { t.mutex.Lock() t.recordEvent(time.Now(), &eventGeneric{ diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index 4581630bd..4638d200c 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -70,7 +70,7 @@ var _ = Describe("Tracing", func() { Context("connection tracer", func() { var ( - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer buf *bytes.Buffer ) @@ -419,6 +419,7 @@ var _ = Describe("Tracing", func() { PacketNumber: 1337, }, 987, + logging.ECNCE, nil, []logging.Frame{ &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, @@ -439,6 +440,7 @@ var _ = Describe("Tracing", func() { Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) Expect(ev).To(HaveKey("frames")) + Expect(ev).To(HaveKeyWithValue("ecn", "CE")) frames := ev["frames"].([]interface{}) Expect(frames).To(HaveLen(2)) Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_stream_data")) @@ -452,6 +454,7 @@ var _ = Describe("Tracing", func() { PacketNumber: 1337, }, 123, + logging.ECNUnsupported, &logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}}, []logging.Frame{&logging.MaxDataFrame{MaximumData: 987}}, ) @@ -461,6 +464,7 @@ var _ = Describe("Tracing", func() { Expect(raw).To(HaveKeyWithValue("length", float64(123))) Expect(raw).ToNot(HaveKey("payload_length")) Expect(ev).To(HaveKey("header")) + Expect(ev).ToNot(HaveKey("ecn")) hdr := ev["header"].(map[string]interface{}) Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) @@ -485,6 +489,7 @@ var _ = Describe("Tracing", func() { PacketNumber: 1337, }, 789, + logging.ECT0, []logging.Frame{ &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, @@ -498,6 +503,7 @@ var _ = Describe("Tracing", func() { raw := ev["raw"].(map[string]interface{}) Expect(raw).To(HaveKeyWithValue("length", float64(789))) Expect(raw).To(HaveKeyWithValue("payload_length", float64(1234))) + Expect(ev).To(HaveKeyWithValue("ecn", "ECT(0)")) Expect(ev).To(HaveKey("header")) hdr := ev["header"].(map[string]interface{}) Expect(hdr).To(HaveKeyWithValue("packet_type", "initial")) @@ -520,6 +526,7 @@ var _ = Describe("Tracing", func() { tracer.ReceivedShortHeaderPacket( shdr, 789, + logging.ECT1, []logging.Frame{ &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, @@ -533,6 +540,7 @@ var _ = Describe("Tracing", func() { raw := ev["raw"].(map[string]interface{}) Expect(raw).To(HaveKeyWithValue("length", float64(789))) Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-(1+8+3)))) + Expect(ev).To(HaveKeyWithValue("ecn", "ECT(1)")) Expect(ev).To(HaveKey("header")) hdr := ev["header"].(map[string]interface{}) Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) @@ -857,6 +865,27 @@ var _ = Describe("Tracing", func() { Expect(ev).To(HaveKeyWithValue("event_type", "cancelled")) }) + It("records an ECN state transition, without a trigger", func() { + tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:ecn_state_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("new", "unknown")) + }) + + It("records an ECN state transition, with a trigger", func() { + tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:ecn_state_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("new", "failed")) + Expect(ev).To(HaveKeyWithValue("trigger", "ACK doesn't contain ECN marks")) + }) + It("records a generic event", func() { tracer.Debug("foo", "bar") entry := exportAndParseSingle() diff --git a/qlog/types.go b/qlog/types.go index ff6db9797..ed1c98ac6 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -312,3 +312,61 @@ func (s congestionState) String() string { return "unknown congestion state" } } + +type ecn logging.ECN + +func (e ecn) String() string { + //nolint:exhaustive // The unsupported value is never logged. + switch logging.ECN(e) { + case logging.ECTNot: + return "Not-ECT" + case logging.ECT0: + return "ECT(0)" + case logging.ECT1: + return "ECT(1)" + case logging.ECNCE: + return "CE" + default: + return "unknown ECN" + } +} + +type ecnState logging.ECNState + +func (e ecnState) String() string { + switch logging.ECNState(e) { + case logging.ECNStateTesting: + return "testing" + case logging.ECNStateUnknown: + return "unknown" + case logging.ECNStateCapable: + return "capable" + case logging.ECNStateFailed: + return "failed" + default: + return "unknown ECN state" + } +} + +type ecnStateTrigger logging.ECNStateTrigger + +func (e ecnStateTrigger) String() string { + switch logging.ECNStateTrigger(e) { + case logging.ECNTriggerNoTrigger: + return "" + case logging.ECNFailedNoECNCounts: + return "ACK doesn't contain ECN marks" + case logging.ECNFailedDecreasedECNCounts: + return "ACK decreases ECN counts" + case logging.ECNFailedLostAllTestingPackets: + return "all ECN testing packets declared lost" + case logging.ECNFailedMoreECNCountsThanSent: + return "ACK contains more ECN counts than ECN-marked packets sent" + case logging.ECNFailedTooFewECNCounts: + return "ACK contains fewer new ECN counts than acknowledged ECN-marked packets" + case logging.ECNFailedManglingDetected: + return "ECN mangling detected" + default: + return "unknown ECN state trigger" + } +} diff --git a/qlog/types_test.go b/qlog/types_test.go index 2b8312f6f..df613b7c3 100644 --- a/qlog/types_test.go +++ b/qlog/types_test.go @@ -127,4 +127,31 @@ var _ = Describe("Types", func() { Expect(congestionState(logging.CongestionStateApplicationLimited).String()).To(Equal("application_limited")) Expect(congestionState(logging.CongestionStateRecovery).String()).To(Equal("recovery")) }) + + It("has a string representation for the ECN bits", func() { + Expect(ecn(logging.ECT0).String()).To(Equal("ECT(0)")) + Expect(ecn(logging.ECT1).String()).To(Equal("ECT(1)")) + Expect(ecn(logging.ECNCE).String()).To(Equal("CE")) + Expect(ecn(logging.ECTNot).String()).To(Equal("Not-ECT")) + Expect(ecn(42).String()).To(Equal("unknown ECN")) + }) + + It("has a string representation for the ECN state", func() { + Expect(ecnState(logging.ECNStateTesting).String()).To(Equal("testing")) + Expect(ecnState(logging.ECNStateUnknown).String()).To(Equal("unknown")) + Expect(ecnState(logging.ECNStateFailed).String()).To(Equal("failed")) + Expect(ecnState(logging.ECNStateCapable).String()).To(Equal("capable")) + Expect(ecnState(42).String()).To(Equal("unknown ECN state")) + }) + + It("has a string representation for the ECN state trigger", func() { + Expect(ecnStateTrigger(logging.ECNTriggerNoTrigger).String()).To(Equal("")) + Expect(ecnStateTrigger(logging.ECNFailedNoECNCounts).String()).To(Equal("ACK doesn't contain ECN marks")) + Expect(ecnStateTrigger(logging.ECNFailedDecreasedECNCounts).String()).To(Equal("ACK decreases ECN counts")) + Expect(ecnStateTrigger(logging.ECNFailedLostAllTestingPackets).String()).To(Equal("all ECN testing packets declared lost")) + Expect(ecnStateTrigger(logging.ECNFailedMoreECNCountsThanSent).String()).To(Equal("ACK contains more ECN counts than ECN-marked packets sent")) + Expect(ecnStateTrigger(logging.ECNFailedTooFewECNCounts).String()).To(Equal("ACK contains fewer new ECN counts than acknowledged ECN-marked packets")) + Expect(ecnStateTrigger(logging.ECNFailedManglingDetected).String()).To(Equal("ECN mangling detected")) + Expect(ecnStateTrigger(42).String()).To(Equal("unknown ECN state trigger")) + }) }) diff --git a/send_conn.go b/send_conn.go index a3feaf628..29090e320 100644 --- a/send_conn.go +++ b/send_conn.go @@ -1,8 +1,6 @@ package quic import ( - "fmt" - "math" "net" "github.com/refraction-networking/uquic/internal/protocol" @@ -11,7 +9,7 @@ import ( // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write(b []byte, size protocol.ByteCount) error + Write(b []byte, gsoSize uint16, ecn protocol.ECN) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr @@ -27,8 +25,7 @@ type sconn struct { logger utils.Logger - info packetInfo - oob []byte + packetInfoOOB []byte // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. gotGSOError bool } @@ -46,33 +43,20 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge } oob := info.OOB() - // add 32 bytes, so we can add the UDP_SEGMENT msg + // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) - oob = append(oob, make([]byte, 32)...) - oob = oob[:l] + oob = append(oob, make([]byte, 64)...)[:l] return &sconn{ - rawConn: c, - localAddr: localAddr, - remoteAddr: remote, - info: info, - oob: oob, - logger: logger, + rawConn: c, + localAddr: localAddr, + remoteAddr: remote, + packetInfoOOB: oob, + logger: logger, } } -func (c *sconn) Write(p []byte, size protocol.ByteCount) error { - if !c.capabilities().GSO { - if protocol.ByteCount(len(p)) != size { - panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size)) - } - _, err := c.WritePacket(p, c.remoteAddr, c.oob) - return err - } - // GSO is supported. Append the control message and send. - if size > math.MaxUint16 { - panic("size overflow") - } - _, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size))) +func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { + _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true @@ -82,10 +66,10 @@ func (c *sconn) Write(p []byte, size protocol.ByteCount) error { // send out the packets one by one for len(p) > 0 { l := len(p) - if l > int(size) { - l = int(size) + if l > int(gsoSize) { + l = int(gsoSize) } - if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { return err } p = p[l:] diff --git a/send_conn_test.go b/send_conn_test.go index 382c0ccaf..5c23aa7ca 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -3,7 +3,9 @@ package quic import ( "net" "net/netip" + "runtime" + "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" . "github.com/onsi/ginkgo/v2" @@ -35,48 +37,43 @@ var _ = Describe("Connection (for sending packets)", func() { Expect(c.LocalAddr().String()).To(Equal("127.0.0.42:1234")) }) - if platformSupportsGSO { - It("writes with GSO", func() { + // We're not using an OOB conn on windows, and packetInfo.OOB() always returns an empty slice. + if runtime.GOOS != "windows" { + It("sets the OOB", func() { rawConn := NewMockRawConn(mockCtrl) rawConn.EXPECT().LocalAddr() - rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() - c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).Do(func(_ []byte, _ net.Addr, oob []byte) { - msg := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oob).To(Equal(msg)) - }) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + rawConn.EXPECT().capabilities().AnyTimes() + pi := packetInfo{addr: netip.IPv6Loopback()} + Expect(pi.OOB()).ToNot(BeEmpty()) + c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0), protocol.ECT1) + Expect(c.Write([]byte("foobar"), 0, protocol.ECT1)).To(Succeed()) }) + } - It("disables GSO if writing fails", func() { + It("writes", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3), protocol.ECNCE) + Expect(c.Write([]byte("foobar"), 3, protocol.ECNCE)).To(Succeed()) + }) + + if platformSupportsGSO { + It("disables GSO if sending fails", func() { rawConn := NewMockRawConn(mockCtrl) rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) Expect(c.capabilities().GSO).To(BeTrue()) gomock.InOrder( - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).DoAndReturn(func(_ []byte, _ net.Addr, oob []byte) (int, error) { - msg := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oob).To(Equal(msg)) - return 0, errGSO - }), - rawConn.EXPECT().WritePacket([]byte("foo"), remoteAddr, gomock.Len(0)).Return(3, nil), - rawConn.EXPECT().WritePacket([]byte("bar"), remoteAddr, gomock.Len(0)).Return(3, nil), + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4), protocol.ECNCE).Return(0, errGSO), + rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(4, nil), + rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(2, nil), ) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) - Expect(c.capabilities().GSO).To(BeFalse()) // GSO support is now disabled - // make sure we actually enforce that - Expect(func() { c.Write([]byte("foobar"), 3) }).To(PanicWith("inconsistent packet size (6 vs 3)")) - }) - } else { - It("writes without GSO", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - rawConn := NewMockRawConn(mockCtrl) - rawConn.EXPECT().LocalAddr() - rawConn.EXPECT().capabilities() - c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Len(0)) - Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) + Expect(c.Write([]byte("foobar"), 4, protocol.ECNCE)).To(Succeed()) + Expect(c.capabilities().GSO).To(BeFalse()) }) } }) diff --git a/send_queue.go b/send_queue.go index ffcd25533..d18ed6811 100644 --- a/send_queue.go +++ b/send_queue.go @@ -3,7 +3,7 @@ package quic import "github.com/refraction-networking/uquic/internal/protocol" type sender interface { - Send(p *packetBuffer, packetSize protocol.ByteCount) + Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) Run() error WouldBlock() bool Available() <-chan struct{} @@ -11,8 +11,9 @@ type sender interface { } type queueEntry struct { - buf *packetBuffer - size protocol.ByteCount + buf *packetBuffer + gsoSize uint16 + ecn protocol.ECN } type sendQueue struct { @@ -40,9 +41,9 @@ func newSendQueue(conn sendConn) sender { // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) { +func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) { select { - case h.queue <- queueEntry{buf: p, size: size}: + case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -77,7 +78,7 @@ func (h *sendQueue) Run() error { // make sure that all queued packets are actually sent out shouldClose = true case e := <-h.queue: - if err := h.conn.Write(e.buf.Data, e.size); err != nil { + if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and diff --git a/send_queue_test.go b/send_queue_test.go index 6abe56123..9dd64847d 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -28,10 +28,10 @@ var _ = Describe("Send Queue", func() { It("sends a packet", func() { p := getPacket([]byte("foobar")) - q.Send(p, 10) // make sure the packet size is passed through to the conn + q.Send(p, 10, protocol.ECT1) // make sure the packet size is passed through to the conn written := make(chan struct{}) - c.EXPECT().Write([]byte("foobar"), protocol.ByteCount(10)).Do(func([]byte, protocol.ByteCount) { close(written) }) + c.EXPECT().Write([]byte("foobar"), uint16(10), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -47,19 +47,19 @@ var _ = Describe("Send Queue", func() { It("panics when Send() is called although there's no space in the queue", func() { for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) } Expect(q.WouldBlock()).To(BeTrue()) - Expect(func() { q.Send(getPacket([]byte("raboof")), 6) }).To(Panic()) + Expect(func() { q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) }).To(Panic()) }) It("signals when sending is possible again", func() { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar1")), 6) + q.Send(getPacket([]byte("foobar1")), 6, protocol.ECNNon) Consistently(q.Available()).ShouldNot(Receive()) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).MinTimes(1).MaxTimes(2) + c.EXPECT().Write(gomock.Any(), gomock.Any(), protocol.ECNNon).MinTimes(1).MaxTimes(2) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -69,7 +69,7 @@ var _ = Describe("Send Queue", func() { Eventually(q.Available()).Should(Receive()) Expect(q.WouldBlock()).To(BeFalse()) - Expect(func() { q.Send(getPacket([]byte("foobar2")), 7) }).ToNot(Panic()) + Expect(func() { q.Send(getPacket([]byte("foobar2")), 7, protocol.ECNNon) }).ToNot(Panic()) q.Close() Eventually(done).Should(BeClosed()) @@ -79,7 +79,7 @@ var _ = Describe("Send Queue", func() { write := make(chan struct{}, 1) written := make(chan struct{}, 100) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, protocol.ByteCount) error { + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16, protocol.ECN) error { written <- struct{}{} <-write return nil @@ -94,19 +94,19 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) <-written // now fill up the send queue for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) } // One more packet is queued when it's picked up by Run and written to the connection. // In this test, it's blocked on write channel in the mocked Write call. <-written Eventually(q.WouldBlock()).Should(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) Expect(q.WouldBlock()).To(BeTrue()) Consistently(q.Available()).ShouldNot(Receive()) @@ -132,15 +132,15 @@ var _ = Describe("Send Queue", func() { // the run loop exits if there is a write error testErr := errors.New("test error") - c.EXPECT().Write(gomock.Any(), gomock.Any()).Return(testErr) - q.Send(getPacket([]byte("foobar")), 6) + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(testErr) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) Eventually(done).Should(BeClosed()) sent := make(chan struct{}) go func() { defer GinkgoRecover() - q.Send(getPacket([]byte("raboof")), 6) - q.Send(getPacket([]byte("quux")), 4) + q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) + q.Send(getPacket([]byte("quux")), 4, protocol.ECNNon) close(sent) }() @@ -149,7 +149,7 @@ var _ = Describe("Send Queue", func() { It("blocks Close() until the packet has been sent out", func() { written := make(chan []byte) - c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ protocol.ByteCount) { written <- p }) + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16, _ protocol.ECN) { written <- p }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -157,7 +157,7 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) closed := make(chan struct{}) go func() { diff --git a/server.go b/server.go index 5b80ec29c..d2bcc7da3 100644 --- a/server.go +++ b/server.go @@ -2,7 +2,6 @@ package quic import ( "context" - "crypto/rand" "errors" "fmt" "net" @@ -60,7 +59,8 @@ type zeroRTTQueue struct { type baseServer struct { mutex sync.Mutex - acceptEarlyConns bool + disableVersionNegotiation bool + acceptEarlyConns bool tlsConf *tls.Config config *Config @@ -68,6 +68,7 @@ type baseServer struct { conn rawConn tokenGenerator *handshake.TokenGenerator + maxTokenAge time.Duration connIDGenerator ConnectionIDGenerator connHandler packetHandlerManager @@ -93,7 +94,7 @@ type baseServer struct { *tls.Config, *handshake.TokenGenerator, bool, /* client address validated by an address validation token */ - logging.ConnectionTracer, + *logging.ConnectionTracer, uint64, utils.Logger, protocol.VersionNumber, @@ -109,7 +110,7 @@ type baseServer struct { connQueue chan quicConn connQueueLen int32 // to be used as an atomic - tracer logging.Tracer + tracer *logging.Tracer logger utils.Logger } @@ -225,32 +226,33 @@ func newServer( connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, - tracer logging.Tracer, + tracer *logging.Tracer, onClose func(), + tokenGeneratorKey TokenGeneratorKey, + maxTokenAge time.Duration, + disableVersionNegotiation bool, acceptEarly bool, -) (*baseServer, error) { - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - if err != nil { - return nil, err - } +) *baseServer { s := &baseServer{ - conn: conn, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connIDGenerator: connIDGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), - versionNegotiationQueue: make(chan receivedPacket, 4), - invalidTokenQueue: make(chan receivedPacket, 4), - newConn: newConnection, - tracer: tracer, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - onClose: onClose, + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), + maxTokenAge: maxTokenAge, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan receivedPacket, 4), + invalidTokenQueue: make(chan receivedPacket, 4), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + disableVersionNegotiation: disableVersionNegotiation, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} @@ -258,7 +260,7 @@ func newServer( go s.run() go s.runSendQueue() s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - return s, nil + return s } func (s *baseServer) run() { @@ -351,7 +353,7 @@ func (s *baseServer) handlePacket(p receivedPacket) { case s.receivedPackets <- p: default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } @@ -364,7 +366,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -377,20 +379,20 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // drop the packet if we failed to parse the protocol version if err != nil { s.logger.Debugf("Dropping a packet with an unknown version") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, v) { - if s.config.DisableVersionNegotiationPackets { + if s.disableVersionNegotiation { return false } if p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -400,7 +402,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -412,7 +414,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) @@ -420,7 +422,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -431,7 +433,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -450,7 +452,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) } return false @@ -464,7 +466,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -474,7 +476,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { } if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -502,7 +504,7 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { continue } for _, p := range q.packets { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } p.buffer.Release() @@ -526,10 +528,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { if !token.ValidateRemoteAddr(addr) { return false } - if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge { + if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge { return false } - if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge { + if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() { return false } return true @@ -538,7 +540,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return errors.New("too short connection ID") @@ -623,7 +625,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } config = populateConfig(conf) } - var tracer logging.ConnectionTracer + var tracer *logging.ConnectionTracer if config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. connID := hdr.DestConnectionID @@ -740,10 +742,10 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentPacket != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } @@ -761,7 +763,7 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) @@ -776,14 +778,14 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } return } hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } return @@ -839,10 +841,10 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentPacket != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } @@ -868,7 +870,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) if err != nil { // should never happen s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return @@ -877,10 +879,10 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/server_test.go b/server_test.go index d1563ad30..71cdf9a7b 100644 --- a/server_test.go +++ b/server_test.go @@ -178,8 +178,9 @@ var _ = Describe("Server", func() { ) BeforeEach(func() { - tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: tracer} + var t *logging.Tracer + t, tracer = mocklogging.NewMockTracer(mockCtrl) + tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.Listen(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer @@ -292,7 +293,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -358,7 +359,7 @@ var _ = Describe("Server", func() { }) It("doesn't send a Version Negotiation packets if sending them is disabled", func() { - serv.config.DisableVersionNegotiationPackets = true + serv.disableVersionNegotiation = true srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ @@ -494,7 +495,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -553,7 +554,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -606,7 +607,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -642,7 +643,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -713,7 +714,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -834,7 +835,8 @@ var _ = Describe("Server", func() { It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - serv.config.MaxRetryTokenAge = time.Millisecond + serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout + Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) @@ -900,7 +902,7 @@ var _ = Describe("Server", func() { It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - serv.config.MaxTokenAge = time.Millisecond + serv.maxTokenAge = time.Millisecond raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} token, err := serv.tokenGenerator.NewToken(raddr) Expect(err).ToNot(HaveOccurred()) @@ -1022,7 +1024,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1099,7 +1101,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1172,7 +1174,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1215,7 +1217,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1279,7 +1281,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1329,8 +1331,9 @@ var _ = Describe("Server", func() { ) BeforeEach(func() { - tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: tracer} + var t *logging.Tracer + t, tracer = mocklogging.NewMockTracer(mockCtrl) + tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.ListenEarly(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) phm = NewMockPacketHandlerManager(mockCtrl) @@ -1404,13 +1407,13 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, ) quicConn { conn := NewMockQUICConn(mockCtrl) - var calls []*gomock.Call + var calls []any calls = append(calls, conn.EXPECT().handlePacket(initial)) for _, p := range zeroRTTPackets { calls = append(calls, conn.EXPECT().handlePacket(p)) diff --git a/sys_conn.go b/sys_conn.go index 88e098fd6..36fc8ab72 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -104,7 +104,13 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) { + if gsoSize != 0 { + panic("cannot use GSO with a basicConn") + } + if ecn != protocol.ECNUnsupported { + panic("cannot use ECN with a basicConn") + } return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_helper_darwin.go b/sys_conn_helper_darwin.go index 758cf7788..d761072f2 100644 --- a/sys_conn_helper_darwin.go +++ b/sys_conn_helper_darwin.go @@ -15,6 +15,8 @@ const ( ipv4PKTINFO = unix.IP_RECVPKTINFO ) +const ecnIPv4DataLen = 4 + // ReadBatch only returns a single packet on OSX, // see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. const batchSize = 1 diff --git a/sys_conn_helper_freebsd.go b/sys_conn_helper_freebsd.go index a2baae3b3..a53ca2eae 100644 --- a/sys_conn_helper_freebsd.go +++ b/sys_conn_helper_freebsd.go @@ -14,6 +14,8 @@ const ( ipv4PKTINFO = 0x7 ) +const ecnIPv4DataLen = 1 + const batchSize = 8 func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) { diff --git a/sys_conn_helper_linux.go b/sys_conn_helper_linux.go index 6a049241b..622f4e6f3 100644 --- a/sys_conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -19,6 +19,8 @@ const ( ipv4PKTINFO = unix.IP_PKTINFO ) +const ecnIPv4DataLen = 4 + const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error { diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 66d5ce67c..67acab1a4 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -8,9 +8,12 @@ import ( "log" "net" "net/netip" + "os" + "strconv" "sync" "syscall" "time" + "unsafe" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -56,6 +59,11 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) { return size, serr } +func isECNDisabled() bool { + disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN")) + return err == nil && disabled +} + type oobConn struct { OOBCapablePacketConn batchConn batchConn @@ -140,6 +148,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { cap: connCapabilities{ DF: supportsDF, GSO: isGSOSupported(rawConn), + ECN: !isECNDisabled(), }, } for i := 0; i < batchSize; i++ { @@ -188,7 +197,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { case msgTypeIPTOS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case ipv4PKTINFO: ip, ifIndex, ok := parseIPv4PktInfo(body) if ok { @@ -205,7 +214,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IPV6 { switch hdr.Type { case unix.IPV6_TCLASS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case unix.IPV6_PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ @@ -228,8 +237,26 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. -func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { +func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) { + oob := packetInfoOOB + if gsoSize > 0 { + if !c.capabilities().GSO { + panic("GSO disabled") + } + oob = appendUDPSegmentSizeMsg(oob, gsoSize) + } + if ecn != protocol.ECNUnsupported { + if !c.capabilities().ECN { + panic("tried to send a ECN-marked packet although ECN is disabled") + } + if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { + if remoteUDPAddr.IP.To4() != nil { + oob = appendIPv4ECNMsg(oob, ecn) + } else { + oob = appendIPv6ECNMsg(oob, ecn) + } + } + } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } @@ -273,3 +300,32 @@ func (info *packetInfo) OOB() []byte { } return nil } + +func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte { + startLen := len(b) + b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_IP + h.Type = unix.IP_TOS + h.SetLen(unix.CmsgLen(ecnIPv4DataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + b[offset] = val.ToHeaderBits() + return b +} + +func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte { + startLen := len(b) + const dataLen = 4 + b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_IPV6 + h.Type = unix.IPV6_TCLASS + h.SetLen(unix.CmsgLen(dataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + b[offset] = val.ToHeaderBits() + return b +} diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index 0a4efb947..3d96df136 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -18,6 +18,16 @@ import ( "go.uber.org/mock/gomock" ) +type oobRecordingConn struct { + *net.UDPConn + oobs [][]byte +} + +func (c *oobRecordingConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { + c.oobs = append(c.oobs, oob) + return c.UDPConn.WriteMsgUDP(b, oob, addr) +} + var _ = Describe("OOB Conn Test", func() { runServer := func(network, address string) (*net.UDPConn, <-chan receivedPacket) { addr, err := net.ResolveUDPAddr(network, address) @@ -43,7 +53,7 @@ var _ = Describe("OOB Conn Test", func() { return udpConn, packetChan } - Context("ECN conn", func() { + Context("reading ECN-marked packets", func() { sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { conn, err := net.DialUDP(network, nil, addr) ExpectWithOffset(1, err).ToNot(HaveOccurred()) @@ -129,6 +139,42 @@ var _ = Describe("OOB Conn Test", func() { Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) Expect(p.ecn).To(Equal(protocol.ECT1)) }) + + It("sends packets with ECN on IPv4", func() { + conn, packetChan := runServer("udp4", "localhost:0") + defer conn.Close() + + c, err := net.ListenUDP("udp4", nil) + Expect(err).ToNot(HaveOccurred()) + defer c.Close() + + for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} { + _, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv4ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr)) + Expect(err).ToNot(HaveOccurred()) + var p receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.ecn).To(Equal(val)) + } + }) + + It("sends packets with ECN on IPv6", func() { + conn, packetChan := runServer("udp6", "[::1]:0") + defer conn.Close() + + c, err := net.ListenUDP("udp6", nil) + Expect(err).ToNot(HaveOccurred()) + defer c.Close() + + for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} { + _, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv6ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr)) + Expect(err).ToNot(HaveOccurred()) + var p receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.ecn).To(Equal(val)) + } + }) }) Context("Packet Info conn", func() { @@ -242,4 +288,50 @@ var _ = Describe("OOB Conn Test", func() { } }) }) + + Context("sending ECN-marked packets", func() { + It("sets the ECN control message", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + c := &oobRecordingConn{UDPConn: udpConn} + oobConn, err := newConn(c, true) + Expect(err).ToNot(HaveOccurred()) + + oob := make([]byte, 0, 123) + oobConn.WritePacket([]byte("foobar"), addr, oob, 0, protocol.ECNCE) + Expect(c.oobs).To(HaveLen(1)) + oobMsg := c.oobs[0] + Expect(oobMsg).ToNot(BeEmpty()) + Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob + expected := appendIPv4ECNMsg([]byte{}, protocol.ECNCE) + Expect(oobMsg).To(Equal(expected)) + }) + }) + + if platformSupportsGSO { + Context("GSO", func() { + It("appends the GSO control message", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + c := &oobRecordingConn{UDPConn: udpConn} + oobConn, err := newConn(c, true) + Expect(err).ToNot(HaveOccurred()) + Expect(oobConn.capabilities().GSO).To(BeTrue()) + + oob := make([]byte, 0, 123) + oobConn.WritePacket([]byte("foobar"), addr, oob, 3, protocol.ECNCE) + Expect(c.oobs).To(HaveLen(1)) + oobMsg := c.oobs[0] + Expect(oobMsg).ToNot(BeEmpty()) + Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob + expected := appendUDPSegmentSizeMsg([]byte{}, 3) + // Check that the first control message is the OOB control message. + Expect(oobMsg[:len(expected)]).To(Equal(expected)) + }) + }) + } }) diff --git a/transport.go b/transport.go index a0d0784a3..e1474e848 100644 --- a/transport.go +++ b/transport.go @@ -17,6 +17,8 @@ import ( "github.com/refraction-networking/uquic/logging" ) +var errListenerAlreadySet = errors.New("listener already set") + // The Transport is the central point to manage incoming and outgoing QUIC connections. // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. // This means that a single UDP socket can be used for listening for incoming connections, as well as @@ -58,8 +60,26 @@ type Transport struct { // See section 10.3 of RFC 9000 for details. StatelessResetKey *StatelessResetKey + // The TokenGeneratorKey is used to encrypt session resumption tokens. + // If no key is configured, a random key will be generated. + // If multiple servers are authoritative for the same domain, they should use the same key, + // see section 8.1.3 of RFC 9000 for details. + TokenGeneratorKey *TokenGeneratorKey + + // MaxTokenAge is the maximum age of the resumption token presented during the handshake. + // These tokens allow skipping address resumption when resuming a QUIC connection, + // and are especially useful when using 0-RTT. + // If not set, it defaults to 24 hours. + // See section 8.1.3 of RFC 9000 for details. + MaxTokenAge time.Duration + + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. + // This can be useful if version information is exchanged out-of-band. + // It has no effect for clients. + DisableVersionNegotiationPackets bool + // A Tracer traces events that don't belong to a single QUIC connection. - Tracer logging.Tracer + Tracer *logging.Tracer handlerMap packetHandlerManager @@ -74,7 +94,7 @@ type Transport struct { // If no ConnectionIDGenerator is set, this is set to a default. connIDGenerator ConnectionIDGenerator - server unknownPacketHandler + server *baseServer conn rawConn @@ -96,28 +116,10 @@ type Transport struct { // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(conf); err != nil { - return nil, err - } - - t.mutex.Lock() - defer t.mutex.Unlock() - - if t.server != nil { - return nil, errListenerAlreadySet - } - conf = populateServerConfig(conf) - if err := t.init(false); err != nil { - return nil, err - } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) + s, err := t.createServer(tlsConf, conf, false) if err != nil { return nil, err } - t.server = s return &Listener{baseServer: s}, nil } @@ -125,6 +127,14 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + s, err := t.createServer(tlsConf, conf, true) + if err != nil { + return nil, err + } + return &EarlyListener{baseServer: s}, nil +} + +func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } @@ -142,12 +152,21 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen if err := t.init(false); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) - if err != nil { - return nil, err - } + s := newServer( + t.conn, + t.handlerMap, + t.connIDGenerator, + tlsConf, + conf, + t.Tracer, + t.closeServer, + *t.TokenGeneratorKey, + t.MaxTokenAge, + t.DisableVersionNegotiationPackets, + allow0RTT, + ) t.server = s - return &EarlyListener{baseServer: s}, nil + return s, nil } // Dial dials a new connection to a remote host (not using 0-RTT). @@ -200,6 +219,14 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.closeQueue = make(chan closePacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4) + if t.TokenGeneratorKey == nil { + var key TokenGeneratorKey + if _, err := rand.Read(key[:]); err != nil { + t.initErr = err + return + } + t.TokenGeneratorKey = &key + } if t.ConnectionIDGenerator != nil { t.connIDGenerator = t.ConnectionIDGenerator @@ -225,7 +252,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil) + return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -243,7 +270,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -348,7 +375,7 @@ func (t *Transport) handlePacket(p receivedPacket) { connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } p.buffer.MaybeRelease() @@ -411,7 +438,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } @@ -443,7 +470,7 @@ func (t *Transport) handleNonQUICPacket(p receivedPacket) { select { case t.nonQUICPackets <- p: default: - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } diff --git a/transport_test.go b/transport_test.go index 588562184..97ddfabed 100644 --- a/transport_test.go +++ b/transport_test.go @@ -127,11 +127,11 @@ var _ = Describe("Transport", func() { It("drops unparseable QUIC packets", func() { addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) + t, tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, + Tracer: t, } tr.init(true) dropped := make(chan struct{}) @@ -329,11 +329,9 @@ var _ = Describe("Transport", func() { It("allows receiving non-QUIC packets", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, } tr.init(true) receivedPacketChan := make(chan []byte) @@ -363,11 +361,11 @@ var _ = Describe("Transport", func() { It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) + t, tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, + Tracer: t, } tr.init(true) diff --git a/u_connection.go b/u_connection.go index dc118bc6c..7c11749c8 100644 --- a/u_connection.go +++ b/u_connection.go @@ -24,7 +24,7 @@ var newUClientConnection = func( initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -67,11 +67,12 @@ var newUClientConnection = func( ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) - s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewUAckHandler( // [UQUIC] + s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewUAckHandler( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, - false, /* has no effect */ + false, // has no effect + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, @@ -82,7 +83,7 @@ var newUClientConnection = func( } s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream(true) + oneRTTStream := newCryptoStream() var params *wire.TransportParameters @@ -134,8 +135,7 @@ var newUClientConnection = func( params.MaxDatagramFrameSize = protocol.InvalidByteCount } } - - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewUCryptoSetupClient(