Skip to content

Commit

Permalink
Add "end to end" tests for variable meta data round robin functionality.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kami committed Oct 21, 2020
1 parent a59228f commit f46a37e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
40 changes: 35 additions & 5 deletions internal/helloworld/greeter_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

context "golang.org/x/net/context"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
)

Expand Down Expand Up @@ -36,6 +37,7 @@ type Greeter struct {
mutex *sync.RWMutex
callCounts map[CallType]int
calls map[CallType][][]*HelloRequest
metadata map[CallType][][]metadata.MD
}

func randomSleep() {
Expand All @@ -49,22 +51,32 @@ func (s *Greeter) recordCall(ct CallType) int {

s.callCounts[ct]++
var messages []*HelloRequest
var metadataItems []metadata.MD
s.calls[ct] = append(s.calls[ct], messages)
s.metadata[ct] = append(s.metadata[ct], metadataItems)

return len(s.calls[ct]) - 1
}

func (s *Greeter) recordMessage(ct CallType, callIdx int, msg *HelloRequest) {
func (s *Greeter) recordMessageAndMetadata(ct CallType, callIdx int, msg *HelloRequest, ctx context.Context) {
s.mutex.Lock()
defer s.mutex.Unlock()

s.calls[ct][callIdx] = append(s.calls[ct][callIdx], msg)

var md metadata.MD

if ctx != nil {
md, _ = metadata.FromIncomingContext(ctx)
}

s.metadata[ct][callIdx] = append(s.metadata[ct][callIdx], md)
}

// SayHello implements helloworld.GreeterServer
func (s *Greeter) SayHello(ctx context.Context, in *HelloRequest) (*HelloReply, error) {
callIdx := s.recordCall(Unary)
s.recordMessage(Unary, callIdx, in)
s.recordMessageAndMetadata(Unary, callIdx, in, ctx)

randomSleep()

Expand All @@ -74,7 +86,7 @@ func (s *Greeter) SayHello(ctx context.Context, in *HelloRequest) (*HelloReply,
// SayHellos lists all hellos
func (s *Greeter) SayHellos(req *HelloRequest, stream Greeter_SayHellosServer) error {
callIdx := s.recordCall(ServerStream)
s.recordMessage(ServerStream, callIdx, req)
s.recordMessageAndMetadata(ServerStream, callIdx, req, nil)

randomSleep()

Expand Down Expand Up @@ -104,7 +116,7 @@ func (s *Greeter) SayHelloCS(stream Greeter_SayHelloCSServer) error {
if err != nil {
return err
}
s.recordMessage(ClientStream, callIdx, in)
s.recordMessageAndMetadata(ClientStream, callIdx, in, nil)
msgCount++
}
}
Expand All @@ -124,7 +136,7 @@ func (s *Greeter) SayHelloBidi(stream Greeter_SayHelloBidiServer) error {
return err
}

s.recordMessage(Bidi, callIdx, in)
s.recordMessageAndMetadata(Bidi, callIdx, in, nil)
msg := "Hello " + in.Name
if err := stream.Send(&HelloReply{Message: msg}); err != nil {
return err
Expand All @@ -148,6 +160,12 @@ func (s *Greeter) ResetCounters() {
s.calls[ClientStream] = make([][]*HelloRequest, 0)
s.calls[Bidi] = make([][]*HelloRequest, 0)

s.metadata = make(map[CallType][][]metadata.MD)
s.metadata[Unary] = make([][]metadata.MD, 0)
s.metadata[ServerStream] = make([][]metadata.MD, 0)
s.metadata[ClientStream] = make([][]metadata.MD, 0)
s.metadata[Bidi] = make([][]metadata.MD, 0)

s.mutex.Unlock()

if s.Stats != nil {
Expand Down Expand Up @@ -180,6 +198,18 @@ func (s *Greeter) GetCalls(key CallType) [][]*HelloRequest {
return nil
}

// GetMetadata gets the received metadata for the specific call type
func (s *Greeter) GetMetadata(key CallType) [][]metadata.MD {
s.mutex.Lock()
val, ok := s.metadata[key]
s.mutex.Unlock()

if ok {
return val
}
return nil
}

// GetConnectionCount gets the connection count
func (s *Greeter) GetConnectionCount() int {
return s.Stats.GetConnectionCount()
Expand Down
37 changes: 36 additions & 1 deletion runner/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ func TestRunUnary(t *testing.T) {
assert.Equal(t, 5, connCount)
})

t.Run("test round-robin c = 2", func(t *testing.T) {
t.Run("test data and metadata round-robin c = 2", func(t *testing.T) {
gs.ResetCounters()

data := make([]map[string]interface{}, 3)
Expand All @@ -421,6 +421,7 @@ func TestRunUnary(t *testing.T) {
WithDialTimeout(time.Duration(20*time.Second)),
WithInsecure(true),
WithData(data),
WithMetadataFromJSON(`[{"index": "1 one"}, {"index": "2 two"}, {"index": "3 three"}]`),
)

assert.NoError(t, err)
Expand All @@ -429,6 +430,24 @@ func TestRunUnary(t *testing.T) {
count := gs.GetCount(callType)
assert.Equal(t, 6, count)

// Verify metadata

// We specify 3 unique metadata items over which the requester should round-robin
// for all of the 6 requests. This means we should see each unique item twice.
metadata := gs.GetMetadata(callType)
assert.Equal(t, len(metadata), 6)

seenMetadataIndexValues := make([]string, 0)

for _, metadataItem := range metadata {
seenMetadataIndexValues = append(seenMetadataIndexValues, metadataItem[0]["index"][0])
}

// we don't expect to have the same order of elements since requests are concurrent
assert.ElementsMatch(t, []string{"1 one", "2 two", "3 three", "1 one", "2 two", "3 three"},
seenMetadataIndexValues)

// Verify actual payload / messages
calls := gs.GetCalls(callType)
assert.NotNil(t, calls)
assert.Len(t, calls, 6)
Expand Down Expand Up @@ -462,6 +481,7 @@ func TestRunUnary(t *testing.T) {
WithDialTimeout(time.Duration(20*time.Second)),
WithInsecure(true),
WithData(data),
WithMetadataFromJSON(`{"index": "1 one"}`),
)

assert.NoError(t, err)
Expand All @@ -470,6 +490,21 @@ func TestRunUnary(t *testing.T) {
count := gs.GetCount(callType)
assert.Equal(t, 6, count)

// Verify metadata
// We specify a single item for metadata which should be used for all the requests
metadata := gs.GetMetadata(callType)
assert.Equal(t, len(metadata), 6)

seenMetadataIndexValues := make([]string, 0)

for _, metadataItem := range metadata {
seenMetadataIndexValues = append(seenMetadataIndexValues, metadataItem[0]["index"][0])
}

assert.ElementsMatch(t, []string{"1 one", "1 one", "1 one", "1 one", "1 one", "1 one"},
seenMetadataIndexValues)

// Verify actual payload / messages
calls := gs.GetCalls(callType)
assert.NotNil(t, calls)
assert.Len(t, calls, 6)
Expand Down

0 comments on commit f46a37e

Please sign in to comment.