diff --git a/association.go b/association.go index 8130834d..167fce12 100644 --- a/association.go +++ b/association.go @@ -979,19 +979,23 @@ func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet { func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { state := a.getState() a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state)) - if state != closed && state != cookieWait && state != cookieEchoed { - return nil - } + if state != established { + if state != closed && state != cookieWait && state != cookieEchoed { + return nil + } + if !bytes.Equal(a.myCookie.cookie, c.cookie) { + return nil + } - if !bytes.Equal(a.myCookie.cookie, c.cookie) { - return nil - } + a.t1Init.stop() + a.storedInit = nil - a.t1Init.stop() - a.storedInit = nil + a.t1Cookie.stop() + a.storedCookieEcho = nil - a.t1Cookie.stop() - a.storedCookieEcho = nil + a.setState(established) + a.handshakeCompletedCh <- nil + } p := &packet{ verificationTag: a.peerVerificationTag, @@ -999,9 +1003,6 @@ func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { destinationPort: a.destinationPort, chunks: []chunk{&chunkCookieAck{}}, } - - a.setState(established) - a.handshakeCompletedCh <- nil return pack(p) } diff --git a/vnet_test.go b/vnet_test.go index e445a830..7d699295 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -21,11 +21,13 @@ type vNetEnvConfig struct { } type vNetEnv struct { - wan *vnet.Router - net0 *vnet.Net - net1 *vnet.Net - numToDropData int - numToDropReconfig int + wan *vnet.Router + net0 *vnet.Net + net1 *vnet.Net + numToDropData int + numToDropReconfig int + numToDropCookieEcho int + numToDropCookieAck int } func (venv *vNetEnv) dropNextDataChunk(numToDrop int) { @@ -36,6 +38,14 @@ func (venv *vNetEnv) dropNextReconfigChunk(numToDrop int) { venv.numToDropReconfig = numToDrop } +func (venv *vNetEnv) dropNextCookieEchoChunk(numToDrop int) { + venv.numToDropCookieEcho = numToDrop +} + +func (venv *vNetEnv) dropNextCookieAckChunk(numToDrop int) { + venv.numToDropCookieAck = numToDrop +} + func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { log := cfg.log @@ -88,6 +98,20 @@ func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { log.Infof("Chunk filter: drop RECONFIG %s", chunk.String()) break loop } + case *chunkCookieEcho: + if venv.numToDropCookieEcho > 0 { + toDrop = true + venv.numToDropCookieEcho-- + log.Infof("Chunk filter: drop %s", chunk.String()) + break loop + } + case *chunkCookieAck: + if venv.numToDropCookieAck > 0 { + toDrop = true + venv.numToDropCookieAck-- + log.Infof("Chunk filter: drop %s", chunk.String()) + break loop + } } } return !toDrop @@ -183,7 +207,7 @@ func testRwndFull(t *testing.T, unordered bool) { } defer assoc.Close() // nolint:errcheck - log.Info("server handlshake complete") + log.Info("server handshake complete") close(serverHandshakeDone) stream, err := assoc.AcceptStream() @@ -257,7 +281,7 @@ func testRwndFull(t *testing.T, unordered bool) { } defer assoc.Close() // nolint:errcheck - log.Info("client handlshake complete") + log.Info("client handshake complete") close(clientHandshakeDone) stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary) @@ -408,7 +432,7 @@ func testStreamClose(t *testing.T, dropReconfig bool) { } defer assoc.Close() // nolint:errcheck - log.Info("server handlshake complete") + log.Info("server handshake complete") stream, err := assoc.AcceptStream() if !assert.NoError(t, err, "should succeed") { @@ -457,7 +481,7 @@ func testStreamClose(t *testing.T, dropReconfig bool) { } defer assoc.Close() // nolint:errcheck - log.Info("client handlshake complete") + log.Info("client handshake complete") stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary) if !assert.NoError(t, err, "should succeed") { @@ -544,3 +568,116 @@ func TestStreamClose(t *testing.T) { testStreamClose(t, true) }) } + +// this test case reproduces the issue mentioned in +// https://github.com/pion/webrtc/issues/1270#issuecomment-653953743 +// and confirmes the fix. +// To reproduce the case mentioned above: +// * Use simultaneous-open (SCTP) +// * Drop both of the first COOKIE-ECHO and COOKIE-ACK +func TestCookieEchoRetransmission(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + loggerFactory := logging.NewDefaultLoggerFactory() + log := loggerFactory.NewLogger("test") + + venv, err := buildVNetEnv(&vNetEnvConfig{ + minDelay: 200 * time.Millisecond, + loggerFactory: loggerFactory, + log: log, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + if !assert.NotNil(t, venv, "should not be nil") { + return + } + defer venv.wan.Stop() // nolint:errcheck + + // To cause the cookie echo retransmission, both COOKIE-ECHO + // and COOKIE-ACK chunks need to be dropped at the same time. + venv.dropNextCookieEchoChunk(1) + venv.dropNextCookieAckChunk(1) + + serverHandshakeDone := make(chan struct{}) + clientHandshakeDone := make(chan struct{}) + waitAllHandshakeDone := make(chan struct{}) + clientShutDown := make(chan struct{}) + serverShutDown := make(chan struct{}) + + maxReceiveBufferSize := uint32(64 * 1024) + + // Go routine for Server + go func() { + defer close(serverShutDown) + // connected UDP conn for server + conn, err := venv.net0.DialUDP("udp4", + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + ) + if !assert.NoError(t, err, "should succeed") { + return + } + defer conn.Close() // nolint:errcheck + + // server association + // using Client for simultaneous open + assoc, err := Client(Config{ + NetConn: conn, + MaxReceiveBufferSize: maxReceiveBufferSize, + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + defer assoc.Close() // nolint:errcheck + + log.Info("server handshake complete") + close(serverHandshakeDone) + <-waitAllHandshakeDone + }() + + // Go routine for Client + go func() { + defer close(clientShutDown) + // connected UDP conn for client + conn, err := venv.net1.DialUDP("udp4", + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + ) + if !assert.NoError(t, err, "should succeed") { + return + } + + // client association + assoc, err := Client(Config{ + NetConn: conn, + MaxReceiveBufferSize: maxReceiveBufferSize, + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + defer assoc.Close() // nolint:errcheck + + log.Info("client handshake complete") + close(clientHandshakeDone) + <-waitAllHandshakeDone + }() + + // + // Scenario + // + + // wait until both handshake complete + <-clientHandshakeDone + <-serverHandshakeDone + close(waitAllHandshakeDone) + + log.Info("handshake complete") + + <-clientShutDown + <-serverShutDown + log.Info("all done") +}