diff --git a/nats.go b/nats.go index b3ba1f7b6..40f18489f 100644 --- a/nats.go +++ b/nats.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/nats-io/go-nats/util" @@ -351,6 +352,12 @@ type Msg struct { Data []byte Sub *Subscription next *Msg + barrier *barrierInfo +} + +type barrierInfo struct { + refs int64 + f func() } // Tracks various stats received and sent on this connection, @@ -1571,6 +1578,13 @@ func (nc *Conn) waitForMsgs(s *Subscription) { if s.pHead == nil { s.pTail = nil } + if m.barrier != nil { + s.mu.Unlock() + if atomic.AddInt64(&m.barrier.refs, -1) == 0 { + m.barrier.f() + } + continue + } s.pMsgs-- s.pBytes -= len(m.Data) } @@ -1599,6 +1613,19 @@ func (nc *Conn) waitForMsgs(s *Subscription) { break } } + // Check for barrier messages + s.mu.Lock() + for m := s.pHead; m != nil; m = s.pHead { + if m.barrier != nil { + s.mu.Unlock() + if atomic.AddInt64(&m.barrier.refs, -1) == 0 { + m.barrier.f() + } + s.mu.Lock() + } + s.pHead = m.next + } + s.mu.Unlock() } // processMsg is called by parse and will place the msg on the @@ -3006,3 +3033,48 @@ func (nc *Conn) TLSRequired() bool { defer nc.mu.Unlock() return nc.info.TLSRequired } + +// Barrier schedules the given function `f` to all registered asynchronous +// subscriptions. +// Only the last subscription to see this barrier will invoke the function. +// If no subscription is registered at the time of this call, `f()` is invoked +// right away. +// ErrConnectionClosed is returned if the connection is closed prior to +// the call. +func (nc *Conn) Barrier(f func()) error { + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.isClosed() { + return ErrConnectionClosed + } + nc.subsMu.Lock() + defer nc.subsMu.Unlock() + // Need to figure out how many non chan subscriptions there are + numSubs := 0 + for _, sub := range nc.subs { + if sub.typ == AsyncSubscription { + numSubs++ + } + } + if numSubs == 0 { + f() + return nil + } + barrier := &barrierInfo{refs: int64(numSubs), f: f} + for _, sub := range nc.subs { + sub.mu.Lock() + if sub.mch == nil { + msg := &Msg{barrier: barrier} + // Push onto the async pList + if sub.pTail != nil { + sub.pTail.next = msg + } else { + sub.pHead = msg + sub.pCond.Signal() + } + sub.pTail = msg + } + sub.mu.Unlock() + } + return nil +} diff --git a/test/conn_test.go b/test/conn_test.go index 6d2a77d47..b1ee5b0ce 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -1543,3 +1543,203 @@ func TestNewServers(t *testing.T) { t.Fatal("Did not get our callback") } } + +func TestBarrier(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + defer nc.Close() + + pubMsgs := int32(0) + ch := make(chan bool, 1) + + sub1, err := nc.Subscribe("pub", func(_ *nats.Msg) { + atomic.AddInt32(&pubMsgs, 1) + time.Sleep(250 * time.Millisecond) + }) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + + sub2, err := nc.Subscribe("close", func(_ *nats.Msg) { + // The "close" message was sent/received lat, but + // because we are dealing with different subscriptions, + // which are dispatched by different dispatchers, and + // because the "pub" subscription is delayed, this + // callback is likely to be invoked before the sub1's + // second callback is invoked. Using the Barrier call + // here will ensure that the given function will be invoked + // after the preceding messages have been dispatched. + nc.Barrier(func() { + res := atomic.LoadInt32(&pubMsgs) == 2 + ch <- res + }) + }) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + + // Send 2 "pub" messages followed by a "close" message + for i := 0; i < 2; i++ { + if err := nc.Publish("pub", []byte("pub msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + if err := nc.Publish("close", []byte("closing")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + + select { + case ok := <-ch: + if !ok { + t.Fatal("The barrier function was invoked before the second message") + } + case <-time.After(2 * time.Second): + t.Fatal("Waited for too long...") + } + + // Remove all subs + sub1.Unsubscribe() + sub2.Unsubscribe() + + // Barrier should be invoked in place. Since we use buffered channel + // we are ok. + nc.Barrier(func() { ch <- true }) + if err := Wait(ch); err != nil { + t.Fatal("Barrier function was not invoked") + } + + if _, err := nc.Subscribe("foo", func(m *nats.Msg) { + // To check that the Barrier() function works if the subscription + // is unsubscribed after the call was made, sleep a bit here. + time.Sleep(250 * time.Millisecond) + m.Sub.Unsubscribe() + }); err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + // We need to Flush here to make sure that message has been received + // and posted to subscription's internal queue before calling Barrier. + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + nc.Barrier(func() { ch <- true }) + if err := Wait(ch); err != nil { + t.Fatal("Barrier function was not invoked") + } + + // Test with AutoUnsubscribe now... + sub1, err = nc.Subscribe("foo", func(m *nats.Msg) { + // Since we auto-unsubscribe with 1, there should not be another + // invocation of this callback, but the Barrier should still be + // invoked. + nc.Barrier(func() { ch <- true }) + + }) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + sub1.AutoUnsubscribe(1) + // Send 2 messages and flush + for i := 0; i < 2; i++ { + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + // Check barrier was invoked + if err := Wait(ch); err != nil { + t.Fatal("Barrier function was not invoked") + } + + // Check that Barrier only affects asynchronous subscriptions + sub1, err = nc.Subscribe("foo", func(m *nats.Msg) { + nc.Barrier(func() { ch <- true }) + }) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + syncSub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + msgChan := make(chan *nats.Msg, 1) + chanSub, err := nc.ChanSubscribe("foo", msgChan) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + // Check barrier was invoked even if we did not yet consume + // from the 2 other type of subscriptions + if err := Wait(ch); err != nil { + t.Fatal("Barrier function was not invoked") + } + if _, err := syncSub.NextMsg(time.Second); err != nil { + t.Fatalf("Sync sub did not receive the message") + } + select { + case <-msgChan: + case <-time.After(time.Second): + t.Fatal("Chan sub did not receive the message") + } + chanSub.Unsubscribe() + syncSub.Unsubscribe() + sub1.Unsubscribe() + + atomic.StoreInt32(&pubMsgs, 0) + // Check barrier does not prevent new messages to be delivered. + sub1, err = nc.Subscribe("foo", func(_ *nats.Msg) { + if pm := atomic.AddInt32(&pubMsgs, 1); pm == 1 { + nc.Barrier(func() { + nc.Publish("foo", []byte("second")) + nc.Flush() + }) + } else if pm == 2 { + ch <- true + } + }) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("first")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + if err := Wait(ch); err != nil { + t.Fatal("Barrier function was not invoked") + } + sub1.Unsubscribe() + + // Check that barrier works if called before connection + // is closed. + if _, err := nc.Subscribe("bar", func(_ *nats.Msg) { + nc.Barrier(func() { ch <- true }) + nc.Close() + }); err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("bar", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + if err := Wait(ch); err != nil { + t.Fatal("Barrier function was not invoked") + } + + // Finally, check that if connection is closed, Barrier returns + // an error. + if err := nc.Barrier(func() { ch <- true }); err != nats.ErrConnectionClosed { + t.Fatalf("Expected error %v, got %v", nats.ErrConnectionClosed, err) + } +}