diff --git a/association.go b/association.go index bcd3ec74..6a293746 100644 --- a/association.go +++ b/association.go @@ -16,11 +16,11 @@ import ( ) const ( - receiveMTU uint32 = 8192 // MTU for inbound packet (from DTLS) - initialMTU uint32 = 1228 // initial MTU for outgoing packets (to DTLS) - maxReceiveBufferSize uint32 = 64 * 1024 - commonHeaderSize uint32 = 12 - dataChunkHeaderSize uint32 = 16 + receiveMTU uint32 = 8192 // MTU for inbound packet (from DTLS) + initialMTU uint32 = 1228 // initial MTU for outgoing packets (to DTLS) + initialRecvBufSize uint32 = 1024 * 1024 + commonHeaderSize uint32 = 12 + dataChunkHeaderSize uint32 = 16 ) // association state enums @@ -109,15 +109,14 @@ type Association struct { netConn net.Conn - peerVerificationTag uint32 - myVerificationTag uint32 - state uint32 - myNextTSN uint32 // nextTSN - peerLastTSN uint32 // lastRcvdTSN - minTSN2MeasureRTT uint32 // for RTT measurement - willRetransmitDataChunks bool - willSendForwardTSN bool - willRetransmitFast bool + peerVerificationTag uint32 + myVerificationTag uint32 + state uint32 + myNextTSN uint32 // nextTSN + peerLastTSN uint32 // lastRcvdTSN + minTSN2MeasureRTT uint32 // for RTT measurement + willSendForwardTSN bool + willRetransmitFast bool // Reconfig myNextRSN uint32 @@ -141,6 +140,7 @@ type Association struct { useForwardTSN bool // Congestion control parameters + maxReceiveBufferSize uint32 cwnd uint32 // my congestion window size rwnd uint32 // calculated peer's receiver windows size ssthresh uint32 // slow start threshold @@ -184,8 +184,9 @@ type Association struct { // Config collects the arguments to createAssociation construction into // a single structure type Config struct { - NetConn net.Conn - LoggerFactory logging.LoggerFactory + NetConn net.Conn + MaxReceiveBufferSize uint32 + LoggerFactory logging.LoggerFactory } // Server accepts a SCTP stream over a conn @@ -224,9 +225,17 @@ func createAssociation(config Config) *Association { rs := rand.NewSource(time.Now().UnixNano()) r := rand.New(rs) + var maxReceiveBufferSize uint32 + if config.MaxReceiveBufferSize == 0 { + maxReceiveBufferSize = initialRecvBufSize + } else { + maxReceiveBufferSize = config.MaxReceiveBufferSize + } + tsn := r.Uint32() a := &Association{ netConn: config.NetConn, + maxReceiveBufferSize: maxReceiveBufferSize, myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, payloadQueue: newPayloadQueue(), @@ -288,7 +297,7 @@ func (a *Association) init(isClient bool) { init.numOutboundStreams = a.myMaxNumOutboundStreams init.numInboundStreams = a.myMaxNumInboundStreams init.initiateTag = a.myVerificationTag - init.advertisedReceiverWindowCredit = maxReceiveBufferSize + init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize setSupportedExtensions(&init.chunkInitCommon) a.storedInit = init @@ -507,17 +516,13 @@ func (a *Association) gatherOutbound() [][]byte { state := a.getState() if state == established { - if a.willRetransmitDataChunks { - a.willRetransmitDataChunks = false - for _, p := range a.getDataPacketsToRetransmit() { - raw, err := p.marshal() - if err != nil { - a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) - continue - } - a.log.Debugf("[%s] retransmitting %d bytes", a.name, len(raw)) - rawPackets = append(rawPackets, raw) + for _, p := range a.getDataPacketsToRetransmit() { + raw, err := p.marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) + continue } + rawPackets = append(rawPackets, raw) } // Pop unsent data chunks from the pending queue to send as much as @@ -789,7 +794,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { initAck.numOutboundStreams = a.myMaxNumOutboundStreams initAck.numInboundStreams = a.myMaxNumInboundStreams initAck.initiateTag = a.myVerificationTag - initAck.advertisedReceiverWindowCredit = maxReceiveBufferSize + initAck.advertisedReceiverWindowCredit = a.maxReceiveBufferSize if a.myCookie == nil { a.myCookie = newRandomStateCookie() @@ -1021,10 +1026,10 @@ func (a *Association) getMyReceiverWindowCredit() uint32 { bytesQueued += uint32(s.getNumBytesInReassemblyQueue()) } - if bytesQueued >= maxReceiveBufferSize { + if bytesQueued >= a.maxReceiveBufferSize { return 0 } - return maxReceiveBufferSize - bytesQueued + return a.maxReceiveBufferSize - bytesQueued } // OpenStream opens a stream @@ -1222,8 +1227,8 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // path MTU. if !a.inFastRecovery && a.pendingQueue.size() > 0 { - //a.cwnd += min32(uint32(totalBytesAcked), a.cwnd) - a.cwnd += min32(uint32(totalBytesAcked), a.mtu) + a.cwnd += min32(uint32(totalBytesAcked), a.cwnd) // TCP way + //a.cwnd += min32(uint32(totalBytesAcked), a.mtu) // SCTP way (slow) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)", a.name, a.cwnd, a.ssthresh, totalBytesAcked) } else { @@ -1687,12 +1692,9 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 // is 0), the data sender can always have one DATA chunk in flight to // the receiver if allowed by cwnd (see rule B, below). - usingFullWindow := (a.inflightQueue.getNumBytes() == 0) - for { c := a.pendingQueue.peek() if c == nil { - usingFullWindow = false break // no more pending data } @@ -1742,15 +1744,10 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 // min(cwnd,rnwd) would end up with generating SACK, on the remote, for // every single DATA chunks when sending large amount of data at once. // In order to overcome this situation, here we set I-bit to true for - // the last chunk only when the following condition is met: - // - // - When sending full size of windlow (= min(cwnd, rwnd)) + // the last chunk only. lastChunk := chunks[len(chunks)-1] - if usingFullWindow { - //a.log.Debugf("sending tsn=%d with immediateSack: total_chunks=%d", lastChunk.tsn, len(chunks)) - lastChunk.immediateSack = true - } + lastChunk.immediateSack = true } } @@ -1849,7 +1846,7 @@ func (a *Association) getDataPacketsToRetransmit() []*packet { break // end of pending data } - if c.acked || c.abandoned { + if !c.retransmit { continue } @@ -1860,6 +1857,9 @@ func (a *Association) getDataPacketsToRetransmit() []*packet { break } + // reset the retransmit flag not to retransmit again before the next + // t3-rtx timer fires + c.retransmit = false bytesToSend += len(c.userData) c.nSent++ @@ -2013,7 +2013,7 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.cwnd, a.ssthresh) - a.willRetransmitDataChunks = true + a.inflightQueue.markAllToRetrasmit() a.awakeWriteLoop() return diff --git a/association_test.go b/association_test.go index dbb0bda6..fd71f6f5 100644 --- a/association_test.go +++ b/association_test.go @@ -230,7 +230,7 @@ func (c *dumbConn) SetWriteDeadline(t time.Time) error { //////////////////////////////////////////////////////////////////////////////// -func createNewAssociationPair(br *test.Bridge, ackMode int) (*Association, *Association, error) { +func createNewAssociationPair(br *test.Bridge, ackMode int, recvBufSize uint32) (*Association, *Association, error) { var a0, a1 *Association var err0, err1 error loggerFactory := logging.NewDefaultLoggerFactory() @@ -240,15 +240,17 @@ func createNewAssociationPair(br *test.Bridge, ackMode int) (*Association, *Asso go func() { a0, err0 = Client(Config{ - NetConn: br.GetConn0(), - LoggerFactory: loggerFactory, + NetConn: br.GetConn0(), + MaxReceiveBufferSize: recvBufSize, + LoggerFactory: loggerFactory, }) handshake0Ch <- true }() go func() { a1, err1 = Client(Config{ - NetConn: br.GetConn1(), - LoggerFactory: loggerFactory, + NetConn: br.GetConn1(), + MaxReceiveBufferSize: recvBufSize, + LoggerFactory: loggerFactory, }) handshake1Ch <- true }() @@ -415,7 +417,7 @@ func TestAssocReliable(t *testing.T) { const msg = "ABC" br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -454,7 +456,7 @@ func TestAssocReliable(t *testing.T) { var ppi PayloadProtocolIdentifier br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -508,7 +510,7 @@ func TestAssocReliable(t *testing.T) { var ppi PayloadProtocolIdentifier br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -549,7 +551,7 @@ func TestAssocReliable(t *testing.T) { var ppi PayloadProtocolIdentifier br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -590,7 +592,7 @@ func TestAssocReliable(t *testing.T) { var ppi PayloadProtocolIdentifier br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -649,7 +651,7 @@ func TestAssocReliable(t *testing.T) { var ppi PayloadProtocolIdentifier br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -703,7 +705,7 @@ func TestAssocReliable(t *testing.T) { const msg = "Hello" br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -765,7 +767,7 @@ func TestAssocUnreliable(t *testing.T) { const si uint16 = 1 br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -816,7 +818,7 @@ func TestAssocUnreliable(t *testing.T) { const si uint16 = 1 br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -869,7 +871,7 @@ func TestAssocUnreliable(t *testing.T) { const si uint16 = 2 br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -920,7 +922,7 @@ func TestAssocUnreliable(t *testing.T) { const si uint16 = 1 br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -971,7 +973,7 @@ func TestAssocUnreliable(t *testing.T) { const si uint16 = 3 br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -1023,7 +1025,7 @@ func TestAssocUnreliable(t *testing.T) { const si uint16 = 3 br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -1585,7 +1587,7 @@ func TestAssocT3RtxTimer(t *testing.T) { var ppi PayloadProtocolIdentifier br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -1647,7 +1649,7 @@ func TestAssocCongestionControl(t *testing.T) { var ppi PayloadProtocolIdentifier br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNormal) + a0, a1, err := createNewAssociationPair(br, ackModeNormal, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -1709,6 +1711,7 @@ func TestAssocCongestionControl(t *testing.T) { }) t.Run("Congestion Avoidance", func(t *testing.T) { + const maxReceiveBufferSize uint32 = 64 * 1024 const si uint16 = 6 const nPacketsToSend = 2000 var n int @@ -1718,7 +1721,7 @@ func TestAssocCongestionControl(t *testing.T) { br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNormal) + a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -1792,6 +1795,7 @@ func TestAssocCongestionControl(t *testing.T) { // This is to test even rwnd becomes 0, sender should be able to send a zero window probe // on T3-rtx retramission timeout to complete receiving all the packets. t.Run("Slow reader", func(t *testing.T) { + const maxReceiveBufferSize uint32 = 64 * 1024 const si uint16 = 6 nPacketsToSend := int(math.Floor(float64(maxReceiveBufferSize)/1000.0)) * 2 var n int @@ -1801,7 +1805,7 @@ func TestAssocCongestionControl(t *testing.T) { br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, maxReceiveBufferSize) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -1893,7 +1897,7 @@ func TestAssocDelayedAck(t *testing.T) { br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeAlwaysDelay) + a0, a1, err := createNewAssociationPair(br, ackModeAlwaysDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -1960,7 +1964,7 @@ func TestAssocReset(t *testing.T) { const msg = "ABC" br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -2018,7 +2022,7 @@ func TestAssocReset(t *testing.T) { const msg = "ABC" br := test.NewBridge() - a0, a1, err := createNewAssociationPair(br, ackModeNoDelay) + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } diff --git a/chunk_payload_data.go b/chunk_payload_data.go index 374727f7..01c5102d 100644 --- a/chunk_payload_data.go +++ b/chunk_payload_data.go @@ -64,6 +64,10 @@ type chunkPayloadData struct { since time.Time nSent uint32 // number of transmission made for this chunk abandoned bool + + // Retransmission flag set when T1-RTX timeout occurred and this + // chunk is still in the inflight queue + retransmit bool } const ( diff --git a/payload_queue.go b/payload_queue.go index 139103d0..1877d579 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -160,6 +160,15 @@ func (q *payloadQueue) getLastTSNReceived() (uint32, bool) { return q.sorted[qlen-1], true } +func (q *payloadQueue) markAllToRetrasmit() { + for _, c := range q.chunkMap { + if c.acked || c.abandoned { + continue + } + c.retransmit = true + } +} + func (q *payloadQueue) getNumBytes() int { return q.nBytes } diff --git a/payload_queue_test.go b/payload_queue_test.go index a1d504f7..048fd23c 100644 --- a/payload_queue_test.go +++ b/payload_queue_test.go @@ -113,4 +113,23 @@ func TestPayloadQueue(t *testing.T) { assert.True(t, ok, "should be false") assert.Equal(t, uint32(21), tsn, "should match") }) + + t.Run("markAllToRetrasmit", func(t *testing.T) { + pq := newPayloadQueue() + for i := 0; i < 3; i++ { + pq.push(makePayload(uint32(i+1), 10), 0) + } + pq.markAsAcked(2) + pq.markAllToRetrasmit() + + c, ok := pq.get(1) + assert.True(t, ok, "should be true") + assert.True(t, c.retransmit, "should be marked as retransmit") + c, ok = pq.get(2) + assert.True(t, ok, "should be true") + assert.False(t, c.retransmit, "should NOT be marked as retransmit") + c, ok = pq.get(3) + assert.True(t, ok, "should be true") + assert.True(t, c.retransmit, "should be marked as retransmit") + }) } diff --git a/vnet_test.go b/vnet_test.go index 714953eb..16edda8f 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -141,6 +141,7 @@ func testRwndFull(t *testing.T, unordered bool) { shutDownClient := make(chan struct{}) shutDownServer := make(chan struct{}) + maxReceiveBufferSize := uint32(64 * 1024) msgSize := int(float32(maxReceiveBufferSize)/2) + int(initialMTU) msg := make([]byte, msgSize) rand.Read(msg) // nolint:errcheck,gosec @@ -159,8 +160,9 @@ func testRwndFull(t *testing.T, unordered bool) { // server association assoc, err := Server(Config{ - NetConn: conn, - LoggerFactory: loggerFactory, + NetConn: conn, + MaxReceiveBufferSize: maxReceiveBufferSize, + LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return @@ -232,8 +234,9 @@ func testRwndFull(t *testing.T, unordered bool) { // client association assoc, err := Client(Config{ - NetConn: conn, - LoggerFactory: loggerFactory, + NetConn: conn, + MaxReceiveBufferSize: maxReceiveBufferSize, + LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return