diff --git a/join_group_request.go b/join_group_request.go index 656db4562..3a7ba1712 100644 --- a/join_group_request.go +++ b/join_group_request.go @@ -1,11 +1,36 @@ package sarama +type GroupProtocol struct { + Name string + Metadata []byte +} + +func (p *GroupProtocol) decode(pd packetDecoder) (err error) { + p.Name, err = pd.getString() + if err != nil { + return err + } + p.Metadata, err = pd.getBytes() + return err +} + +func (p *GroupProtocol) encode(pe packetEncoder) (err error) { + if err := pe.putString(p.Name); err != nil { + return err + } + if err := pe.putBytes(p.Metadata); err != nil { + return err + } + return nil +} + type JoinGroupRequest struct { - GroupId string - SessionTimeout int32 - MemberId string - ProtocolType string - GroupProtocols map[string][]byte + GroupId string + SessionTimeout int32 + MemberId string + ProtocolType string + GroupProtocols map[string][]byte // deprecated; use OrderedGroupProtocols + OrderedGroupProtocols []*GroupProtocol } func (r *JoinGroupRequest) encode(pe packetEncoder) error { @@ -20,16 +45,31 @@ func (r *JoinGroupRequest) encode(pe packetEncoder) error { return err } - if err := pe.putArrayLength(len(r.GroupProtocols)); err != nil { - return err - } - for name, metadata := range r.GroupProtocols { - if err := pe.putString(name); err != nil { + if len(r.GroupProtocols) > 0 { + if len(r.OrderedGroupProtocols) > 0 { + return PacketDecodingError{"cannot specify both GroupProtocols and OrderedGroupProtocols on JoinGroupRequest"} + } + + if err := pe.putArrayLength(len(r.GroupProtocols)); err != nil { return err } - if err := pe.putBytes(metadata); err != nil { + for name, metadata := range r.GroupProtocols { + if err := pe.putString(name); err != nil { + return err + } + if err := pe.putBytes(metadata); err != nil { + return err + } + } + } else { + if err := pe.putArrayLength(len(r.OrderedGroupProtocols)); err != nil { return err } + for _, protocol := range r.OrderedGroupProtocols { + if err := protocol.encode(pe); err != nil { + return err + } + } } return nil @@ -62,16 +102,12 @@ func (r *JoinGroupRequest) decode(pd packetDecoder, version int16) (err error) { r.GroupProtocols = make(map[string][]byte) for i := 0; i < n; i++ { - name, err := pd.getString() - if err != nil { - return err - } - metadata, err := pd.getBytes() - if err != nil { + protocol := &GroupProtocol{} + if err := protocol.decode(pd); err != nil { return err } - - r.GroupProtocols[name] = metadata + r.GroupProtocols[protocol.Name] = protocol.Metadata + r.OrderedGroupProtocols = append(r.OrderedGroupProtocols, protocol) } return nil @@ -90,11 +126,10 @@ func (r *JoinGroupRequest) requiredVersion() KafkaVersion { } func (r *JoinGroupRequest) AddGroupProtocol(name string, metadata []byte) { - if r.GroupProtocols == nil { - r.GroupProtocols = make(map[string][]byte) - } - - r.GroupProtocols[name] = metadata + r.OrderedGroupProtocols = append(r.OrderedGroupProtocols, &GroupProtocol{ + Name: name, + Metadata: metadata, + }) } func (r *JoinGroupRequest) AddGroupProtocolMetadata(name string, metadata *ConsumerGroupMemberMetadata) error { diff --git a/join_group_request_test.go b/join_group_request_test.go index 8a6448c0e..1ba3308bb 100644 --- a/join_group_request_test.go +++ b/join_group_request_test.go @@ -23,19 +23,35 @@ var ( ) func TestJoinGroupRequest(t *testing.T) { - var request *JoinGroupRequest - - request = new(JoinGroupRequest) + request := new(JoinGroupRequest) request.GroupId = "TestGroup" request.SessionTimeout = 100 request.ProtocolType = "consumer" testRequest(t, "no protocols", request, joinGroupRequestNoProtocols) +} + +func TestJoinGroupRequestOneProtocol(t *testing.T) { + request := new(JoinGroupRequest) + request.GroupId = "TestGroup" + request.SessionTimeout = 100 + request.MemberId = "OneProtocol" + request.ProtocolType = "consumer" + request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03}) + packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol) + request.GroupProtocols = make(map[string][]byte) + request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03} + testRequestDecode(t, "one protocol", request, packet) +} - request = new(JoinGroupRequest) +func TestJoinGroupRequestDeprecatedEncode(t *testing.T) { + request := new(JoinGroupRequest) request.GroupId = "TestGroup" request.SessionTimeout = 100 request.MemberId = "OneProtocol" request.ProtocolType = "consumer" + request.GroupProtocols = make(map[string][]byte) + request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03} + packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol) request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03}) - testRequest(t, "one protocol", request, joinGroupRequestOneProtocol) + testRequestDecode(t, "one protocol", request, packet) } diff --git a/request_test.go b/request_test.go index e54575434..bd9cef4eb 100644 --- a/request_test.go +++ b/request_test.go @@ -50,7 +50,11 @@ func testVersionDecodable(t *testing.T, name string, out versionedDecoder, in [] } func testRequest(t *testing.T, name string, rb protocolBody, expected []byte) { - // Encoder request + packet := testRequestEncode(t, name, rb, expected) + testRequestDecode(t, name, rb, packet) +} + +func testRequestEncode(t *testing.T, name string, rb protocolBody, expected []byte) []byte { req := &request{correlationID: 123, clientID: "foo", body: rb} packet, err := encode(req, nil) headerSize := 14 + len("foo") @@ -59,7 +63,10 @@ func testRequest(t *testing.T, name string, rb protocolBody, expected []byte) { } else if !bytes.Equal(packet[headerSize:], expected) { t.Error("Encoding", name, "failed\ngot ", packet[headerSize:], "\nwant", expected) } - // Decoder request + return packet +} + +func testRequestDecode(t *testing.T, name string, rb protocolBody, packet []byte) { decoded, n, err := decodeRequest(bytes.NewReader(packet)) if err != nil { t.Error("Failed to decode request", err)