diff --git a/aggregate/aggregator.go b/aggregate/aggregator.go index 0914d5a..7353be8 100644 --- a/aggregate/aggregator.go +++ b/aggregate/aggregator.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/sirupsen/logrus" + "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" ) @@ -19,8 +20,8 @@ var ( const ( maximumRecordSize = 1024 * 1024 // 1 MB defaultMaxAggRecordSize = 20 * 1024 // 20K - pKeyIdxSize = 8 - aggProtobufBytes = 2 // Marshalling the data into protobuf adds an additional 2 bytes. + initialAggRecordSize = 0 + fieldNumberSize = 1 // All field numbers are below 16, meaning they will only take up 1 byte ) // Aggregator kinesis aggregator @@ -38,6 +39,7 @@ func NewAggregator() *Aggregator { partitionKeys: make(map[string]uint64, 0), records: make([]*Record, 0), maxAggRecordSize: defaultMaxAggRecordSize, + aggSize: initialAggRecordSize, } } @@ -59,8 +61,16 @@ func (a *Aggregator) AddRecord(partitionKey string, data []byte) (entry *kinesis PartitionKey: aws.String(partitionKey), }, nil } + // Check if we need to add a new partition key, and if we do how much space it will take + pKeyIdx, pKeyAddedSize := a.checkPartitionKey(partitionKey) - if a.getSize()+dataSize+partitionKeySize+pKeyIdxSize >= maximumRecordSize { + // data field size is proto size of data + data field number size + // partition key field size is varint of index size + field number size + recordSize := protowire.SizeBytes(dataSize) + fieldNumberSize + protowire.SizeVarint(pKeyIdx) + fieldNumberSize + // Total size is proto size of data + field number of parent proto + addedSize := protowire.SizeBytes(recordSize) + fieldNumberSize + + if a.getSize()+addedSize+pKeyAddedSize >= maximumRecordSize { // Aggregate records, and return entry, err = a.AggregateRecords() if err != nil { @@ -76,7 +86,7 @@ func (a *Aggregator) AddRecord(partitionKey string, data []byte) (entry *kinesis PartitionKeyIndex: &partitionKeyIndex, }) - a.aggSize += dataSize + pKeyIdxSize + a.aggSize += addedSize return entry, err } @@ -132,10 +142,22 @@ func (a *Aggregator) addPartitionKey(partitionKey string) uint64 { idx := uint64(len(a.partitionKeys)) a.partitionKeys[partitionKey] = idx - a.aggSize += len([]byte(partitionKey)) + + partitionKeyLen := len([]byte(partitionKey)) + a.aggSize += protowire.SizeBytes(partitionKeyLen) + fieldNumberSize return idx } +func (a *Aggregator) checkPartitionKey(partitionKey string) (uint64, int) { + if idx, ok := a.partitionKeys[partitionKey]; ok { + return idx, 0 + } + + idx := uint64(len(a.partitionKeys)) + partitionKeyLen := len([]byte(partitionKey)) + return idx, protowire.SizeBytes(partitionKeyLen) + fieldNumberSize +} + func (a *Aggregator) getPartitionKeys() []string { keys := make([]string, 0) for pk := range a.partitionKeys { @@ -146,11 +168,11 @@ func (a *Aggregator) getPartitionKeys() []string { // getSize of protobuf records, partitionKeys, magicNumber, and md5sum in bytes func (a *Aggregator) getSize() int { - return a.aggSize + kclMagicNumberLen + md5.Size + aggProtobufBytes + return kclMagicNumberLen + md5.Size + a.aggSize } func (a *Aggregator) clearBuffer() { a.partitionKeys = make(map[string]uint64, 0) a.records = make([]*Record, 0) - a.aggSize = 0 + a.aggSize = initialAggRecordSize } diff --git a/aggregate/aggregator_test.go b/aggregate/aggregator_test.go new file mode 100644 index 0000000..a46e3cf --- /dev/null +++ b/aggregate/aggregator_test.go @@ -0,0 +1,21 @@ +package aggregate + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const concurrencyRetryLimit = 4 + +func TestAddRecordCalculatesCorrectSize(t *testing.T) { + aggregator := NewAggregator() + + _, err := aggregator.AddRecord("test partition key", []byte("test value")) + assert.Equal(t, nil, err, "Expected aggregator not to return error") + assert.Equal(t, 36, aggregator.aggSize, "Expected aggregator to compute correct size") + + _, err = aggregator.AddRecord("test partition key 2", []byte("test value 2")) + assert.Equal(t, nil, err, "Expected aggregator not to return error") + assert.Equal(t, 76, aggregator.aggSize, "Expected aggregator to compute correct size") +}