diff --git a/inmem_transport.go b/inmem_transport.go index 02a7a0f9bf3..bb42eeb68b9 100644 --- a/inmem_transport.go +++ b/inmem_transport.go @@ -147,11 +147,17 @@ func (i *InmemTransport) makeRPC(target ServerAddress, args interface{}, r io.Re // Send the RPC over respCh := make(chan RPCResponse) - peer.consumerCh <- RPC{ + req := RPC{ Command: args, Reader: r, RespChan: respCh, } + select { + case peer.consumerCh <- req: + case <-time.After(timeout): + err = fmt.Errorf("send timed out") + return + } // Wait for a response select { diff --git a/inmem_transport_test.go b/inmem_transport_test.go index 82c95348a58..2ac8709a0fb 100644 --- a/inmem_transport_test.go +++ b/inmem_transport_test.go @@ -1,7 +1,9 @@ package raft import ( + "github.com/stretchr/testify/require" "testing" + "time" ) func TestInmemTransportImpl(t *testing.T) { @@ -16,3 +18,66 @@ func TestInmemTransportImpl(t *testing.T) { t.Fatalf("InmemTransport is not a WithPeers Transport") } } + +func TestInmemTransportWriteTimeout(t *testing.T) { + // InmemTransport should timeout if the other end has gone away + // when it tries to send a request. + // Use unbuffered channels so that we can see the write failing + // without having to contrive to fill up the buffer first. + timeout := 10 * time.Millisecond + t1 := &InmemTransport{ + consumerCh: make(chan RPC), + localAddr: NewInmemAddr(), + peers: make(map[ServerAddress]*InmemTransport), + timeout: timeout, + } + t2 := &InmemTransport{ + consumerCh: make(chan RPC), + localAddr: NewInmemAddr(), + peers: make(map[ServerAddress]*InmemTransport), + timeout: timeout, + } + a2 := t2.LocalAddr() + t1.Connect(a2, t2) + + stop := make(chan struct{}) + stopped := make(chan struct{}) + go func() { + defer close(stopped) + var i uint64 + for { + select { + case <-stop: + return + case rpc := <-t2.Consumer(): + i++ + rpc.Respond(&AppendEntriesResponse{ + Success: true, + LastLog: i, + }, nil) + } + } + }() + + var resp AppendEntriesResponse + // Sanity check that sending is working before stopping the + // responder. + err := t1.AppendEntries("server1", a2, &AppendEntriesRequest{}, &resp) + NoErr(err, t) + require.True(t, resp.LastLog == 1) + + close(stop) + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for responder to stop") + } + + err = t1.AppendEntries("server1", a2, &AppendEntriesRequest{}, &resp) + if err == nil { + t.Fatalf("expected AppendEntries to time out") + } + if err.Error() != "send timed out" { + t.Fatalf("unexpected error: %v", err) + } +}