diff --git a/errors.go b/errors.go index 70f2b9bfdf..3d48be8d12 100644 --- a/errors.go +++ b/errors.go @@ -37,6 +37,9 @@ var ErrShuttingDown = errors.New("kafka: message received by producer in process // ErrMessageTooLarge is returned when the next message to consume is larger than the configured Consumer.Fetch.Max var ErrMessageTooLarge = errors.New("kafka: message is larger than Consumer.Fetch.Max") +// ErrMessageTooLarge is returned when a JoinGroup request returns a protocol type that is not supported by sarama. +var ErrUnknownGroupProtocol = errors.New("kafka: encountered an unknown group protocol") + // PacketEncodingError is returned from a failure while encoding a Kafka packet. This can happen, for example, // if you try to encode a string over 2^15 characters in length, since Kafka's encoding rules do not permit that. type PacketEncodingError struct { diff --git a/join_group_request.go b/join_group_request.go new file mode 100644 index 0000000000..9565714604 --- /dev/null +++ b/join_group_request.go @@ -0,0 +1,127 @@ +package sarama + +type JoinGroupRequest struct { + GroupId string + SessionTimeout int32 + MemberId string + ProtocolType string + GroupProtocols []GroupProtocol +} + +func (r *JoinGroupRequest) encode(pe packetEncoder) error { + if err := pe.putString(r.GroupId); err != nil { + return err + } + pe.putInt32(r.SessionTimeout) + if err := pe.putString(r.MemberId); err != nil { + return err + } + if err := pe.putString(r.ProtocolType); err != nil { + return err + } + + if err := pe.putArrayLength(len(r.GroupProtocols)); err != nil { + return err + } + for _, groupProtocol := range r.GroupProtocols { + if err := groupProtocol.encodeGroupProtocol(pe); err != nil { + return err + } + } + + return nil +} + +func (r *JoinGroupRequest) decode(pd packetDecoder) (err error) { + r.GroupId, err = pd.getString() + if err != nil { + return + } + + r.SessionTimeout, err = pd.getInt32() + if err != nil { + return + } + + r.MemberId, err = pd.getString() + if err != nil { + return + } + + r.ProtocolType, err = pd.getString() + if err != nil { + return + } + + switch r.ProtocolType { + case "consumer": + n, err := pd.getArrayLength() + if err != nil { + return err + } + + r.GroupProtocols = make([]GroupProtocol, n) + for i := 0; i < n; i++ { + r.GroupProtocols[i] = new(ConsumerGroupProtocol) + if err := r.GroupProtocols[i].decodeGroupProtocol(pd); err != nil { + return nil + } + } + + default: + return ErrUnknownGroupProtocol + } + + return nil +} + +func (r *JoinGroupRequest) key() int16 { + return 11 +} + +func (r *JoinGroupRequest) version() int16 { + return 0 +} + +type GroupProtocol interface { + encodeGroupProtocol(packetEncoder) error + decodeGroupProtocol(packetDecoder) error +} + +type ConsumerGroupProtocol struct { + ProtocolName string + Version int16 + Subscription []string + UserData []byte +} + +func (cgp *ConsumerGroupProtocol) encodeGroupProtocol(pe packetEncoder) error { + if err := pe.putString(cgp.ProtocolName); err != nil { + return err + } + pe.putInt16(cgp.Version) + if err := pe.putStringArray(cgp.Subscription); err != nil { + return err + } + return pe.putBytes(cgp.UserData) +} + +func (cgp *ConsumerGroupProtocol) decodeGroupProtocol(pd packetDecoder) (err error) { + cgp.ProtocolName, err = pd.getString() + if err != nil { + return + } + + cgp.Version, err = pd.getInt16() + if err != nil { + return + } + + cgp.Subscription, err = pd.getStringArray() + if err != nil { + return + } + + cgp.UserData, err = pd.getBytes() + return +} diff --git a/join_group_response.go b/join_group_response.go new file mode 100644 index 0000000000..d8ddae5dd7 --- /dev/null +++ b/join_group_response.go @@ -0,0 +1,111 @@ +package sarama + +type JoinGroupResponse struct { + ErrorCode int16 + GenerationId int32 + GroupProtocol string + LeaderId string + MemberId string + Members []*GroupMember +} + +func (r *JoinGroupResponse) encode(pe packetEncoder) error { + pe.putInt16(r.ErrorCode) + pe.putInt32(r.GenerationId) + + if err := pe.putString(r.GroupProtocol); err != nil { + return err + } + + if err := pe.putString(r.LeaderId); err != nil { + return err + } + + if err := pe.putString(r.MemberId); err != nil { + return err + } + + if err := pe.putArrayLength(len(r.Members)); err != nil { + return err + } + for _, member := range r.Members { + if err := member.encode(pe); err != nil { + return err + } + } + + return nil +} + +func (r *JoinGroupResponse) decode(pd packetDecoder) (err error) { + r.ErrorCode, err = pd.getInt16() + if err != nil { + return err + } + + r.GenerationId, err = pd.getInt32() + if err != nil { + return err + } + + r.GroupProtocol, err = pd.getString() + if err != nil { + return err + } + + r.LeaderId, err = pd.getString() + if err != nil { + return err + } + + r.MemberId, err = pd.getString() + if err != nil { + return err + } + + n, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Members = make([]*GroupMember, n) + for i := 0; i < n; i++ { + r.Members[i] = new(GroupMember) + if err := r.Members[i].decode(pd); err != nil { + return nil + } + } + + return nil +} + +type GroupMember struct { + MemberId string + MemberMetadata []byte +} + +func (gm *GroupMember) encode(pe packetEncoder) error { + if err := pe.putString(gm.MemberId); err != nil { + return err + } + + if err := pe.putBytes(gm.MemberMetadata); err != nil { + return err + } + + return nil +} + +func (gm *GroupMember) decode(pd packetDecoder) (err error) { + gm.MemberId, err = pd.getString() + if err != nil { + return err + } + + gm.MemberMetadata, err = pd.getBytes() + if err != nil { + return err + } + + return nil +} diff --git a/packet_decoder.go b/packet_decoder.go index 0342223136..28670c0e62 100644 --- a/packet_decoder.go +++ b/packet_decoder.go @@ -16,6 +16,7 @@ type packetDecoder interface { getString() (string, error) getInt32Array() ([]int32, error) getInt64Array() ([]int64, error) + getStringArray() ([]string, error) // Subsets remaining() int diff --git a/packet_encoder.go b/packet_encoder.go index 2c5710938c..0df6e24aa6 100644 --- a/packet_encoder.go +++ b/packet_encoder.go @@ -15,6 +15,7 @@ type packetEncoder interface { putBytes(in []byte) error putRawBytes(in []byte) error putString(in string) error + putStringArray(in []string) error putInt32Array(in []int32) error putInt64Array(in []int64) error diff --git a/prep_encoder.go b/prep_encoder.go index 58fb4fc2c4..8c6ba8502c 100644 --- a/prep_encoder.go +++ b/prep_encoder.go @@ -66,6 +66,21 @@ func (pe *prepEncoder) putString(in string) error { return nil } +func (pe *prepEncoder) putStringArray(in []string) error { + err := pe.putArrayLength(len(in)) + if err != nil { + return err + } + + for _, str := range in { + if err := pe.putString(str); err != nil { + return err + } + } + + return nil +} + func (pe *prepEncoder) putInt32Array(in []int32) error { err := pe.putArrayLength(len(in)) if err != nil { diff --git a/real_decoder.go b/real_decoder.go index 235c8a80d8..e3ea331048 100644 --- a/real_decoder.go +++ b/real_decoder.go @@ -181,6 +181,33 @@ func (rd *realDecoder) getInt64Array() ([]int64, error) { return ret, nil } +func (rd *realDecoder) getStringArray() ([]string, error) { + if rd.remaining() < 4 { + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + n := int(binary.BigEndian.Uint32(rd.raw[rd.off:])) + rd.off += 4 + + if n == 0 { + return nil, nil + } + + if n < 0 { + return nil, PacketDecodingError{"invalid array length"} + } + + ret := make([]string, n) + for i := range ret { + if str, err := rd.getString(); err != nil { + return nil, err + } else { + ret[i] = str + } + } + return ret, nil +} + // subsets func (rd *realDecoder) remaining() int { diff --git a/real_encoder.go b/real_encoder.go index b50f54bc5c..076fdd0ca1 100644 --- a/real_encoder.go +++ b/real_encoder.go @@ -61,6 +61,21 @@ func (re *realEncoder) putString(in string) error { return nil } +func (re *realEncoder) putStringArray(in []string) error { + err := re.putArrayLength(len(in)) + if err != nil { + return err + } + + for _, val := range in { + if err := re.putString(val); err != nil { + return err + } + } + + return nil +} + func (re *realEncoder) putInt32Array(in []int32) error { err := re.putArrayLength(len(in)) if err != nil {