Skip to content

Commit

Permalink
Added tests for Stop()
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <[email protected]>
  • Loading branch information
piotrpio committed Jan 10, 2024
1 parent 3e47919 commit 595538e
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 7 deletions.
11 changes: 8 additions & 3 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type (

// Stop unsubscribes from the stream and cancels subscription. Calling
// Next after calling Stop will return ErrMsgIteratorClosed error.
// All messages that are already in the buffer are discarded.
Stop()

// Drain unsubscribes from the stream and cancels subscription. All
Expand All @@ -48,6 +49,7 @@ type (
ConsumeContext interface {
// Stop unsubscribes from the stream and cancels subscription.
// No more messages will be received after calling this method.
// All messages that are already in the buffer are discarded.
Stop()

// Drain unsubscribes from the stream and cancels subscription.
Expand Down Expand Up @@ -261,7 +263,8 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
return func(subject string) {
p.Lock()
defer p.Unlock()
delete(sub.consumer.subscriptions, sid)
delete(p.subscriptions, sid)
atomic.CompareAndSwapUint32(&sub.draining, 1, 0)
}
}(sub.id))

Expand Down Expand Up @@ -527,7 +530,9 @@ var (
func (s *pullSubscription) Next() (Msg, error) {
s.Lock()
defer s.Unlock()
if len(s.msgs) == 0 && (s.subscription == nil || !s.subscription.IsValid()) {
drainMode := atomic.LoadUint32(&s.draining) == 1
closed := atomic.LoadUint32(&s.closed) == 1
if closed && !drainMode {
return nil, ErrMsgIteratorClosed
}
hbMonitor := s.scheduleHeartbeatCheck(2 * s.consumeOpts.Heartbeat)
Expand Down Expand Up @@ -556,6 +561,7 @@ func (s *pullSubscription) Next() (Msg, error) {
if !ok {
// if msgs channel is closed, it means that subscription was either drained or stopped
delete(s.consumer.subscriptions, s.id)
atomic.CompareAndSwapUint32(&s.draining, 1, 0)
return nil, ErrMsgIteratorClosed
}
if hbMonitor != nil {
Expand Down Expand Up @@ -907,7 +913,6 @@ func (s *pullSubscription) cleanup() {
if s.hbMonitor != nil {
s.hbMonitor.Stop()
}
close(s.connStatusChanged)
drainMode := atomic.LoadUint32(&s.draining) == 1
if drainMode {
s.subscription.Drain()
Expand Down
14 changes: 10 additions & 4 deletions jetstream/test/jetstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,16 @@ func TestCreateStreamMirrorCrossDomains(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if lStream.CachedInfo().State.Msgs != 3 {
t.Fatalf("Expected 3 msgs in stream; got: %d", lStream.CachedInfo().State.Msgs)
}
checkFor(t, 2*time.Second, 15*time.Millisecond, func() error {
info, err := lStream.Info(ctx)
if err != nil {
return fmt.Errorf("Unexpected error when getting stream info: %v", err)
}
if info.State.Msgs != 3 {
return fmt.Errorf("Expected 3 msgs in stream; got: %d", lStream.CachedInfo().State.Msgs)
}
return nil
})

rjs, err := jetstream.NewWithDomain(lnc, "HUB")
if err != nil {
Expand Down
103 changes: 103 additions & 0 deletions jetstream/test/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,61 @@ func TestPullConsumerMessages(t *testing.T) {
}
})

t.Run("no messages received after stop", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

msgs := make([]jetstream.Msg, 0)
it, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

publishTestMsgs(t, nc)
go func() {
time.Sleep(100 * time.Millisecond)
it.Stop()
}()
for i := 0; i < 2; i++ {
msg, err := it.Next()
if err != nil {
t.Fatal(err)
}
time.Sleep(80 * time.Millisecond)
msg.Ack()
msgs = append(msgs, msg)
}
_, err = it.Next()
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err)
}

if len(msgs) != 2 {
t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs))
}
})

t.Run("drain mode", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
Expand Down Expand Up @@ -2096,6 +2151,54 @@ func TestPullConsumerConsume(t *testing.T) {
wg.Wait()
})

t.Run("no messages received after stop", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
wg := &sync.WaitGroup{}
wg.Add(2)
publishTestMsgs(t, nc)
msgs := make([]jetstream.Msg, 0)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(80 * time.Millisecond)
msg.Ack()
msgs = append(msgs, msg)
wg.Done()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
time.Sleep(100 * time.Millisecond)
cc.Stop()
wg.Wait()
// wait for some time to make sure no new messages are received
time.Sleep(100 * time.Millisecond)

if len(msgs) != 2 {
t.Fatalf("Unexpected received message count after stop; want 2; got %d", len(msgs))
}
})

t.Run("drain mode", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
Expand Down

0 comments on commit 595538e

Please sign in to comment.