Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDED] ConsumeContext.Closed() method for waiting for consume to be closed/drained #1691

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions jetstream/ordered.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ type (
cfg *OrderedConsumerConfig
stream string
currentConsumer *pullConsumer
currentSub ConsumeContext
currentSub *pullSubscription
cursor cursor
namePrefix string
serial int
consumerType consumerType
doReset chan struct{}
resetInProgress uint32
resetInProgress atomic.Uint32
userErrHandler ConsumeErrHandlerFunc
stopAfter int
stopAfterMsgsLeft chan int
Expand All @@ -52,7 +52,7 @@ type (
consumer *orderedConsumer
opts []PullMessagesOpt
done chan struct{}
closed uint32
closed atomic.Uint32
}

cursor struct {
Expand Down Expand Up @@ -138,7 +138,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
if err != nil {
return nil, err
}
c.currentSub = cc
c.currentSub = cc.(*pullSubscription)

go func() {
for {
Expand Down Expand Up @@ -175,7 +175,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
c.errHandler(c.serial)(cc, err)
} else {
c.Lock()
c.currentSub = cc
c.currentSub = cc.(*pullSubscription)
c.Unlock()
}
case <-sub.done:
Expand Down Expand Up @@ -210,8 +210,8 @@ func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err err
errors.Is(err, ErrConsumerDeleted) ||
errors.Is(err, errConnected) {
// only reset if serial matches the current consumer serial and there is no reset in progress
if serial == c.serial && atomic.LoadUint32(&c.resetInProgress) == 0 {
atomic.StoreUint32(&c.resetInProgress, 1)
if serial == c.serial && c.resetInProgress.Load() == 0 {
c.resetInProgress.Store(1)
c.doReset <- struct{}{}
}
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er
if err != nil {
return nil, err
}
c.currentSub = cc
c.currentSub = cc.(*pullSubscription)

sub := &orderedSubscription{
consumer: c,
Expand All @@ -270,7 +270,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er

func (s *orderedSubscription) Next() (Msg, error) {
for {
msg, err := s.consumer.currentSub.(*pullSubscription).Next()
msg, err := s.consumer.currentSub.Next()
if err != nil {
if errors.Is(err, ErrMsgIteratorClosed) {
s.Stop()
Expand All @@ -297,7 +297,7 @@ func (s *orderedSubscription) Next() (Msg, error) {
if err != nil {
return nil, err
}
s.consumer.currentSub = cc
s.consumer.currentSub = cc.(*pullSubscription)
continue
}

Expand All @@ -321,7 +321,7 @@ func (s *orderedSubscription) Next() (Msg, error) {
if err != nil {
return nil, err
}
s.consumer.currentSub = cc
s.consumer.currentSub = cc.(*pullSubscription)
continue
}
s.consumer.cursor.deliverSeq = dseq
Expand All @@ -331,7 +331,7 @@ func (s *orderedSubscription) Next() (Msg, error) {
}

func (s *orderedSubscription) Stop() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
if !s.closed.CompareAndSwap(0, 1) {
return
}
s.consumer.Lock()
Expand All @@ -343,7 +343,7 @@ func (s *orderedSubscription) Stop() {
}

func (s *orderedSubscription) Drain() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
if !s.closed.CompareAndSwap(0, 1) {
return
}
if s.consumer.currentSub != nil {
Expand All @@ -354,6 +354,37 @@ func (s *orderedSubscription) Drain() {
close(s.done)
}

// Closed returns a channel that is closed when the consuming is
// fully stopped/drained. When the channel is closed, no more messages
// will be received and processing is complete.
func (s *orderedSubscription) Closed() <-chan struct{} {
s.consumer.Lock()
defer s.consumer.Unlock()
closedCh := make(chan struct{})

go func() {
for {
s.consumer.Lock()
if s.consumer.currentSub == nil {
return
}

closed := s.consumer.currentSub.Closed()
s.consumer.Unlock()

// wait until the underlying pull consumer is closed
<-closed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment what does <-closed does and what s.closed.Load() does.

// if the subscription is closed and ordered consumer is closed as well,
// send a signal that the Consume() is fully stopped
if s.closed.Load() == 1 {
close(closedCh)
return
}
}
}()
return closedCh
}

// Fetch is used to retrieve up to a provided number of messages from a
// stream. This method will always send a single request and wait until
// either all messages are retrieved or request times out.
Expand Down Expand Up @@ -495,7 +526,7 @@ func serialNumberFromConsumer(name string) int {
func (c *orderedConsumer) reset() error {
c.Lock()
defer c.Unlock()
defer atomic.StoreUint32(&c.resetInProgress, 0)
defer c.resetInProgress.Store(0)
if c.currentConsumer != nil {
c.currentConsumer.Lock()
if c.currentSub != nil {
Expand Down Expand Up @@ -524,7 +555,7 @@ func (c *orderedConsumer) reset() error {
cancel: c.subscription.done,
}
err = retryWithBackoff(func(attempt int) (bool, error) {
isClosed := atomic.LoadUint32(&c.subscription.closed) == 1
isClosed := c.subscription.closed.Load() == 1
if isClosed {
return false, errOrderedConsumerClosed
}
Expand Down
30 changes: 30 additions & 0 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ type (
// Drain unsubscribes from the stream and cancels subscription.
// All messages that are already in the buffer will be processed in callback function.
Drain()

// Closed returns a channel that is closed when the consuming is
// fully stopped/drained. When the channel is closed, no more messages
// will be received and processing is complete.
Closed() <-chan struct{}
}

// MessageHandler is a handler function used as callback in [Consume].
Expand Down Expand Up @@ -125,6 +130,7 @@ type (
fetchNext chan *pullRequest
consumeOpts *consumeOpts
delivered int
closedCh chan struct{}
}

pendingMsgs struct {
Expand Down Expand Up @@ -257,6 +263,12 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
return func(subject string) {
p.subs.Delete(sid)
sub.draining.CompareAndSwap(1, 0)
sub.Lock()
if sub.closedCh != nil {
close(sub.closedCh)
sub.closedCh = nil
}
sub.Unlock()
}
}(sub.id))

Expand Down Expand Up @@ -649,6 +661,24 @@ func (s *pullSubscription) Drain() {
}
}

// Closed returns a channel that is closed when consuming is
// fully stopped/drained. When the channel is closed, no more messages
// will be received and processing is complete.
func (s *pullSubscription) Closed() <-chan struct{} {
s.Lock()
defer s.Unlock()
closedCh := s.closedCh
if closedCh == nil {
closedCh = make(chan struct{})
s.closedCh = closedCh
}
if !s.subscription.IsValid() {
close(s.closedCh)
s.closedCh = nil
}
return closedCh
}

// Fetch sends a single request to retrieve given number of messages.
// It will wait up to provided expiry time if not all messages are available.
func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) {
Expand Down
129 changes: 129 additions & 0 deletions jetstream/test/ordered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,135 @@ func TestOrderedConsumerConsume(t *testing.T) {
time.Sleep(50 * time.Millisecond)
}
})

t.Run("wait for closed after drain", func(t *testing.T) {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
closed := cc.Closed()
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
publishTestMsgs(t, js)

// wait for the consumer to be recreated before calling drain
for i := 0; i < 5; i++ {
_, err = c.Info(ctx)
if err != nil {
if errors.Is(err, jetstream.ErrConsumerNotFound) {
time.Sleep(100 * time.Millisecond)
continue
}
t.Fatalf("Unexpected error: %v", err)
}
break
}

cc.Drain()

select {
case <-closed:
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}

if len(msgs) != 2*len(testMsgs) {
t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs))
}
})
}
})

t.Run("wait for closed on already closed consume", func(t *testing.T) {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

cc.Stop()

time.Sleep(100 * time.Millisecond)

select {
case <-cc.Closed():
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}
})
}
})
}

func TestOrderedConsumerMessages(t *testing.T) {
Expand Down
Loading