diff --git a/association.go b/association.go index b0b5e4f9..6f6f5f0d 100644 --- a/association.go +++ b/association.go @@ -821,11 +821,6 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { // responding, the endpoint MUST send the INIT ACK back to the same // address that the original INIT (sent by this endpoint) was sent. - // https://tools.ietf.org/html/rfc4960#section-5.2.1 - // Upon receipt of an INIT in the COOKIE-ECHOED state, an endpoint MUST - // respond with an INIT ACK using the same parameters it sent in its - // original INIT chunk (including its Initiate Tag, unchanged) - if state != closed && state != cookieWait && state != cookieEchoed { // 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED, // COOKIE-WAIT, and SHUTDOWN-ACK-SENT @@ -845,6 +840,21 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { // subtracting one from it. a.peerLastTSN = i.initialTSN - 1 + for _, param := range i.params { + switch v := param.(type) { + case *paramSupportedExtensions: + for _, t := range v.ChunkTypes { + if t == ctForwardTSN { + a.log.Debugf("[%s] use ForwardTSN (on init)\n", a.name) + a.useForwardTSN = true + } + } + } + } + if !a.useForwardTSN { + a.log.Warnf("[%s] not using ForwardTSN (on init)\n", a.name) + } + outbound := &packet{} outbound.verificationTag = a.peerVerificationTag outbound.sourcePort = a.sourcePort @@ -917,11 +927,15 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { case *paramSupportedExtensions: for _, t := range v.ChunkTypes { if t == ctForwardTSN { + a.log.Debugf("[%s] use ForwardTSN (on initAck)\n", a.name) a.useForwardTSN = true } } } } + if !a.useForwardTSN { + a.log.Warnf("[%s] not using ForwardTSN (on initAck)\n", a.name) + } if cookieParam == nil { return errors.Errorf("no cookie in InitAck") } @@ -1579,6 +1593,7 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { a.log.Tracef("[%s] FwdTSN: %s", a.name, c.String()) if !a.useForwardTSN { + a.log.Warn("[%s] received FwdTSN but not enabled") // Return an error chunk cerr := &chunkError{ errorCauses: []errorCause{&errorCauseUnrecognizedChunkType{}}, @@ -1599,6 +1614,8 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { // send a SACK to its peer (the sender of the FORWARD TSN) since such a // duplicate may indicate the previous SACK was lost in the network. + a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d\n", + a.name, c.newCumulativeTSN, a.peerLastTSN) if sna32LTE(c.newCumulativeTSN, a.peerLastTSN) { a.log.Tracef("[%s] sending ack on Forward TSN", a.name) a.ackState = ackStateImmediate @@ -1860,6 +1877,7 @@ func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error { // The caller should hold the lock. func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) { + a.log.Debugf("[%s] PR check: in, useForwardTSN=%v", a.name, a.useForwardTSN) if !a.useForwardTSN { return } @@ -1880,6 +1898,8 @@ func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) { } } s.lock.RUnlock() + } else { + a.log.Errorf("[%s] stream %d not found)", a.name, c.streamIdentifier) } } diff --git a/association_test.go b/association_test.go index 2ca13813..040d9870 100644 --- a/association_test.go +++ b/association_test.go @@ -2420,3 +2420,64 @@ func TestStats(t *testing.T) { assert.Equal(t, conn.bytesReceived, a.BytesReceived()) assert.Equal(t, conn.bytesSent, a.BytesSent()) } + +func TestAssocHandleInit(t *testing.T) { + loggerFactory := logging.NewDefaultLoggerFactory() + + handleInitTest := func(t *testing.T, initialState uint32, expectErr bool) { + a := createAssociation(Config{ + NetConn: &dumbConn{}, + LoggerFactory: loggerFactory, + }) + a.setState(initialState) + pkt := &packet{ + sourcePort: 5001, + destinationPort: 5002, + } + init := &chunkInit{} + init.initialTSN = 1234 + init.numOutboundStreams = 1001 + init.numInboundStreams = 1002 + init.initiateTag = 5678 + init.advertisedReceiverWindowCredit = 512 * 1024 + setSupportedExtensions(&init.chunkInitCommon) + + _, err := a.handleInit(pkt, init) + if expectErr { + assert.Error(t, err, "should fail") + return + } + assert.NoError(t, err, "should succeed") + assert.Equal(t, uint32(init.initialTSN-1), a.peerLastTSN, "should match") + assert.Equal(t, uint16(1001), a.myMaxNumOutboundStreams, "should match") + assert.Equal(t, uint16(1002), a.myMaxNumInboundStreams, "should match") + assert.Equal(t, uint32(5678), a.peerVerificationTag, "should match") + assert.Equal(t, pkt.sourcePort, a.destinationPort, "should match") + assert.Equal(t, pkt.destinationPort, a.sourcePort, "should match") + assert.True(t, a.useForwardTSN, "should be set to true") + } + + t.Run("normal", func(t *testing.T) { + handleInitTest(t, closed, false) + }) + + t.Run("unexpected state established", func(t *testing.T) { + handleInitTest(t, established, true) + }) + + t.Run("unexpected state shutdownAckSent", func(t *testing.T) { + handleInitTest(t, shutdownAckSent, true) + }) + + t.Run("unexpected state shutdownPending", func(t *testing.T) { + handleInitTest(t, shutdownPending, true) + }) + + t.Run("unexpected state shutdownReceived", func(t *testing.T) { + handleInitTest(t, shutdownReceived, true) + }) + + t.Run("unexpected state shutdownSent", func(t *testing.T) { + handleInitTest(t, shutdownSent, true) + }) +} diff --git a/stream.go b/stream.go index 17933812..c563d680 100644 --- a/stream.go +++ b/stream.go @@ -70,6 +70,8 @@ func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint3 // setReliabilityParams sets reliability parameters for this stream. // The caller should hold the lock. func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) { + s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d", + s.name, !unordered, relType, relVal) s.unordered = unordered s.reliabilityType = relType s.reliabilityValue = relVal