Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sctp.Association.Abort(reason string) method #183

Merged
merged 1 commit into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ type Association struct {
willSendShutdownAck bool
willSendShutdownComplete bool

willSendAbort bool
willSendAbortCause errorCause

// Reconfig
myNextRSN uint32
reconfigs map[uint32]*chunkReconfig
Expand Down Expand Up @@ -469,6 +472,26 @@ func (a *Association) close() error {
return err
}

// Abort sends the abort packet with user initiated abort and immediately
// closes the connection.
func (a *Association) Abort(reason string) {
a.log.Debugf("[%s] aborting association: %s", a.name, reason)

a.lock.Lock()

a.willSendAbort = true
a.willSendAbortCause = &errorCauseUserInitiatedAbort{
upperLayerAbortReason: []byte(reason),
}

a.lock.Unlock()

a.awakeWriteLoop()

// Wait for readLoop to end
<-a.readLoopCloseCh
}

func (a *Association) closeAllTimers() {
// Close all retransmission & ack timers
a.t1Init.close()
Expand Down Expand Up @@ -829,12 +852,39 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
return rawPackets, ok
}

func (a *Association) gatherAbortPacket() ([]byte, error) {
cause := a.willSendAbortCause

a.willSendAbort = false
a.willSendAbortCause = nil

abort := &chunkAbort{}

if cause != nil {
abort.errorCauses = []errorCause{cause}
}

raw, err := a.createPacket([]chunk{abort}).marshal()

return raw, err
}

// gatherOutbound gathers outgoing packets. The returned bool value set to
// false means the association should be closed down after the final send.
func (a *Association) gatherOutbound() ([][]byte, bool) {
a.lock.Lock()
defer a.lock.Unlock()

if a.willSendAbort {
pkt, err := a.gatherAbortPacket()
if err != nil {
a.log.Warnf("[%s] failed to serialize an abort packet", a.name)
return nil, false
}

return [][]byte{pkt}, false
}

rawPackets := [][]byte{}

if a.controlQueue.size() > 0 {
Expand Down Expand Up @@ -1747,6 +1797,17 @@ func (a *Association) handleShutdownComplete(_ *chunkShutdownComplete) error {
return nil
}

func (a *Association) handleAbort(c *chunkAbort) error {
var errStr string
for _, e := range c.errorCauses {
errStr += fmt.Sprintf("(%s)", e)
}

_ = a.close()

return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr)
}

// createForwardTSN generates ForwardTSN chunk.
// This method will be be called if useForwardTSN is set to false.
// The caller should hold the lock.
Expand Down Expand Up @@ -2251,6 +2312,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error {
return nil
}

isAbort := false

switch c := c.(type) {
case *chunkInit:
packets, err = a.handleInit(p, c)
Expand All @@ -2259,11 +2322,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error {
err = a.handleInitAck(p, c)

case *chunkAbort:
var errStr string
for _, e := range c.errorCauses {
errStr += fmt.Sprintf("(%s)", e)
}
return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr)
isAbort = true
err = a.handleAbort(c)

case *chunkError:
var errStr string
Expand Down Expand Up @@ -2306,6 +2366,10 @@ func (a *Association) handleChunk(p *packet, c chunk) error {

// Log and return, the only condition that is fatal is a ABORT chunk
if err != nil {
if isAbort {
return err
}

a.log.Errorf("Failed to handle chunk: %v", err)
return nil
}
Expand Down
57 changes: 53 additions & 4 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2652,11 +2652,12 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) {
}
}

func TestAssociation_HandlePacketBeforeInit(t *testing.T) {
func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) {
loggerFactory := logging.NewDefaultLoggerFactory()

testCases := map[string]struct {
inputPacket *packet
skipClose bool
}{
"InitAck": {
inputPacket: &packet{
Expand All @@ -2680,6 +2681,8 @@ func TestAssociation_HandlePacketBeforeInit(t *testing.T) {
destinationPort: 1,
chunks: []chunk{&chunkAbort{}},
},
// Prevent "use of close network connection" error on close.
skipClose: true,
},
"CoockeEcho": {
inputPacket: &packet{
Expand Down Expand Up @@ -2774,9 +2777,12 @@ func TestAssociation_HandlePacketBeforeInit(t *testing.T) {
LoggerFactory: loggerFactory,
})
a.init(true)
defer func() {
assert.NoError(t, a.close())
}()

if !testCase.skipClose {
defer func() {
assert.NoError(t, a.close())
}()
}

packet, err := testCase.inputPacket.marshal()
assert.NoError(t, err)
Expand All @@ -2788,3 +2794,46 @@ func TestAssociation_HandlePacketBeforeInit(t *testing.T) {
})
}
}

func TestAssociation_Abort(t *testing.T) {
runtime.GC()
n0 := runtime.NumGoroutine()

defer func() {
runtime.GC()
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

s21, err := a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

testData := []byte("test")

i, err := s11.Write(testData)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)

buf := make([]byte, len(testData))
i, err = s21.Read(buf)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)
assert.Equal(t, testData, buf)

a1.Abort("1234")

// Wait for close read loop channels to prevent flaky tests.
select {
case <-a2.readLoopCloseCh:
case <-time.After(1 * time.Second):
assert.Fail(t, "timed out waiting for a2 read loop to close")
}

i, err = s21.Read(buf)
assert.Equal(t, i, 0, "expected no data read")
assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason")
}
3 changes: 3 additions & 0 deletions error_cause.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ func buildErrorCause(raw []byte) (errorCause, error) {
e = &errorCauseUnrecognizedChunkType{}
case protocolViolation:
e = &errorCauseProtocolViolation{}
case userInitiatedAbort:
e = &errorCauseUserInitiatedAbort{}
default:
return nil, fmt.Errorf("%w: %s", errBuildErrorCaseHandle, c.String())
}

if err := e.unmarshal(raw); err != nil {
return nil, err
}

return e, nil
}

Expand Down
46 changes: 46 additions & 0 deletions error_cause_user_initiated_abort.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package sctp

import (
"fmt"
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to add the following text just to be consistent with other error_cause_xxx.go files?

/*
   This error cause MAY be included in ABORT chunks that are sent
   because of an upper-layer request.  The upper layer can specify an
   Upper Layer Abort Reason that is transported by SCTP transparently
   and MAY be delivered to the upper-layer protocol at the peer.

        0                   1                   2                   3
        0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
       +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
       |         Cause Code=12         |      Cause Length=Variable    |
       +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
       /                    Upper Layer Abort Reason                   /
       \                                                               \
       +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll try to address it over the next couple of days

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

@enobufs enobufs May 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jeremija !

Some auto tests are failing with TestAssociation_HandlePacketBeforeInit/Abort.
From what I can see, a.handleAbort() calls a.close() which sets the state to closed. Then the test case calls a.Close() which complains "use of closed network connection".

I think calling a.close() on the receipt of Abort chunk is right. I believe we could safely skip a.Close() in the test when the test case is "Abort" in TestAssociation_HandlePacketBeforeInit.

Otherwise, it all looks good to me!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something from before but I think this test case name should be changed to something like: "TestAssociation_HandlePacketInCookieWaitState" IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I didn't realize there was a failing test. It should be fixed now. I also renamed the test case.

/*
This error cause MAY be included in ABORT chunks that are sent
because of an upper-layer request. The upper layer can specify an
Upper Layer Abort Reason that is transported by SCTP transparently
and MAY be delivered to the upper-layer protocol at the peer.

0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Cause Code=12 | Cause Length=Variable |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ Upper Layer Abort Reason /
\ \
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
type errorCauseUserInitiatedAbort struct {
errorCauseHeader
upperLayerAbortReason []byte
}

func (e *errorCauseUserInitiatedAbort) marshal() ([]byte, error) {
e.code = userInitiatedAbort
e.errorCauseHeader.raw = e.upperLayerAbortReason
return e.errorCauseHeader.marshal()
}

func (e *errorCauseUserInitiatedAbort) unmarshal(raw []byte) error {
err := e.errorCauseHeader.unmarshal(raw)
if err != nil {
return err
}

e.upperLayerAbortReason = e.errorCauseHeader.raw
return nil
}

// String makes errorCauseUserInitiatedAbort printable
func (e *errorCauseUserInitiatedAbort) String() string {
return fmt.Sprintf("%s: %s", e.errorCauseHeader.String(), e.upperLayerAbortReason)
}