Skip to content

Commit

Permalink
Fix race condition with graceful shutdown of MockBroker
Browse files Browse the repository at this point in the history
- use Logger instead of Logf in unit test to stay consistent
- add MmockBroker.WaitForExpectations for graceful shutdown
  • Loading branch information
slaunay committed Jul 13, 2016
1 parent f9642ad commit b7f401f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
13 changes: 8 additions & 5 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sarama
import (
"fmt"
"testing"
"time"

"github.com/rcrowley/go-metrics"
)
Expand Down Expand Up @@ -55,12 +56,9 @@ func TestBrokerAccessors(t *testing.T) {

func TestSimpleBrokerCommunication(t *testing.T) {
for _, tt := range brokerTestTable {
t.Log("Testing broker communication for", tt.name)
Logger.Printf("Testing broker communication for %s", tt.name)
mb := NewMockBroker(t, 0)
// Do not add expectation for ProduceRequest (No Response)
if len(tt.response) != 0 {
mb.Returns(&mockEncoder{tt.response})
}
mb.Returns(&mockEncoder{tt.response})
broker := NewBroker(mb.Addr())
// Set the broker id in order to validate local broker metrics
broker.id = 0
Expand All @@ -77,6 +75,11 @@ func TestSimpleBrokerCommunication(t *testing.T) {
if err != nil {
t.Error(err)
}
// Wait up to 500 ms for the remote broker to process requests
// in order to have consistent metrics
if err := mb.WaitForExpectations(500 * time.Millisecond); err != nil {
t.Error(err)
}
mb.Close()
validateBrokerMetrics(t, broker, mb)
}
Expand Down
55 changes: 40 additions & 15 deletions mockbroker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sarama
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -50,10 +51,12 @@ type MockBroker struct {
closing chan none
stopper chan none
expectations chan encoder
done sync.WaitGroup
listener net.Listener
t TestReporter
latency time.Duration
handler requestHandlerFunc
origHandler bool
history []*RequestResponse
lock sync.Mutex
}
Expand Down Expand Up @@ -116,6 +119,21 @@ func (b *MockBroker) Addr() string {
return b.listener.Addr().String()
}

// Wait for the remaining expectations to be consumed or that the timeout expires
func (b *MockBroker) WaitForExpectations(timeout time.Duration) error {
c := make(chan none)
go func() {
b.done.Wait()
close(c)
}()
select {
case <-c:
return nil
case <-time.After(timeout):
return errors.New(fmt.Sprintf("Not all expectations have been honoured after %v", timeout))
}
}

// Close terminates the broker blocking until it stops internal goroutines and
// releases all resources.
func (b *MockBroker) Close() {
Expand All @@ -137,6 +155,7 @@ func (b *MockBroker) Close() {
func (b *MockBroker) setHandler(handler requestHandlerFunc) {
b.lock.Lock()
b.handler = handler
b.origHandler = false
b.lock.Unlock()
}

Expand Down Expand Up @@ -196,6 +215,7 @@ func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup)
}

b.lock.Lock()
originalHandlerUsed := b.origHandler
res := b.handler(req)
requestResponse := RequestResponse{req.body, res, bytesRead, 0}
b.history = append(b.history, &requestResponse)
Expand All @@ -212,23 +232,25 @@ func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup)
b.serverError(err)
break
}
if len(encodedRes) == 0 {
continue
}

binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4))
binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID))
if _, err = conn.Write(resHeader); err != nil {
b.serverError(err)
break
if len(encodedRes) != 0 {
binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4))
binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID))
if _, err = conn.Write(resHeader); err != nil {
b.serverError(err)
break
}
if _, err = conn.Write(encodedRes); err != nil {
b.serverError(err)
break
}
b.lock.Lock()
requestResponse.ResponseSize = len(resHeader) + len(encodedRes)
b.lock.Unlock()
}
if _, err = conn.Write(encodedRes); err != nil {
b.serverError(err)
break
// Prevent negative wait group in case we are using a custom handler
if originalHandlerUsed {
b.done.Done()
}
b.lock.Lock()
requestResponse.ResponseSize = len(resHeader) + len(encodedRes)
b.lock.Unlock()
}
Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
}
Expand Down Expand Up @@ -280,8 +302,10 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
t: t,
brokerID: brokerID,
expectations: make(chan encoder, 512),
done: sync.WaitGroup{},
}
broker.handler = broker.defaultRequestHandler
broker.origHandler = true

broker.listener, err = net.Listen("tcp", addr)
if err != nil {
Expand All @@ -304,5 +328,6 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
}

func (b *MockBroker) Returns(e encoder) {
b.done.Add(1)
b.expectations <- e
}

0 comments on commit b7f401f

Please sign in to comment.