Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple record batches, closes #1022 #1023

Merged
merged 5 commits into from
Jan 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,12 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
return nil, block.Err
}

nRecs, err := block.Records.numRecords()
nRecs, err := block.numRecords()
if err != nil {
return nil, err
}
if nRecs == 0 {
partialTrailingMessage, err := block.Records.isPartial()
partialTrailingMessage, err := block.isPartial()
if err != nil {
return nil, err
}
Expand All @@ -601,14 +601,33 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
child.fetchSize = child.conf.Consumer.Fetch.Default
atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)

if control, err := block.Records.isControl(); err != nil || control {
return nil, err
}
messages := []*ConsumerMessage{}
for _, records := range block.RecordsSet {
if control, err := records.isControl(); err != nil || control {
continue
}

switch records.recordsType {
case legacyRecords:
messageSetMessages, err := child.parseMessages(records.msgSet)
if err != nil {
return nil, err
}

if block.Records.recordsType == legacyRecords {
return child.parseMessages(block.Records.msgSet)
messages = append(messages, messageSetMessages...)
case defaultRecords:
recordBatchMessages, err := child.parseRecords(records.recordBatch)
if err != nil {
return nil, err
}

messages = append(messages, recordBatchMessages...)
default:
return nil, fmt.Errorf("unknown records type: %v", records.recordsType)
}
}
return child.parseRecords(block.Records.recordBatch)

return messages, nil
}

// brokerConsumer
Expand Down
98 changes: 79 additions & 19 deletions fetch_response.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sarama

import "time"
import (
"time"
)

type AbortedTransaction struct {
ProducerID int64
Expand Down Expand Up @@ -31,7 +33,9 @@ type FetchResponseBlock struct {
HighWaterMarkOffset int64
LastStableOffset int64
AbortedTransactions []*AbortedTransaction
Records Records
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it's a breaking change to remove this; could you leave it with a comment that it's deprecated, and just fill it in with the first set or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, let me know if that's what you had in mind.

Records *Records // deprecated: use FetchResponseBlock.Records
RecordsSet []*Records
Partial bool
}

func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error) {
Expand Down Expand Up @@ -79,15 +83,69 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
if err != nil {
return err
}
if recordsSize > 0 {
if err = b.Records.decode(recordsDecoder); err != nil {

b.RecordsSet = []*Records{}

for recordsDecoder.remaining() > 0 {
records := &Records{}
if err := records.decode(recordsDecoder); err != nil {
// If we have at least one decoded records, this is not an error
if err == ErrInsufficientData {
if len(b.RecordsSet) == 0 {
b.Partial = true
}
break
}
return err
}

partial, err := records.isPartial()
if err != nil {
return err
}

// If we have at least one full records, we skip incomplete ones
if partial && len(b.RecordsSet) > 0 {
break
}

b.RecordsSet = append(b.RecordsSet, records)

if b.Records == nil {
b.Records = records
}
}

return nil
}

func (b *FetchResponseBlock) numRecords() (int, error) {
sum := 0

for _, records := range b.RecordsSet {
count, err := records.numRecords()
if err != nil {
return 0, err
}

sum += count
}

return sum, nil
}

func (b *FetchResponseBlock) isPartial() (bool, error) {
if b.Partial {
return true, nil
}

if len(b.RecordsSet) == 1 {
return b.RecordsSet[0].isPartial()
}

return false, nil
}

func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error) {
pe.putInt16(int16(b.Err))

Expand All @@ -107,9 +165,11 @@ func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error)
}

pe.push(&lengthField{})
err = b.Records.encode(pe)
if err != nil {
return err
for _, records := range b.RecordsSet {
err = records.encode(pe)
if err != nil {
return err
}
}
return pe.pop()
}
Expand Down Expand Up @@ -289,33 +349,33 @@ func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Enc
kb, vb := encodeKV(key, value)
msg := &Message{Key: kb, Value: vb}
msgBlock := &MessageBlock{Msg: msg, Offset: offset}
set := frb.Records.msgSet
if set == nil {
set = &MessageSet{}
frb.Records = newLegacyRecords(set)
if len(frb.RecordsSet) == 0 {
records := newLegacyRecords(&MessageSet{})
frb.RecordsSet = []*Records{&records}
}
set := frb.RecordsSet[0].msgSet
set.Messages = append(set.Messages, msgBlock)
}

func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Encoder, offset int64) {
frb := r.getOrCreateBlock(topic, partition)
kb, vb := encodeKV(key, value)
rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
batch := frb.Records.recordBatch
if batch == nil {
batch = &RecordBatch{Version: 2}
frb.Records = newDefaultRecords(batch)
if len(frb.RecordsSet) == 0 {
records := newDefaultRecords(&RecordBatch{Version: 2})
frb.RecordsSet = []*Records{&records}
}
batch := frb.RecordsSet[0].recordBatch
batch.addRecord(rec)
}

func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
frb := r.getOrCreateBlock(topic, partition)
batch := frb.Records.recordBatch
if batch == nil {
batch = &RecordBatch{Version: 2}
frb.Records = newDefaultRecords(batch)
if len(frb.RecordsSet) == 0 {
records := newDefaultRecords(&RecordBatch{Version: 2})
frb.RecordsSet = []*Records{&records}
}
batch := frb.RecordsSet[0].recordBatch
batch.LastOffsetDelta = offset
}

Expand Down
18 changes: 9 additions & 9 deletions fetch_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,22 @@ func TestOneMessageFetchResponse(t *testing.T) {
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
partial, err := block.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing message where there wasn't one.")
}

n, err := block.Records.numRecords()
n, err := block.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of messages.")
}
msgBlock := block.Records.msgSet.Messages[0]
msgBlock := block.RecordsSet[0].msgSet.Messages[0]
if msgBlock.Offset != 0x550000 {
t.Error("Decoding produced incorrect message offset.")
}
Expand Down Expand Up @@ -170,22 +170,22 @@ func TestOneRecordFetchResponse(t *testing.T) {
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
partial, err := block.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing record where there wasn't one.")
}

n, err := block.Records.numRecords()
n, err := block.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
rec := block.Records.recordBatch.Records[0]
rec := block.RecordsSet[0].recordBatch.Records[0]
if !bytes.Equal(rec.Key, []byte{0x01, 0x02, 0x03, 0x04}) {
t.Error("Decoding produced incorrect record key.")
}
Expand Down Expand Up @@ -216,22 +216,22 @@ func TestOneMessageFetchResponseV4(t *testing.T) {
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
partial, err := block.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing record where there wasn't one.")
}

n, err := block.Records.numRecords()
n, err := block.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
msgBlock := block.Records.msgSet.Messages[0]
msgBlock := block.RecordsSet[0].msgSet.Messages[0]
if msgBlock.Offset != 0x550000 {
t.Error("Decoding produced incorrect message offset.")
}
Expand Down
9 changes: 9 additions & 0 deletions message_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func (ms *MessageSet) decode(pd packetDecoder) (err error) {
ms.Messages = nil

for pd.remaining() > 0 {
magic, err := magicValue(pd)
if err != nil {
return err
}

if magic > 1 {
return nil
}

msb := new(MessageBlock)
err = msb.decode(pd)
switch err {
Expand Down
20 changes: 13 additions & 7 deletions records.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,12 @@ func (r *Records) encode(pe packetEncoder) error {
}
return r.recordBatch.encode(pe)
}

return fmt.Errorf("unknown records type: %v", r.recordsType)
}

func (r *Records) setTypeFromMagic(pd packetDecoder) error {
dec, err := pd.peek(magicOffset, magicLength)
if err != nil {
return err
}

magic, err := dec.getInt8()
magic, err := magicValue(pd)
if err != nil {
return err
}
Expand All @@ -80,13 +76,14 @@ func (r *Records) setTypeFromMagic(pd packetDecoder) error {
if magic < 2 {
r.recordsType = legacyRecords
}

return nil
}

func (r *Records) decode(pd packetDecoder) error {
if r.recordsType == unknownRecords {
if err := r.setTypeFromMagic(pd); err != nil {
return nil
return err
}
}

Expand Down Expand Up @@ -165,3 +162,12 @@ func (r *Records) isControl() (bool, error) {
}
return false, fmt.Errorf("unknown records type: %v", r.recordsType)
}

func magicValue(pd packetDecoder) (int8, error) {
dec, err := pd.peek(magicOffset, magicLength)
if err != nil {
return 0, err
}

return dec.getInt8()
}