Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit size of encrypted packet queue #628

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 61 additions & 33 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
// maxAppDataPacketQueueSize is the maximum number of app data packets we will
// enqueue before the handshake is completed
maxAppDataPacketQueueSize = 100
)

func invalidKeyingLabels() map[string]bool {
Expand Down Expand Up @@ -81,7 +84,7 @@
replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
func createConn(nextConn net.Conn, config *Config, isClient bool) (*Conn, error) {
err := validateConfig(config)
if err != nil {
return nil, err
Expand All @@ -91,21 +94,6 @@
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
Expand Down Expand Up @@ -149,6 +137,28 @@

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

serverName := config.ServerName
// Do not allow the use of an IP address literal as an SNI value.
Expand Down Expand Up @@ -180,7 +190,7 @@
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
log: conn.log,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand All @@ -205,30 +215,30 @@
var initialFSMState handshakeState

if initialState != nil {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

c.state = *initialState
conn.state = *initialState
} else {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
}

c.log.Trace("Handshake Completed")
conn.log.Trace("Handshake Completed")

return c, nil
return conn, nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
Expand Down Expand Up @@ -279,16 +289,24 @@
return nil, errPSKAndIdentityMustBeSetForClient
}

return createConn(ctx, conn, config, true, nil)
dconn, err := createConn(conn, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}

return createConn(ctx, conn, config, false, nil)
dconn, err := createConn(conn, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
}

// Read reads data from the connection.
Expand Down Expand Up @@ -654,6 +672,14 @@
return nil
}

func (c *Conn) enqueueEncryptedPackets(packet []byte) bool {
if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
c.encryptedPackets = append(c.encryptedPackets, packet)
return true
}
return false
}

func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
h := &recordlayer.Header{}
if err := h.Unmarshal(buf); err != nil {
Expand All @@ -662,7 +688,6 @@
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}

// Validate epoch
remoteEpoch := c.state.getRemoteEpoch()
if h.Epoch > remoteEpoch {
Expand All @@ -673,8 +698,9 @@
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, buf)
if ok := c.enqueueEncryptedPackets(buf); ok {
c.log.Debug("received packet of next epoch, queuing packet")
}
}
return false, nil, nil
}
Expand All @@ -697,8 +723,9 @@
if h.Epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
if ok := c.enqueueEncryptedPackets(buf); ok {
c.log.Debug("handshake not finished, queuing packet")
}
}
return false, nil, nil
}
Expand Down Expand Up @@ -749,8 +776,9 @@
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
if ok := c.enqueueEncryptedPackets(buf); ok {
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
}
return false, nil, nil
}
Expand Down Expand Up @@ -828,7 +856,7 @@
done := make(chan struct{})
ctxRead, cancelRead := context.WithCancel(context.Background())
c.cancelHandshakeReader = cancelRead
cfg.onFlightState = func(f flightVal, s handshakeState) {

Check warning on line 859 in conn.go

View workflow job for this annotation

GitHub Actions / lint / Go

unused-parameter: parameter 'f' seems to be unused, consider removing or renaming it as _ (revive)
if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
c.setHandshakeCompletedSuccessfully()
close(done)
Expand Down
85 changes: 85 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3050,3 +3050,88 @@
}
return c.Conn.Write(b)
}

func TestApplicationDataQueueLimited(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

ca, cb := dpipe.Pipe()
defer ca.Close()

Check failure on line 3067 in conn_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `ca.Close` is not checked (errcheck)
defer cb.Close()

Check failure on line 3068 in conn_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `cb.Close` is not checked (errcheck)

done := make(chan struct{})
go func() {
serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Error(err)
return
}
cfg := &Config{}
cfg.Certificates = []tls.Certificate{serverCert}

dconn, err := createConn(cb, cfg, false)
if err != nil {
t.Error(err)
return
}
go func() {
for i := 0; i < 5; i++ {
dconn.lock.RLock()
qlen := len(dconn.encryptedPackets)
dconn.lock.RUnlock()
if qlen > maxAppDataPacketQueueSize {
t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
}
t.Log(qlen)
time.Sleep(1 * time.Second)
}

Check failure on line 3096 in conn_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not `gofumpt`-ed (gofumpt)
}()

Check failure on line 3097 in conn_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

unnecessary trailing newline (whitespace)
if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
t.Error("expected handshake to fail")
}
close(done)
}()
extensions := []extension.Extension{}

time.Sleep(50 * time.Millisecond)

err := sendClientHello([]byte{}, ca, 0, extensions)
if err != nil {
t.Fatal(err)
}

time.Sleep(50 * time.Millisecond)

for i := 0; i < 1000; i++ {
// Send an application data packet
packet, err := (&recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: uint64(3),
Epoch: 1, // use an epoch greater than 0
},
Content: &protocol.ApplicationData{
Data: []byte{1, 2, 3, 4},
},
}).Marshal()
if err != nil {
t.Fatal(err)
}
ca.Write(packet)

Check failure on line 3129 in conn_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `ca.Write` is not checked (errcheck)
if i%100 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
time.Sleep(1 * time.Second)
ca.Close()

Check failure on line 3135 in conn_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `ca.Close` is not checked (errcheck)
<-done
}
6 changes: 5 additions & 1 deletion resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ func Resume(state *State, conn net.Conn, config *Config) (*Conn, error) {
if err := state.initCipherSuite(); err != nil {
return nil, err
}
c, err := createConn(context.Background(), conn, config, state.isClient, state)
dconn, err := createConn(conn, config, state.isClient)
if err != nil {
return nil, err
}
c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state)
if err != nil {
return nil, err
}
Expand Down
Loading