Skip to content

Commit

Permalink
Limit size of encrypted packet queue
Browse files Browse the repository at this point in the history
Before we would queue unbounded encrypted data until the handshake
finished/timed out. This sets a limit of 100 packets until the handshake
completes.
  • Loading branch information
sukunrt authored and Sean-Der committed Jun 1, 2024
1 parent fbbdf66 commit 9be049c
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 34 deletions.
94 changes: 61 additions & 33 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ const (
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
// maxAppDataPacketQueueSize is the maximum number of app data packets we will
// enqueue before the handshake is completed
maxAppDataPacketQueueSize = 100
)

func invalidKeyingLabels() map[string]bool {
Expand Down Expand Up @@ -88,7 +91,7 @@ type Conn struct {
replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, initialState *State) (*Conn, error) {
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
Expand All @@ -97,21 +100,6 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
Expand Down Expand Up @@ -162,6 +150,28 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn

Check warning on line 158 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L158

Added line #L158 was not covered by tests
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err

Check warning on line 163 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L163

Added line #L163 was not covered by tests
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err

Check warning on line 168 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L168

Added line #L168 was not covered by tests
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

serverName := config.ServerName
// Do not allow the use of an IP address literal as an SNI value.
Expand Down Expand Up @@ -193,7 +203,7 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
log: conn.log,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand Down Expand Up @@ -222,30 +232,30 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
var initialFSMState handshakeState

if initialState != nil {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

c.state = *initialState
conn.state = *initialState
} else {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
}

c.log.Trace("Handshake Completed")
conn.log.Trace("Handshake Completed")

return c, nil
return conn, nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
Expand Down Expand Up @@ -301,16 +311,24 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
return nil, errPSKAndIdentityMustBeSetForClient
}

return createConn(ctx, conn, rAddr, config, true, nil)
dconn, err := createConn(conn, rAddr, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}

return createConn(ctx, conn, rAddr, config, false, nil)
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
}

// Read reads data from the connection.
Expand Down Expand Up @@ -738,6 +756,14 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
return nil
}

func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
c.encryptedPackets = append(c.encryptedPackets, packet)
return true
}
return false
}

func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
h := &recordlayer.Header{}
// Set connection ID size so that records of content type tls12_cid will
Expand All @@ -751,7 +777,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}

// Validate epoch
remoteEpoch := c.state.getRemoteEpoch()
if h.Epoch > remoteEpoch {
Expand All @@ -762,8 +787,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("received packet of next epoch, queuing packet")
}
}
return false, nil, nil
}
Expand All @@ -790,8 +816,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
if h.Epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
c.log.Debug("handshake not finished, queuing packet")
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("handshake not finished, queuing packet")

Check warning on line 820 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L819-L820

Added lines #L819 - L820 were not covered by tests
}
}
return false, nil, nil
}
Expand Down Expand Up @@ -883,8 +910,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
c.log.Debugf("CipherSuite not initialized, queuing packet")
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
}
return false, nil, nil
}
Expand Down
84 changes: 84 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3279,3 +3279,87 @@ func (c *connWithCallback) Write(b []byte) (int, error) {
}
return c.Conn.Write(b)
}

func TestApplicationDataQueueLimited(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

ca, cb := dpipe.Pipe()
defer ca.Close() //nolint:errcheck
defer cb.Close() //nolint:errcheck

done := make(chan struct{})
go func() {
serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Error(err)
return
}
cfg := &Config{}
cfg.Certificates = []tls.Certificate{serverCert}

dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false)
if err != nil {
t.Error(err)
return
}
go func() {
for i := 0; i < 5; i++ {
dconn.lock.RLock()
qlen := len(dconn.encryptedPackets)
dconn.lock.RUnlock()
if qlen > maxAppDataPacketQueueSize {
t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
}
t.Log(qlen)
time.Sleep(1 * time.Second)
}
}()
if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
t.Error("expected handshake to fail")
}
close(done)
}()
extensions := []extension.Extension{}

time.Sleep(50 * time.Millisecond)

err := sendClientHello([]byte{}, ca, 0, extensions)
if err != nil {
t.Fatal(err)
}

time.Sleep(50 * time.Millisecond)

for i := 0; i < 1000; i++ {
// Send an application data packet
packet, err := (&recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: uint64(3),
Epoch: 1, // use an epoch greater than 0
},
Content: &protocol.ApplicationData{
Data: []byte{1, 2, 3, 4},
},
}).Marshal()
if err != nil {
t.Fatal(err)
}
ca.Write(packet) // nolint
if i%100 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
time.Sleep(1 * time.Second)
ca.Close() // nolint
<-done
}
6 changes: 5 additions & 1 deletion resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (
if err := state.initCipherSuite(); err != nil {
return nil, err
}
c, err := createConn(context.Background(), conn, rAddr, config, state.isClient, state)
dconn, err := createConn(conn, rAddr, config, state.isClient)
if err != nil {
return nil, err

Check warning on line 18 in resume.go

View check run for this annotation

Codecov / codecov/patch

resume.go#L18

Added line #L18 was not covered by tests
}
c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 9be049c

Please sign in to comment.