Skip to content

Commit

Permalink
Support multiple record batches, closes #1022
Browse files Browse the repository at this point in the history
  • Loading branch information
bobrik committed Jan 14, 2018
1 parent f0c3255 commit 068e0b7
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 75 deletions.
21 changes: 16 additions & 5 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,14 +601,25 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
child.fetchSize = child.conf.Consumer.Fetch.Default
atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)

if control, err := block.Records.isControl(); err != nil || control {
return nil, err
}

if block.Records.recordsType == legacyRecords {
return child.parseMessages(block.Records.msgSet)
}
return child.parseRecords(block.Records.recordBatch)

messages := []*ConsumerMessage{}
for _, recordBatch := range block.Records.recordBatchSet.batches {
if recordBatch.Control {
continue
}

recordBatchMessages, err := child.parseRecords(recordBatch)
messages = append(messages, recordBatchMessages...)

if err != nil {
return messages, err
}
}

return messages, nil
}

// brokerConsumer
Expand Down
20 changes: 9 additions & 11 deletions fetch_response.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sarama

import "time"
import (
"time"
)

type AbortedTransaction struct {
ProducerID int64
Expand Down Expand Up @@ -301,22 +303,18 @@ func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Enco
frb := r.getOrCreateBlock(topic, partition)
kb, vb := encodeKV(key, value)
rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
batch := frb.Records.recordBatch
if batch == nil {
batch = &RecordBatch{Version: 2}
frb.Records = newDefaultRecords(batch)
if frb.Records.recordBatchSet == nil {
frb.Records = newDefaultRecords([]*RecordBatch{&RecordBatch{Version: 2}})
}
batch.addRecord(rec)
frb.Records.recordBatchSet.batches[0].addRecord(rec)
}

func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
frb := r.getOrCreateBlock(topic, partition)
batch := frb.Records.recordBatch
if batch == nil {
batch = &RecordBatch{Version: 2}
frb.Records = newDefaultRecords(batch)
if frb.Records.recordBatchSet == nil {
frb.Records = newDefaultRecords([]*RecordBatch{&RecordBatch{Version: 2}})
}
batch.LastOffsetDelta = offset
frb.Records.recordBatchSet.batches[0].LastOffsetDelta = offset
}

func (r *FetchResponse) SetLastStableOffset(topic string, partition int32, offset int64) {
Expand Down
2 changes: 1 addition & 1 deletion fetch_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func TestOneRecordFetchResponse(t *testing.T) {
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
rec := block.Records.recordBatch.Records[0]
rec := block.Records.recordBatchSet.batches[0].Records[0]
if !bytes.Equal(rec.Key, []byte{0x01, 0x02, 0x03, 0x04}) {
t.Error("Decoding produced incorrect record key.")
}
Expand Down
4 changes: 2 additions & 2 deletions produce_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (r *ProduceRequest) encode(pe packetEncoder) error {
}
if metricRegistry != nil {
if r.Version >= 3 {
topicRecordCount += updateBatchMetrics(records.recordBatch, compressionRatioMetric, topicCompressionRatioMetric)
topicRecordCount += updateBatchMetrics(records.recordBatchSet.batches[0], compressionRatioMetric, topicCompressionRatioMetric)
} else {
topicRecordCount += updateMsgSetMetrics(records.msgSet, compressionRatioMetric, topicCompressionRatioMetric)
}
Expand Down Expand Up @@ -248,5 +248,5 @@ func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet)

func (r *ProduceRequest) AddBatch(topic string, partition int32, batch *RecordBatch) {
r.ensureRecords(topic, partition)
r.records[topic][partition] = newDefaultRecords(batch)
r.records[topic][partition] = newDefaultRecords([]*RecordBatch{batch})
}
10 changes: 5 additions & 5 deletions produce_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
ProducerID: -1, /* No producer id */
Codec: ps.parent.conf.Producer.Compression,
}
set = &partitionSet{recordsToSend: newDefaultRecords(batch)}
set = &partitionSet{recordsToSend: newDefaultRecords([]*RecordBatch{batch})}
size = recordBatchOverhead
} else {
set = &partitionSet{recordsToSend: newLegacyRecords(new(MessageSet))}
Expand All @@ -79,7 +79,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
rec := &Record{
Key: key,
Value: val,
TimestampDelta: timestamp.Sub(set.recordsToSend.recordBatch.FirstTimestamp),
TimestampDelta: timestamp.Sub(set.recordsToSend.recordBatchSet.batches[0].FirstTimestamp),
}
size += len(key) + len(val)
if len(msg.Headers) > 0 {
Expand All @@ -89,7 +89,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
size += len(rec.Headers[i].Key) + len(rec.Headers[i].Value) + 2*binary.MaxVarintLen32
}
}
set.recordsToSend.recordBatch.addRecord(rec)
set.recordsToSend.recordBatchSet.batches[0].addRecord(rec)
} else {
msgToSend := &Message{Codec: CompressionNone, Key: key, Value: val}
if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) {
Expand Down Expand Up @@ -122,11 +122,11 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
for topic, partitionSet := range ps.msgs {
for partition, set := range partitionSet {
if req.Version >= 3 {
for i, record := range set.recordsToSend.recordBatch.Records {
for i, record := range set.recordsToSend.recordBatchSet.batches[0].Records {
record.OffsetDelta = int64(i)
}

req.AddBatch(topic, partition, set.recordsToSend.recordBatch)
req.AddBatch(topic, partition, set.recordsToSend.recordBatchSet.batches[0])
continue
}
if ps.parent.conf.Producer.Compression == CompressionNone {
Expand Down
2 changes: 1 addition & 1 deletion produce_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func TestProduceSetV3RequestBuilding(t *testing.T) {
t.Error("Wrong request version")
}

batch := req.records["t1"][0].recordBatch
batch := req.records["t1"][0].recordBatchSet.batches[0]
if batch.FirstTimestamp != now {
t.Errorf("Wrong first timestamp: %v", batch.FirstTimestamp)
}
Expand Down
41 changes: 41 additions & 0 deletions record_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,47 @@ func (e recordsArray) decode(pd packetDecoder) error {
return nil
}

type RecordBatchSet struct {
batches []*RecordBatch
}

func (rbs *RecordBatchSet) encode(pe packetEncoder) error {
for _, rb := range rbs.batches {
if err := rb.encode(pe); err != nil {
return err
}
}
return nil
}

func (rbs *RecordBatchSet) decode(pd packetDecoder) error {
rbs.batches = []*RecordBatch{}

for {
if pd.remaining() == 0 {
break
}

rb := &RecordBatch{}
if err := rb.decode(pd); err != nil {
// If we have at least one decoded record batch, this is not an error
if err == ErrInsufficientData && len(rbs.batches) > 0 {
return nil
}
return err
}

// If we have at least one full record batch, we skip incomplete ones
if rb.PartialTrailingRecord && len(rbs.batches) > 0 {
return nil
}

rbs.batches = append(rbs.batches, rb)
}

return nil
}

type RecordBatch struct {
FirstOffset int64
PartitionLeaderEpoch int32
Expand Down
58 changes: 24 additions & 34 deletions records.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sarama

import "fmt"
import (
"fmt"
)

const (
unknownRecords = iota
Expand All @@ -13,28 +15,28 @@ const (

// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
type Records struct {
recordsType int
msgSet *MessageSet
recordBatch *RecordBatch
recordsType int
msgSet *MessageSet
recordBatchSet *RecordBatchSet
}

func newLegacyRecords(msgSet *MessageSet) Records {
return Records{recordsType: legacyRecords, msgSet: msgSet}
}

func newDefaultRecords(batch *RecordBatch) Records {
return Records{recordsType: defaultRecords, recordBatch: batch}
func newDefaultRecords(batches []*RecordBatch) Records {
return Records{recordsType: defaultRecords, recordBatchSet: &RecordBatchSet{batches}}
}

// setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
// The first return value indicates whether both fields are nil (and the type is not set).
// If both fields are not nil, it returns an error.
func (r *Records) setTypeFromFields() (bool, error) {
if r.msgSet == nil && r.recordBatch == nil {
if r.msgSet == nil && r.recordBatchSet == nil {
return true, nil
}
if r.msgSet != nil && r.recordBatch != nil {
return false, fmt.Errorf("both msgSet and recordBatch are set, but record type is unknown")
if r.msgSet != nil && r.recordBatchSet != nil {
return false, fmt.Errorf("both msgSet and recordBatchSet are set, but record type is unknown")
}
r.recordsType = defaultRecords
if r.msgSet != nil {
Expand All @@ -57,10 +59,10 @@ func (r *Records) encode(pe packetEncoder) error {
}
return r.msgSet.encode(pe)
case defaultRecords:
if r.recordBatch == nil {
if r.recordBatchSet == nil {
return nil
}
return r.recordBatch.encode(pe)
return r.recordBatchSet.encode(pe)
}
return fmt.Errorf("unknown records type: %v", r.recordsType)
}
Expand Down Expand Up @@ -95,8 +97,8 @@ func (r *Records) decode(pd packetDecoder) error {
r.msgSet = &MessageSet{}
return r.msgSet.decode(pd)
case defaultRecords:
r.recordBatch = &RecordBatch{}
return r.recordBatch.decode(pd)
r.recordBatchSet = &RecordBatchSet{batches: []*RecordBatch{}}
return r.recordBatchSet.decode(pd)
}
return fmt.Errorf("unknown records type: %v", r.recordsType)
}
Expand All @@ -115,10 +117,14 @@ func (r *Records) numRecords() (int, error) {
}
return len(r.msgSet.Messages), nil
case defaultRecords:
if r.recordBatch == nil {
if r.recordBatchSet == nil {
return 0, nil
}
return len(r.recordBatch.Records), nil
s := 0
for i := range r.recordBatchSet.batches {
s += len(r.recordBatchSet.batches[i].Records)
}
return s, nil
}
return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
}
Expand All @@ -139,29 +145,13 @@ func (r *Records) isPartial() (bool, error) {
}
return r.msgSet.PartialTrailingMessage, nil
case defaultRecords:
if r.recordBatch == nil {
if r.recordBatchSet == nil {
return false, nil
}
return r.recordBatch.PartialTrailingRecord, nil
}
return false, fmt.Errorf("unknown records type: %v", r.recordsType)
}

func (r *Records) isControl() (bool, error) {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return false, err
if len(r.recordBatchSet.batches) == 1 {
return r.recordBatchSet.batches[0].PartialTrailingRecord, nil
}
}

switch r.recordsType {
case legacyRecords:
return false, nil
case defaultRecords:
if r.recordBatch == nil {
return false, nil
}
return r.recordBatch.Control, nil
}
return false, fmt.Errorf("unknown records type: %v", r.recordsType)
}
20 changes: 4 additions & 16 deletions records_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ func TestLegacyRecords(t *testing.T) {
if p {
t.Errorf("MessageSet shouldn't have a partial trailing message")
}

c, err := r.isControl()
if err != nil {
t.Fatal(err)
}
if c {
t.Errorf("MessageSet can't be a control batch")
}
}

func TestDefaultRecords(t *testing.T) {
Expand All @@ -84,7 +76,7 @@ func TestDefaultRecords(t *testing.T) {
},
}

r := newDefaultRecords(batch)
r := newDefaultRecords([]*RecordBatch{batch})

exp, err := encode(batch, nil)
if err != nil {
Expand Down Expand Up @@ -113,8 +105,8 @@ func TestDefaultRecords(t *testing.T) {
if r.recordsType != defaultRecords {
t.Fatalf("Wrong records type %v, expected %v", r.recordsType, defaultRecords)
}
if !reflect.DeepEqual(batch, r.recordBatch) {
t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
if !reflect.DeepEqual(batch, r.recordBatchSet.batches[0]) {
t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatchSet.batches[0])
}

n, err := r.numRecords()
Expand All @@ -133,11 +125,7 @@ func TestDefaultRecords(t *testing.T) {
t.Errorf("RecordBatch shouldn't have a partial trailing record")
}

c, err := r.isControl()
if err != nil {
t.Fatal(err)
}
if c {
if r.recordBatchSet.batches[0].Control {
t.Errorf("RecordBatch shouldn't be a control batch")
}
}

0 comments on commit 068e0b7

Please sign in to comment.