From 5f78d9008398c346ec69cab02ff03bf2c6bdbe47 Mon Sep 17 00:00:00 2001
From: Willem van Bergen <willem@railsdoctors.com>
Date: Thu, 12 Mar 2015 13:35:15 -0300
Subject: [PATCH] Add high water mark offset support to the consumer.

---
 consumer.go                 | 16 ++++++++++++++--
 consumer_test.go            | 26 ++++++++++++++++++++------
 functional_consumer_test.go | 34 ++++++++++++++++++++++++++++++++++
 mocks/consumer.go           | 11 ++++++++---
 4 files changed, 76 insertions(+), 11 deletions(-)

diff --git a/consumer.go b/consumer.go
index 2a7a94a8d..3adc0a7de 100644
--- a/consumer.go
+++ b/consumer.go
@@ -3,6 +3,7 @@ package sarama
 import (
 	"fmt"
 	"sync"
+	"sync/atomic"
 	"time"
 )
 
@@ -255,6 +256,11 @@ type PartitionConsumer interface {
 	// errors are logged and not returned over this channel. If you want to implement any custom errpr
 	// handling, set your config's Consumer.Return.Errors setting to true, and read from this channel.
 	Errors() <-chan *ConsumerError
+
+	// HighWaterMarkOffset returns the high water mark offset of the partition, i.e. the offset that will
+	// be used for the next message that will be produced. You can use this to determine how far behind
+	// the processing is.
+	HighWaterMarkOffset() int64
 }
 
 type partitionConsumer struct {
@@ -268,8 +274,9 @@ type partitionConsumer struct {
 	errors         chan *ConsumerError
 	trigger, dying chan none
 
-	fetchSize int32
-	offset    int64
+	fetchSize           int32
+	offset              int64
+	highWaterMarkOffset int64
 }
 
 func (child *partitionConsumer) sendError(err error) {
@@ -391,6 +398,10 @@ func (child *partitionConsumer) Close() error {
 	return nil
 }
 
+func (child *partitionConsumer) HighWaterMarkOffset() int64 {
+	return atomic.LoadInt64(&child.highWaterMarkOffset)
+}
+
 func (child *partitionConsumer) handleResponse(response *FetchResponse) error {
 	block := response.GetBlock(child.topic, child.partition)
 	if block == nil {
@@ -422,6 +433,7 @@ func (child *partitionConsumer) handleResponse(response *FetchResponse) error {
 
 	// we got messages, reset our fetch size in case it was increased for a previous request
 	child.fetchSize = child.conf.Consumer.Fetch.Default
+	atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
 
 	incomplete := false
 	atLeastOne := false
diff --git a/consumer_test.go b/consumer_test.go
index 0611c6e23..3d171121c 100644
--- a/consumer_test.go
+++ b/consumer_test.go
@@ -59,7 +59,7 @@ func TestConsumerOffsetManual(t *testing.T) {
 	leader.Close()
 }
 
-func TestConsumerLatestOffset(t *testing.T) {
+func TestConsumerOffsetNewest(t *testing.T) {
 	seedBroker := newMockBroker(t, 1)
 	leader := newMockBroker(t, 2)
 
@@ -69,15 +69,17 @@ func TestConsumerLatestOffset(t *testing.T) {
 	seedBroker.Returns(metadataResponse)
 
 	offsetResponseNewest := new(OffsetResponse)
-	offsetResponseNewest.AddTopicPartition("my_topic", 0, 0x010102)
+	offsetResponseNewest.AddTopicPartition("my_topic", 0, 10)
 	leader.Returns(offsetResponseNewest)
 
 	offsetResponseOldest := new(OffsetResponse)
-	offsetResponseOldest.AddTopicPartition("my_topic", 0, 0x010101)
+	offsetResponseOldest.AddTopicPartition("my_topic", 0, 7)
 	leader.Returns(offsetResponseOldest)
 
 	fetchResponse := new(FetchResponse)
-	fetchResponse.AddMessage("my_topic", 0, nil, ByteEncoder([]byte{0x00, 0x0E}), 0x010101)
+	fetchResponse.AddMessage("my_topic", 0, nil, ByteEncoder([]byte{0x00, 0x0E}), 10)
+	block := fetchResponse.GetBlock("my_topic", 0)
+	block.HighWaterMarkOffset = 14
 	leader.Returns(fetchResponse)
 
 	master, err := NewConsumer([]string{seedBroker.Addr()}, nil)
@@ -91,12 +93,24 @@ func TestConsumerLatestOffset(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	msg := <-consumer.Messages()
+
+	// we deliver one message, so it should be one higher than we return in the OffsetResponse
+	if msg.Offset != 10 {
+		t.Error("Latest message offset not fetched correctly:", msg.Offset)
+	}
+
+	if hwmo := consumer.HighWaterMarkOffset(); hwmo != 14 {
+		t.Errorf("Expected high water mark offset 14, found %d", hwmo)
+	}
+
 	leader.Close()
 	safeClose(t, consumer)
 	safeClose(t, master)
 
-	// we deliver one message, so it should be one higher than we return in the OffsetResponse
-	if consumer.(*partitionConsumer).offset != 0x010102 {
+	// We deliver one message, so it should be one higher than we return in the OffsetResponse.
+	// This way it is set correctly for the next FetchRequest.
+	if consumer.(*partitionConsumer).offset != 11 {
 		t.Error("Latest offset not fetched correctly:", consumer.(*partitionConsumer).offset)
 	}
 }
diff --git a/functional_consumer_test.go b/functional_consumer_test.go
index 73074f28d..6afd5cc52 100644
--- a/functional_consumer_test.go
+++ b/functional_consumer_test.go
@@ -23,3 +23,37 @@ func TestFuncConsumerOffsetOutOfRange(t *testing.T) {
 
 	safeClose(t, consumer)
 }
+
+func TestConsumerHighWaterMarkOffset(t *testing.T) {
+	checkKafkaAvailability(t)
+
+	p, err := NewSyncProducer(kafkaBrokers, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer safeClose(t, p)
+
+	_, offset, err := p.SendMessage(&ProducerMessage{Topic: "test.1", Value: StringEncoder("Test")})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	c, err := NewConsumer(kafkaBrokers, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer safeClose(t, c)
+
+	pc, err := c.ConsumePartition("test.1", 0, OffsetOldest)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	<-pc.Messages()
+
+	if hwmo := pc.HighWaterMarkOffset(); hwmo != offset+1 {
+		t.Logf("Last produced offset %d; high water mark should be one higher but found %d.", offset, hwmo)
+	}
+
+	safeClose(t, pc)
+}
diff --git a/mocks/consumer.go b/mocks/consumer.go
index ff851e7f5..acf0894ee 100644
--- a/mocks/consumer.go
+++ b/mocks/consumer.go
@@ -2,6 +2,7 @@ package mocks
 
 import (
 	"sync"
+	"sync/atomic"
 
 	"github.com/Shopify/sarama"
 )
@@ -175,13 +176,13 @@ type PartitionConsumer struct {
 	consumed                bool
 	errorsShouldBeDrained   bool
 	messagesShouldBeDrained bool
+	highWaterMarkOffset     int64
 }
 
 func (pc *PartitionConsumer) handleExpectations() {
 	pc.l.Lock()
 	defer pc.l.Unlock()
 
-	var offset int64
 	for ex := range pc.expectations {
 		if ex.Err != nil {
 			pc.errors <- &sarama.ConsumerError{
@@ -190,11 +191,11 @@ func (pc *PartitionConsumer) handleExpectations() {
 				Err:       ex.Err,
 			}
 		} else {
-			offset++
+			atomic.AddInt64(&pc.highWaterMarkOffset, 1)
 
 			ex.Msg.Topic = pc.topic
 			ex.Msg.Partition = pc.partition
-			ex.Msg.Offset = offset
+			ex.Msg.Offset = atomic.LoadInt64(&pc.highWaterMarkOffset)
 
 			pc.messages <- ex.Msg
 		}
@@ -274,6 +275,10 @@ func (pc *PartitionConsumer) Messages() <-chan *sarama.ConsumerMessage {
 	return pc.messages
 }
 
+func (pc *PartitionConsumer) HighWaterMarkOffset() int64 {
+	return atomic.LoadInt64(&pc.highWaterMarkOffset) + 1
+}
+
 ///////////////////////////////////////////////////
 // Expectation API
 ///////////////////////////////////////////////////