From fef768624fec09d6f64e8f72a798f50e36a295bd Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 6 May 2019 13:29:38 -0700 Subject: [PATCH] fix: reset a stream even if closed remotely Otherwise, the other side may continue _reading_. --- multiplex_test.go | 34 ++++++++++++++++++++++++++++++++++ stream.go | 19 ++++++++++++------- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/multiplex_test.go b/multiplex_test.go index 20608bc..2e194ec 100644 --- a/multiplex_test.go +++ b/multiplex_test.go @@ -8,6 +8,8 @@ import ( "net" "testing" "time" + + streammux "github.com/libp2p/go-stream-muxer" ) func init() { @@ -362,6 +364,38 @@ func TestReset(t *testing.T) { } } +func TestResetAfterEOF(t *testing.T) { + a, b := net.Pipe() + + mpa := NewMultiplex(a, false) + mpb := NewMultiplex(b, true) + + defer mpa.Close() + defer mpb.Close() + + sa, err := mpa.NewStream() + if err != nil { + t.Fatal(err) + } + sb, err := mpb.Accept() + + if err := sa.Close(); err != nil { + t.Fatal(err) + } + + n, err := sb.Read([]byte{0}) + if n != 0 || err != io.EOF { + t.Fatal(err) + } + + sb.Reset() + + n, err = sa.Read([]byte{0}) + if n != 0 || err != streammux.ErrReset { + t.Fatal(err) + } +} + func TestOpenAfterClose(t *testing.T) { a, b := net.Pipe() diff --git a/stream.go b/stream.go index 47f1e79..dfd70d7 100644 --- a/stream.go +++ b/stream.go @@ -203,22 +203,27 @@ func (s *Stream) Close() error { func (s *Stream) Reset() error { s.clLock.Lock() - isClosed := s.isClosed() - if s.closedRemote && isClosed { + + // Don't reset when fully closed. + if s.closedRemote && s.isClosed() { s.clLock.Unlock() return nil } - if !s.closedRemote { - close(s.reset) - // We generally call this to tell the other side to go away. No point in waiting around. - go s.mp.sendMsg(context.Background(), s.id.header(resetTag), nil) + // Don't reset twice. + select { + case <-s.reset: + s.clLock.Unlock() + return nil + default: } + close(s.reset) s.doCloseLocal() - s.closedRemote = true + go s.mp.sendMsg(context.Background(), s.id.header(resetTag), nil) + s.clLock.Unlock() s.mp.chLock.Lock()