Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit size of encrypted packet queue #632

Merged
merged 1 commit into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 61 additions & 33 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
// maxAppDataPacketQueueSize is the maximum number of app data packets we will
// enqueue before the handshake is completed
maxAppDataPacketQueueSize = 100
)

func invalidKeyingLabels() map[string]bool {
Expand Down Expand Up @@ -88,7 +91,7 @@
replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, initialState *State) (*Conn, error) {
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
Expand All @@ -97,21 +100,6 @@
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
Expand Down Expand Up @@ -162,6 +150,28 @@

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn

Check warning on line 158 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L158

Added line #L158 was not covered by tests
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err

Check warning on line 163 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L163

Added line #L163 was not covered by tests
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err

Check warning on line 168 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L168

Added line #L168 was not covered by tests
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

serverName := config.ServerName
// Do not allow the use of an IP address literal as an SNI value.
Expand Down Expand Up @@ -193,7 +203,7 @@
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
log: conn.log,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand Down Expand Up @@ -222,30 +232,30 @@
var initialFSMState handshakeState

if initialState != nil {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

c.state = *initialState
conn.state = *initialState
} else {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
}

c.log.Trace("Handshake Completed")
conn.log.Trace("Handshake Completed")

return c, nil
return conn, nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
Expand Down Expand Up @@ -301,16 +311,24 @@
return nil, errPSKAndIdentityMustBeSetForClient
}

return createConn(ctx, conn, rAddr, config, true, nil)
dconn, err := createConn(conn, rAddr, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}

return createConn(ctx, conn, rAddr, config, false, nil)
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
}

// Read reads data from the connection.
Expand Down Expand Up @@ -738,6 +756,14 @@
return nil
}

func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
c.encryptedPackets = append(c.encryptedPackets, packet)
return true
}
return false
}

func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
h := &recordlayer.Header{}
// Set connection ID size so that records of content type tls12_cid will
Expand All @@ -751,7 +777,6 @@
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}

// Validate epoch
remoteEpoch := c.state.getRemoteEpoch()
if h.Epoch > remoteEpoch {
Expand All @@ -762,8 +787,9 @@
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("received packet of next epoch, queuing packet")
}
}
return false, nil, nil
}
Expand All @@ -790,8 +816,9 @@
if h.Epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
c.log.Debug("handshake not finished, queuing packet")
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("handshake not finished, queuing packet")

Check warning on line 820 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L819-L820

Added lines #L819 - L820 were not covered by tests
}
}
return false, nil, nil
}
Expand Down Expand Up @@ -883,8 +910,9 @@
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
c.log.Debugf("CipherSuite not initialized, queuing packet")
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
}
return false, nil, nil
}
Expand Down
84 changes: 84 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3279,3 +3279,87 @@ func (c *connWithCallback) Write(b []byte) (int, error) {
}
return c.Conn.Write(b)
}

func TestApplicationDataQueueLimited(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

ca, cb := dpipe.Pipe()
defer ca.Close() //nolint:errcheck
defer cb.Close() //nolint:errcheck

done := make(chan struct{})
go func() {
serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Error(err)
return
}
cfg := &Config{}
cfg.Certificates = []tls.Certificate{serverCert}

dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false)
if err != nil {
t.Error(err)
return
}
go func() {
for i := 0; i < 5; i++ {
dconn.lock.RLock()
qlen := len(dconn.encryptedPackets)
dconn.lock.RUnlock()
if qlen > maxAppDataPacketQueueSize {
t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
}
t.Log(qlen)
time.Sleep(1 * time.Second)
}
}()
if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
t.Error("expected handshake to fail")
}
close(done)
}()
extensions := []extension.Extension{}

time.Sleep(50 * time.Millisecond)

err := sendClientHello([]byte{}, ca, 0, extensions)
if err != nil {
t.Fatal(err)
}

time.Sleep(50 * time.Millisecond)

for i := 0; i < 1000; i++ {
// Send an application data packet
packet, err := (&recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: uint64(3),
Epoch: 1, // use an epoch greater than 0
},
Content: &protocol.ApplicationData{
Data: []byte{1, 2, 3, 4},
},
}).Marshal()
if err != nil {
t.Fatal(err)
}
ca.Write(packet) // nolint
if i%100 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
time.Sleep(1 * time.Second)
ca.Close() // nolint
<-done
}
6 changes: 5 additions & 1 deletion resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
if err := state.initCipherSuite(); err != nil {
return nil, err
}
c, err := createConn(context.Background(), conn, rAddr, config, state.isClient, state)
dconn, err := createConn(conn, rAddr, config, state.isClient)
if err != nil {
return nil, err

Check warning on line 18 in resume.go

View check run for this annotation

Codecov / codecov/patch

resume.go#L18

Added line #L18 was not covered by tests
}
c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state)
if err != nil {
return nil, err
}
Expand Down
Loading