Skip to content

Commit

Permalink
extract useForwardTSN in handleInit
Browse files Browse the repository at this point in the history
Relates to pion/webrtc#1270
  • Loading branch information
enobufs committed Jun 29, 2020
1 parent 6f6d053 commit 201f752
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 5 deletions.
30 changes: 25 additions & 5 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -821,11 +821,6 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
// responding, the endpoint MUST send the INIT ACK back to the same
// address that the original INIT (sent by this endpoint) was sent.

// https://tools.ietf.org/html/rfc4960#section-5.2.1
// Upon receipt of an INIT in the COOKIE-ECHOED state, an endpoint MUST
// respond with an INIT ACK using the same parameters it sent in its
// original INIT chunk (including its Initiate Tag, unchanged)

if state != closed && state != cookieWait && state != cookieEchoed {
// 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED,
// COOKIE-WAIT, and SHUTDOWN-ACK-SENT
Expand All @@ -845,6 +840,21 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
// subtracting one from it.
a.peerLastTSN = i.initialTSN - 1

for _, param := range i.params {
switch v := param.(type) {
case *paramSupportedExtensions:
for _, t := range v.ChunkTypes {
if t == ctForwardTSN {
a.log.Debugf("[%s] use ForwardTSN (on init)\n", a.name)
a.useForwardTSN = true
}
}
}
}
if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on init)\n", a.name)
}

outbound := &packet{}
outbound.verificationTag = a.peerVerificationTag
outbound.sourcePort = a.sourcePort
Expand Down Expand Up @@ -917,11 +927,15 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
case *paramSupportedExtensions:
for _, t := range v.ChunkTypes {
if t == ctForwardTSN {
a.log.Debugf("[%s] use ForwardTSN (on initAck)\n", a.name)
a.useForwardTSN = true
}
}
}
}
if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on initAck)\n", a.name)
}
if cookieParam == nil {
return errors.Errorf("no cookie in InitAck")
}
Expand Down Expand Up @@ -1579,6 +1593,7 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet {
a.log.Tracef("[%s] FwdTSN: %s", a.name, c.String())

if !a.useForwardTSN {
a.log.Warn("[%s] received FwdTSN but not enabled")
// Return an error chunk
cerr := &chunkError{
errorCauses: []errorCause{&errorCauseUnrecognizedChunkType{}},
Expand All @@ -1599,6 +1614,8 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet {
// send a SACK to its peer (the sender of the FORWARD TSN) since such a
// duplicate may indicate the previous SACK was lost in the network.

a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d\n",
a.name, c.newCumulativeTSN, a.peerLastTSN)
if sna32LTE(c.newCumulativeTSN, a.peerLastTSN) {
a.log.Tracef("[%s] sending ack on Forward TSN", a.name)
a.ackState = ackStateImmediate
Expand Down Expand Up @@ -1860,6 +1877,7 @@ func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error {

// The caller should hold the lock.
func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) {
a.log.Debugf("[%s] PR check: in, useForwardTSN=%v", a.name, a.useForwardTSN)
if !a.useForwardTSN {
return
}
Expand All @@ -1880,6 +1898,8 @@ func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) {
}
}
s.lock.RUnlock()
} else {
a.log.Errorf("[%s] stream %d not found)", a.name, c.streamIdentifier)
}
}

Expand Down
61 changes: 61 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2420,3 +2420,64 @@ func TestStats(t *testing.T) {
assert.Equal(t, conn.bytesReceived, a.BytesReceived())
assert.Equal(t, conn.bytesSent, a.BytesSent())
}

func TestAssocHandleInit(t *testing.T) {
loggerFactory := logging.NewDefaultLoggerFactory()

handleInitTest := func(t *testing.T, initialState uint32, expectErr bool) {
a := createAssociation(Config{
NetConn: &dumbConn{},
LoggerFactory: loggerFactory,
})
a.setState(initialState)
pkt := &packet{
sourcePort: 5001,
destinationPort: 5002,
}
init := &chunkInit{}
init.initialTSN = 1234
init.numOutboundStreams = 1001
init.numInboundStreams = 1002
init.initiateTag = 5678
init.advertisedReceiverWindowCredit = 512 * 1024
setSupportedExtensions(&init.chunkInitCommon)

_, err := a.handleInit(pkt, init)
if expectErr {
assert.Error(t, err, "should fail")
return
}
assert.NoError(t, err, "should succeed")
assert.Equal(t, uint32(init.initialTSN-1), a.peerLastTSN, "should match")
assert.Equal(t, uint16(1001), a.myMaxNumOutboundStreams, "should match")
assert.Equal(t, uint16(1002), a.myMaxNumInboundStreams, "should match")
assert.Equal(t, uint32(5678), a.peerVerificationTag, "should match")
assert.Equal(t, pkt.sourcePort, a.destinationPort, "should match")
assert.Equal(t, pkt.destinationPort, a.sourcePort, "should match")
assert.True(t, a.useForwardTSN, "should be set to true")
}

t.Run("normal", func(t *testing.T) {
handleInitTest(t, closed, false)
})

t.Run("unexpected state established", func(t *testing.T) {
handleInitTest(t, established, true)
})

t.Run("unexpected state shutdownAckSent", func(t *testing.T) {
handleInitTest(t, shutdownAckSent, true)
})

t.Run("unexpected state shutdownPending", func(t *testing.T) {
handleInitTest(t, shutdownPending, true)
})

t.Run("unexpected state shutdownReceived", func(t *testing.T) {
handleInitTest(t, shutdownReceived, true)
})

t.Run("unexpected state shutdownSent", func(t *testing.T) {
handleInitTest(t, shutdownSent, true)
})
}
2 changes: 2 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint3
// setReliabilityParams sets reliability parameters for this stream.
// The caller should hold the lock.
func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) {
s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d",
s.name, !unordered, relType, relVal)
s.unordered = unordered
s.reliabilityType = relType
s.reliabilityValue = relVal
Expand Down

0 comments on commit 201f752

Please sign in to comment.