diff --git a/coordinator/protobuf_client.go b/coordinator/protobuf_client.go index 561902bb4ad..815c7d5dc29 100644 --- a/coordinator/protobuf_client.go +++ b/coordinator/protobuf_client.go @@ -46,7 +46,7 @@ const ( ) func NewProtobufClient(hostAndPort string, writeTimeout time.Duration) *ProtobufClient { - log.Debug("NewProtobufClient: ", hostAndPort) + log.Debug("NewProtobufClient: %s", hostAndPort) return &ProtobufClient{ hostAndPort: hostAndPort, requestBuffer: make(map[uint32]*runningRequest), @@ -179,6 +179,14 @@ func (self *ProtobufClient) MakeRequest(request *protocol.Request, r cluster.Res func (self *ProtobufClient) readResponses() { message := make([]byte, 0, MAX_RESPONSE_SIZE) buff := bytes.NewBuffer(message) + + connErrFn := func(err error) { + log.Error("Error while reading messsage: %s", err.Error()) + self.conn.Close() + self.conn = nil + time.Sleep(200 * time.Millisecond) + } + for !self.stopped { buff.Reset() conn := self.getConnection() @@ -190,21 +198,19 @@ func (self *ProtobufClient) readResponses() { var err error err = binary.Read(conn, binary.LittleEndian, &messageSizeU) if err != nil { - log.Error("Error while reading messsage size: %d", err) - time.Sleep(200 * time.Millisecond) + connErrFn(err) continue } messageSize := int64(messageSizeU) messageReader := io.LimitReader(conn, messageSize) _, err = io.Copy(buff, messageReader) if err != nil { - log.Error("Error while reading message: %d", err) - time.Sleep(200 * time.Millisecond) + connErrFn(err) continue } response, err := protocol.DecodeResponse(buff) if err != nil { - log.Error("error unmarshaling response: %s", err) + log.Error("error unmarshaling response: %s", err.Error()) time.Sleep(200 * time.Millisecond) } else { self.sendResponse(response) @@ -273,6 +279,7 @@ func (self *ProtobufClient) reconnect() net.Conn { self.attempts = 0 } + self.attempts = 0 self.conn = conn log.Info("connected to %s", self.hostAndPort) return conn @@ -286,7 +293,7 @@ func (self *ProtobufClient) peridicallySweepTimedOutRequests() { for k, req := range self.requestBuffer { if req.timeMade.Before(maxAge) { delete(self.requestBuffer, k) - log.Warn("Request timed out: ", req.request) + log.Warn("Request timed out: %v", req.request) } } self.requestBufferLock.Unlock() diff --git a/coordinator/protobuf_client_test.go b/coordinator/protobuf_client_test.go index cc9068957d3..6b5950da876 100644 --- a/coordinator/protobuf_client_test.go +++ b/coordinator/protobuf_client_test.go @@ -90,7 +90,7 @@ func FakeHeartbeatServer() *PingResponseServer { type ProtobufClientSuite struct{} -var _ = gocheck.Suite(&ProtobufClient{}) +var _ = gocheck.Suite(&ProtobufClientSuite{}) func (self *ProtobufClientSuite) BenchmarkSingle(c *gocheck.C) { var HEARTBEAT_TYPE = protocol.Request_HEARTBEAT @@ -109,3 +109,49 @@ func (self *ProtobufClientSuite) BenchmarkSingle(c *gocheck.C) { <-responseChan } } + +func (pcs *ProtobufClientSuite) TestReadResponsesWhenRemoteClosesConnection(c *gocheck.C) { + // Channel used to kill the remote connection + dieCh := make(chan struct{}) + // Channel for remote client to notify test that ProtobufClient connected + connectedCh := make(chan struct{}) + // Channel the remote client will use to tell the test that it has closed + connClosedCh := make(chan struct{}) + + // Remote connection (talking to a ProtobufClient) + handleConnFn := func(conn net.Conn) { + connectedCh <- struct{}{} + <-dieCh + conn.Close() + close(connClosedCh) + } + + // Remote server listening for ProtobufClient connection requests + l, _ := net.Listen("tcp", "127.0.0.1:0") + go func() { + for { + conn, _ := l.Accept() + go handleConnFn(conn) + } + }() + + // Create a ProtobufClient and connect to the remote server we just setup + client := NewProtobufClient(l.Addr().String(), time.Second) + client.Connect() + select { + case <-connectedCh: + case <-time.After(500 * time.Millisecond): + c.Errorf("Waiting for ProtobufClient to connect timed out") + return + } + c.Assert(client, gocheck.NotNil) + c.Assert(client.conn, gocheck.NotNil) + + // Tell remote side to close the connection + dieCh <- struct{}{} + + // Make sure ProtobufClient set the connection to nil + <-connClosedCh + time.Sleep(100 * time.Millisecond) + c.Assert(client.conn, gocheck.IsNil) +}