diff --git a/listener.go b/listener.go index 1129622..b6655fc 100644 --- a/listener.go +++ b/listener.go @@ -222,7 +222,7 @@ func (l *listener) ConsumeClaim(session sarama.ConsumerGroupSession, claim saram } func (l *listener) onNewMessage(msg *sarama.ConsumerMessage, session sarama.ConsumerGroupSession) { - messageContext := context.WithValue(context.Background(), contextTopicKey, msg.Topic) + messageContext := context.WithValue(session.Context(), contextTopicKey, msg.Topic) messageContext = context.WithValue(messageContext, contextkeyKey, msg.Key) messageContext = context.WithValue(messageContext, contextOffsetKey, msg.Offset) messageContext = context.WithValue(messageContext, contextTimestampKey, msg.Timestamp) @@ -245,11 +245,13 @@ func (l *listener) onNewMessage(msg *sarama.ConsumerMessage, session sarama.Cons err := l.handleMessageWithRetry(messageContext, handler, msg, *handler.Config.ConsumerMaxRetries) if err != nil { - err = fmt.Errorf("processing failed after all possible attempts attempts: %w", err) + err = fmt.Errorf("processing failed: %w", err) l.handleErrorMessage(err, handler, msg) } - session.MarkMessage(msg, "") + if !errors.Is(err, context.Canceled) { + session.MarkMessage(msg, "") + } } func (l *listener) handleErrorMessage(initialError error, handler Handler, msg *sarama.ConsumerMessage) { diff --git a/listener_test.go b/listener_test.go index e8352a5..92c3ba6 100644 --- a/listener_test.go +++ b/listener_test.go @@ -172,6 +172,7 @@ func Test_ConsumeClaim_Happy_Path(t *testing.T) { consumerGroupClaim.On("Messages").Return((<-chan *sarama.ConsumerMessage)(msgChanel)) consumerGroupSession := &mocks.ConsumerGroupSession{} + consumerGroupSession.On("Context").Return(context.Background()) consumerGroupSession.On("MarkMessage", mock.Anything, mock.Anything).Return() handlerCalled := false @@ -215,6 +216,7 @@ func Test_ConsumeClaim_Message_Error_WithErrorTopic(t *testing.T) { consumerGroupClaim.On("Messages").Return((<-chan *sarama.ConsumerMessage)(msgChanel)) consumerGroupSession := &mocks.ConsumerGroupSession{} + consumerGroupSession.On("Context").Return(context.Background()) consumerGroupSession.On("MarkMessage", mock.Anything, mock.Anything).Return() producer := &mocks.MockProducer{} @@ -268,6 +270,7 @@ func Test_ConsumeClaim_Message_Error_WithPanicTopic(t *testing.T) { consumerGroupClaim.On("Messages").Return((<-chan *sarama.ConsumerMessage)(msgChanel)) consumerGroupSession := &mocks.ConsumerGroupSession{} + consumerGroupSession.On("Context").Return(context.Background()) consumerGroupSession.On("MarkMessage", mock.Anything, mock.Anything).Return() producer := &mocks.MockProducer{} @@ -322,6 +325,7 @@ func Test_ConsumeClaim_Message_Error_WithHandlerSpecificRetryTopic(t *testing.T) consumerGroupClaim.On("Messages").Return((<-chan *sarama.ConsumerMessage)(msgChanel)) consumerGroupSession := &mocks.ConsumerGroupSession{} + consumerGroupSession.On("Context").Return(context.Background()) consumerGroupSession.On("MarkMessage", mock.Anything, mock.Anything).Return() producer := &mocks.MockProducer{} @@ -368,6 +372,54 @@ func Test_ConsumeClaim_Message_Error_WithHandlerSpecificRetryTopic(t *testing.T) producer.AssertExpectations(t) } +func Test_ConsumeClaim_Message_Error_Context_Cancelled_Does_Not_Commit_Offset(t *testing.T) { + PushConsumerErrorsToRetryTopic = false + PushConsumerErrorsToDeadletterTopic = false + + // Arrange + msgChanel := make(chan *sarama.ConsumerMessage, 1) + msgChanel <- &sarama.ConsumerMessage{ + Topic: "topic-test", + } + close(msgChanel) + + consumerGroupClaim := &mocks.ConsumerGroupClaim{} + consumerGroupClaim.On("Messages").Return((<-chan *sarama.ConsumerMessage)(msgChanel)) + + consumerGroupSession := &mocks.ConsumerGroupSession{} + consumerGroupSession.On("Context").Return(context.Background()) + + producer := &mocks.MockProducer{} + + handlerCalled := false + handlerProcessor := func(ctx context.Context, msg *sarama.ConsumerMessage) error { + handlerCalled = true + return context.Canceled + } + handler := Handler{ + Processor: handlerProcessor, + Config: HandlerConfig{ + ConsumerMaxRetries: Ptr(3), + DurationBeforeRetry: Ptr(1 * time.Millisecond), + }, + } + + tested := listener{ + handlers: map[string]Handler{"topic-test": handler}, + deadletterProducer: producer, + } + + // Act + err := tested.ConsumeClaim(consumerGroupSession, consumerGroupClaim) + + // Assert + assert.NoError(t, err) + assert.True(t, handlerCalled) + consumerGroupClaim.AssertExpectations(t) + consumerGroupSession.AssertExpectations(t) + producer.AssertExpectations(t) +} + func Test_handleErrorMessage_OmittedError(t *testing.T) { omittedError := errors.New("This error should be omitted") @@ -501,6 +553,7 @@ func Test_ConsumerClaim_HappyPath_WithTracing(t *testing.T) { consumerGroupClaim.On("Messages").Return((<-chan *sarama.ConsumerMessage)(msgChanel)) consumerGroupSession := &mocks.ConsumerGroupSession{} + consumerGroupSession.On("Context").Return(context.Background()) consumerGroupSession.On("MarkMessage", mock.Anything, mock.Anything).Return() handlerCalled := false