Skip to content

Commit

Permalink
add a context to NewStream, remove the NewStreamTimeout
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Dec 19, 2020
1 parent 67680fb commit d561095
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 27 deletions.
5 changes: 3 additions & 2 deletions benchmarks_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multiplex

import (
"context"
"io"
"math/rand"
"net"
Expand Down Expand Up @@ -64,7 +65,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {

streamPairs := make([][]*Stream, 0)
for i := 0; i < mp; i++ {
sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -190,7 +191,7 @@ func benchmarkPackets(b *testing.B, msgs [][]byte) {
func benchmarkPacketsWithConn(b *testing.B, parallelism int, msgs [][]byte, mpa, mpb *Multiplex) {
streamPairs := make([][]*Stream, 0)
for i := 0; i < parallelism*runtime.GOMAXPROCS(0); i++ {
sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
b.Error(err)
}
Expand Down
13 changes: 6 additions & 7 deletions multiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ var errTimeout = timeout{}
var errStreamClosed = errors.New("stream closed")

var (
NewStreamTimeout = time.Minute
ResetStreamTimeout = 2 * time.Minute

WriteCoalesceDelay = 100 * time.Microsecond
Expand Down Expand Up @@ -291,12 +290,12 @@ func (mp *Multiplex) nextChanID() uint64 {
}

// NewStream creates a new stream.
func (mp *Multiplex) NewStream() (*Stream, error) {
return mp.NewNamedStream("")
func (mp *Multiplex) NewStream(ctx context.Context) (*Stream, error) {
return mp.NewNamedStream(ctx,"")
}

// NewNamedStream creates a new named stream.
func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
func (mp *Multiplex) NewNamedStream(ctx context.Context, name string) (*Stream, error) {
mp.chLock.Lock()

// We could call IsClosed but this is faster (given that we already have
Expand All @@ -319,11 +318,11 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
mp.channels[s.id] = s
mp.chLock.Unlock()

ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout)
defer cancel()

err := mp.sendMsg(ctx.Done(), nil, header, []byte(name))
if err != nil {
if err == errTimeout {
return nil, ctx.Err()
}
return nil, err
}

Expand Down
63 changes: 45 additions & 18 deletions multiplex_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multiplex

import (
"context"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -28,7 +29,7 @@ func TestSlowReader(t *testing.T) {

mes := []byte("Hello world")

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -85,7 +86,7 @@ func TestBasicStreams(t *testing.T) {
}
}()

s, err := mpa.NewStream()
s, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand All @@ -105,6 +106,32 @@ func TestBasicStreams(t *testing.T) {
mpb.Close()
}

func TestOpenStreamDeadline(t *testing.T) {
a, _ := net.Pipe()
mp := NewMultiplex(a, false)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
var counter int
var deadlineExceeded bool
for i := 0; i < 1000; i++ {
if _, err := mp.NewStream(ctx); err != nil {
if err != context.DeadlineExceeded {
t.Fatalf("expected the error to be a deadline error, got %s", err.Error())
}
deadlineExceeded = true
break
}
counter++
}
if counter == 0 {
t.Fatal("expected at least some streams to open successfully")
}
if !deadlineExceeded {
t.Fatal("expected a deadline error to occur at some point")
}
}

func TestWriteAfterClose(t *testing.T) {
a, b := net.Pipe()

Expand Down Expand Up @@ -134,7 +161,7 @@ func TestWriteAfterClose(t *testing.T) {
close(done)
}()

s, err := mpa.NewStream()
s, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -178,7 +205,7 @@ func TestEcho(t *testing.T) {
io.Copy(s, s)
}()

s, err := mpa.NewStream()
s, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -214,7 +241,7 @@ func TestFullClose(t *testing.T) {

mes := make([]byte, 40960)
rand.Read(mes)
s, err := mpa.NewStream()
s, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -281,7 +308,7 @@ func TestHalfClose(t *testing.T) {
}
}()

s, err := mpa.NewStream()
s, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -339,7 +366,7 @@ func TestClosing(t *testing.T) {
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)

_, err := mpb.NewStream()
_, err := mpb.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -373,7 +400,7 @@ func TestReset(t *testing.T) {
defer mpa.Close()
defer mpb.Close()

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -425,7 +452,7 @@ func TestCancelRead(t *testing.T) {
defer mpa.Close()
defer mpb.Close()

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -486,7 +513,7 @@ func TestCancelWrite(t *testing.T) {
defer mpa.Close()
defer mpb.Close()

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -560,7 +587,7 @@ func TestCancelReadAfterWrite(t *testing.T) {
defer mpa.Close()
defer mpb.Close()

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -603,7 +630,7 @@ func TestResetAfterEOF(t *testing.T) {
defer mpa.Close()
defer mpb.Close()

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -632,7 +659,7 @@ func TestOpenAfterClose(t *testing.T) {
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand All @@ -646,12 +673,12 @@ func TestOpenAfterClose(t *testing.T) {

mpa.Close()

s, err := mpa.NewStream()
s, err := mpa.NewStream(context.Background())
if err == nil || s != nil {
t.Fatal("opened a stream on a closed connection")
}

s, err = mpa.NewStream()
s, err = mpa.NewStream(context.Background())
if err == nil || s != nil {
t.Fatal("opened a stream on a closed connection")
}
Expand All @@ -668,7 +695,7 @@ func TestDeadline(t *testing.T) {
defer mpa.Close()
defer mpb.Close()

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand All @@ -694,7 +721,7 @@ func TestReadAfterClose(t *testing.T) {
defer mpa.Close()
defer mpb.Close()

sa, err := mpa.NewStream()
sa, err := mpa.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -735,7 +762,7 @@ func TestFuzzCloseStream(t *testing.T) {
streams := make([]*Stream, 100)
for i := range streams {
var err error
streams[i], err = mpb.NewStream()
streams[i], err = mpb.NewStream(context.Background())
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit d561095

Please sign in to comment.