From f784cc1c73be5364714c300cc8dc1c4676232952 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Mon, 26 Feb 2024 14:19:04 -0500 Subject: [PATCH] Make Client always commit params --- datachannel.go | 20 +++++++++----------- datachannel_test.go | 9 ++------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/datachannel.go b/datachannel.go index 8ed18fb..f3e811e 100644 --- a/datachannel.go +++ b/datachannel.go @@ -74,12 +74,12 @@ type Config struct { LoggerFactory logging.LoggerFactory } -func newDataChannel(stream *sctp.Stream, config *Config) (*DataChannel, error) { +func newDataChannel(stream *sctp.Stream, config *Config) *DataChannel { return &DataChannel{ Config: *config, stream: stream, log: config.LoggerFactory.NewLogger("datachannel"), - }, nil + } } // Dial opens a data channels over SCTP @@ -118,7 +118,12 @@ func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) { return nil, fmt.Errorf("failed to send ChannelOpen %w", err) } } - return newDataChannel(stream, config) + dc := newDataChannel(stream, config) + + if err := dc.commitReliabilityParams(); err != nil { + return nil, err + } + return dc, nil } // Accept is used to accept incoming data channels over SCTP @@ -167,10 +172,7 @@ func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) { config.Label = string(openMsg.Label) config.Protocol = string(openMsg.Protocol) - dataChannel, err := newDataChannel(stream, config) - if err != nil { - return nil, err - } + dataChannel := newDataChannel(stream, config) err = dataChannel.writeDataChannelAck() if err != nil { @@ -283,10 +285,6 @@ func (c *DataChannel) handleDCEP(data []byte) error { switch msg := msg.(type) { case *channelAck: - c.log.Debug("Received DATA_CHANNEL_ACK") - if err = c.commitReliabilityParams(); err != nil { - return err - } c.onOpenComplete() default: return fmt.Errorf("%w %v", ErrInvalidMessageType, msg) diff --git a/datachannel_test.go b/datachannel_test.go index 83cc017..9c4965f 100644 --- a/datachannel_test.go +++ b/datachannel_test.go @@ -384,21 +384,16 @@ func TestDataChannel(t *testing.T) { assert.True(t, reflect.DeepEqual(dc0.Config, *cfg), "local config should match") assert.True(t, reflect.DeepEqual(dc1.Config, *cfg), "remote config should match") - err = dc0.commitReliabilityParams() - assert.NoError(t, err, "should succeed") - err = dc1.commitReliabilityParams() - assert.NoError(t, err, "should succeed") - var n int binary.BigEndian.PutUint32(sbuf, 1) n, err = dc0.WriteDataChannel(sbuf, true) - assert.Nil(t, err, "Read() should succeed") + assert.Nil(t, err, "Write() should succeed") assert.Equal(t, len(sbuf), n, "data length should match") binary.BigEndian.PutUint32(sbuf, 2) n, err = dc0.WriteDataChannel(sbuf, true) - assert.Nil(t, err, "Read() should succeed") + assert.Nil(t, err, "Write() should succeed") assert.Equal(t, len(sbuf), n, "data length should match") time.Sleep(100 * time.Millisecond)