diff --git a/.attic/topics/mem/README.md b/.attic/topics/mem/README.md new file mode 100644 index 0000000..e69de29 diff --git a/.attic/topics/mem/node.go b/.attic/topics/mem/node.go new file mode 100644 index 0000000..ceb8fbc --- /dev/null +++ b/.attic/topics/mem/node.go @@ -0,0 +1,342 @@ +package mem + +import ( + "strings" + + "github.com/VolantMQ/vlapi/mqttp" + "github.com/VolantMQ/vlapi/vlmonitoring" + "github.com/VolantMQ/vlapi/vlsubscriber" + "github.com/VolantMQ/vlapi/vltypes" + + topicstypes "github.com/VolantMQ/volantmq/topics/types" +) + +type topicSubscriber struct { + s topicstypes.Subscriber + p vlsubscriber.SubscriptionParams +} + +type subscribers map[uintptr]*topicSubscriber + +func (s *topicSubscriber) acquire() *publish { + pe := &publish{ + s: s.s, + qos: s.p.Granted, + ops: s.p.Ops, + } + + if s.p.ID > 0 { + pe.ids = []uint32{s.p.ID} + } + + return pe +} + +type publish struct { + s topicstypes.Subscriber + ops mqttp.SubscriptionOptions + qos mqttp.QosType + ids []uint32 +} + +type publishes map[uintptr][]*publish + +type node struct { + retained interface{} + subs subscribers + parent *node + children map[string]*node + getSubscribers func(uintptr, *publishes) +} + +func newNode(overlap bool, parent *node) *node { + n := &node{ + subs: make(subscribers), + children: make(map[string]*node), + parent: parent, + } + + if overlap { + n.getSubscribers = n.overlappingSubscribers + } else { + n.getSubscribers = n.nonOverlappingSubscribers + } + + return n +} + +func (mT *provider) leafInsertNode(levels []string) *node { + root := mT.root + + for _, level := range levels { + // Add node if it doesn't already exist + n, ok := root.children[level] + if !ok { + n = newNode(mT.allowOverlapping, root) + + root.children[level] = n + } + + root = n + } + + return root +} + +func (mT *provider) leafSearchNode(levels []string) *node { + root := mT.root + + // run down and try get path matching given topic + for _, token := range levels { + n, ok := root.children[token] + if !ok { + return nil + } + + root = n + } + + return root +} + +func (mT *provider) subscriptionInsert(filter string, sub topicstypes.Subscriber, p vlsubscriber.SubscriptionParams) bool { + levels := strings.Split(filter, "/") + + root := mT.leafInsertNode(levels) + + // Let's see if the subscriber is already on the list and just update QoS if so + // Otherwise create new entry + exists := false + if s, ok := root.subs[sub.Hash()]; !ok { + root.subs[sub.Hash()] = &topicSubscriber{ + s: sub, + p: p, + } + } else { + s.p = p + exists = true + } + + return exists +} + +func (mT *provider) subscriptionRemove(topic string, sub topicstypes.Subscriber) error { + levels := strings.Split(topic, "/") + + var err error + + root := mT.leafSearchNode(levels) + if root == nil { + return topicstypes.ErrNotFound + } + + // path matching the topic exists. + // if subscriber argument is nil remove all of subscribers + // otherwise try remove subscriber or set error if not exists + if sub == nil { + // If subscriber == nil, then it's signal to remove ALL subscribers + root.subs = make(subscribers) + } else { + id := sub.Hash() + if _, ok := root.subs[id]; ok { + delete(root.subs, id) + } else { + err = topicstypes.ErrNotFound + } + } + + // Run up and on each level and check if level has subscriptions and nested nodes + // If both are empty tell parent node to remove that token + level := len(levels) + for leafNode := root; leafNode != nil; leafNode = leafNode.parent { + // If there are no more subscribers or inner nodes or retained messages remove this node from parent + if len(leafNode.subs) == 0 && len(leafNode.children) == 0 && leafNode.retained == nil { + // if this is not root node + mT.onCleanUnsubscribe(levels[:level]) + if leafNode.parent != nil { + delete(leafNode.parent.children, levels[level-1]) + } + } + + level-- + } + + return err +} + +func subscriptionRecurseSearch(root *node, levels []string, publishID uintptr, p *publishes) { + if len(levels) == 0 { + // leaf level of the topic + // get all subscribers and return + root.getSubscribers(publishID, p) + if n, ok := root.children[topicstypes.MWC]; ok { + n.getSubscribers(publishID, p) + } + } else { + if n, ok := root.children[topicstypes.MWC]; ok && len(levels[0]) != 0 { + n.getSubscribers(publishID, p) + } + + if n, ok := root.children[levels[0]]; ok { + subscriptionRecurseSearch(n, levels[1:], publishID, p) + } + + if n, ok := root.children[topicstypes.SWC]; ok { + subscriptionRecurseSearch(n, levels[1:], publishID, p) + } + } +} + +func (mT *provider) subscriptionSearch(topic string, publishID uintptr, p *publishes) { + root := mT.root + levels := strings.Split(topic, "/") + level := levels[0] + + if !strings.HasPrefix(level, "$") { + subscriptionRecurseSearch(root, levels, publishID, p) + } else if n, ok := root.children[level]; ok { + subscriptionRecurseSearch(n, levels[1:], publishID, p) + } +} + +func (mT *provider) retainInsert(topic string, obj vltypes.RetainObject) { + levels := strings.Split(topic, "/") + + root := mT.leafInsertNode(levels) + + root.retained = obj +} + +func (mT *provider) retainRemove(topic string) error { + levels := strings.Split(topic, "/") + + root := mT.leafSearchNode(levels) + if root == nil { + return topicstypes.ErrNotFound + } + + root.retained = nil + + // Run up and on each level and check if level has subscriptions and nested nodes + // If both are empty tell parent node to remove that token + level := len(levels) + for leafNode := root; leafNode != nil; leafNode = leafNode.parent { + // If there are no more subscribers or inner nodes or retained messages remove this node from parent + if len(leafNode.subs) == 0 && len(leafNode.children) == 0 && leafNode.retained == nil { + // if this is not root node + if leafNode.parent != nil { + delete(leafNode.parent.children, levels[level-1]) + } + } + + level-- + } + + return nil +} + +func retainRecurseSearch(root *node, levels []string, retained *[]*mqttp.Publish) { + if len(levels) == 0 { + // leaf level of the topic + root.getRetained(retained) + if n, ok := root.children[topicstypes.MWC]; ok { + n.allRetained(retained) + } + } else { + switch levels[0] { + case topicstypes.MWC: + // If '#', add all retained messages starting this node + root.allRetained(retained) + return + case topicstypes.SWC: + // If '+', check all nodes at this level. Next levels must be matched. + for _, n := range root.children { + retainRecurseSearch(n, levels[1:], retained) + } + default: + if n, ok := root.children[levels[0]]; ok { + retainRecurseSearch(n, levels[1:], retained) + } + } + } +} + +func (mT *provider) retainSearch(filter string, retained *[]*mqttp.Publish) { + levels := strings.Split(filter, "/") + level := levels[0] + + if level == topicstypes.MWC { + for t, n := range mT.root.children { + if t != "" && !strings.HasPrefix(t, "$") { + n.allRetained(retained) + } + } + } else if strings.HasPrefix(level, "$") && mT.root.children[level] != nil { + retainRecurseSearch(mT.root.children[level], levels[1:], retained) + } else { + retainRecurseSearch(mT.root, levels, retained) + } +} + +func (sn *node) getRetained(retained *[]*mqttp.Publish) { + if sn.retained != nil { + var p *mqttp.Publish + + switch t := sn.retained.(type) { + case *mqttp.Publish: + p = t + case vlmonitoring.DynamicIFace: + p = t.Retained() + default: + panic("unknown retained type") + } + + // if publish has expiration set check if there time left to live + if _, _, expired := p.Expired(); !expired { + *retained = append(*retained, p) + } else { + // publish has expired, thus nobody should get it + sn.retained = nil + } + } +} + +func (sn *node) allRetained(retained *[]*mqttp.Publish) { + sn.getRetained(retained) + + for _, n := range sn.children { + n.allRetained(retained) + } +} + +func (sn *node) overlappingSubscribers(publishID uintptr, p *publishes) { + for id, sub := range sn.subs { + if s, ok := (*p)[id]; ok { + if sub.p.ID > 0 { + s[0].ids = append(s[0].ids, sub.p.ID) + } + + if s[0].qos < sub.p.Granted { + s[0].qos = sub.p.Granted + } + } else { + if !sub.p.Ops.NL() || id != publishID { + pe := sub.acquire() + (*p)[id] = append((*p)[id], pe) + } + } + } +} + +func (sn *node) nonOverlappingSubscribers(publishID uintptr, p *publishes) { + for id, sub := range sn.subs { + if !sub.p.Ops.NL() || id != publishID { + pe := sub.acquire() + if _, ok := (*p)[id]; ok { + (*p)[id] = append((*p)[id], pe) + } else { + (*p)[id] = []*publish{pe} + } + } + } +} diff --git a/.attic/topics/mem/topics.go b/.attic/topics/mem/topics.go new file mode 100644 index 0000000..3845935 --- /dev/null +++ b/.attic/topics/mem/topics.go @@ -0,0 +1,352 @@ +// Copyright (c) 2014 The VolantMQ Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mem + +import ( + "sync" + "time" + + "github.com/VolantMQ/vlapi/mqttp" + "github.com/VolantMQ/vlapi/vlpersistence" + "github.com/VolantMQ/vlapi/vltypes" + "go.uber.org/zap" + + "github.com/VolantMQ/volantmq/configuration" + "github.com/VolantMQ/volantmq/metrics" + topicstypes "github.com/VolantMQ/volantmq/topics/types" +) + +type provider struct { + smu sync.RWMutex + root *node + metricsPackets metrics.Packets + metricsSubs metrics.Subscriptions + persist vlpersistence.Retained + log *zap.SugaredLogger + onCleanUnsubscribe func([]string) + wgPublisher sync.WaitGroup + wgPublisherStarted sync.WaitGroup + inbound chan *mqttp.Publish + inRetained chan vltypes.RetainObject + subIn chan topicstypes.SubscribeReq + unSubIn chan topicstypes.UnSubscribeReq + allowOverlapping bool +} + +var _ topicstypes.Provider = (*provider)(nil) + +// NewMemProvider returns an new instance of the provider, which is implements the +// TopicsProvider interface. provider is a hidden struct that stores the topic +// subscriptions and retained messages in memory. The content is not persistent so +// when the server goes, everything will be gone. Use with care. +func NewMemProvider(config *topicstypes.MemConfig) (topicstypes.Provider, error) { + p := &provider{ + metricsPackets: config.MetricsPackets, + metricsSubs: config.MetricsSubs, + persist: config.Persist, + onCleanUnsubscribe: config.OnCleanUnsubscribe, + inbound: make(chan *mqttp.Publish, 1024*512), + inRetained: make(chan vltypes.RetainObject, 1024*512), + subIn: make(chan topicstypes.SubscribeReq, 1024*512), + unSubIn: make(chan topicstypes.UnSubscribeReq, 1024*512), + } + p.root = newNode(p.allowOverlapping, nil) + + p.log = configuration.GetLogger().Named("topics").Named(config.Name) + + if p.persist != nil { + entries, err := p.persist.Load() + if err != nil && err != vlpersistence.ErrNotFound { + return nil, err + } + + for _, d := range entries { + v := mqttp.ProtocolVersion(d.Data[0]) + var pkt mqttp.IFace + pkt, _, err = mqttp.Decode(v, d.Data[1:]) + if err != nil { + p.log.Error("Couldn't decode retained message", zap.Error(err)) + } else { + if m, ok := pkt.(*mqttp.Publish); ok { + if len(d.ExpireAt) > 0 { + var tm time.Time + if tm, err = time.Parse(time.RFC3339, d.ExpireAt); err == nil { + m.SetExpireAt(tm) + } else { + p.log.Error("Decode publish expire at", zap.Error(err)) + } + } + _ = p.Retain(m) // nolint: errcheck + } else { + p.log.Warn("Unsupported retained message type", zap.String("type", m.Type().Name())) + } + } + } + } + + publisherCount := 1 + subsCount := 1 + unSunCount := 1 + + p.wgPublisher.Add(publisherCount + subsCount + unSunCount + 1) + p.wgPublisherStarted.Add(publisherCount + subsCount + unSunCount + 1) + + for i := 0; i < publisherCount; i++ { + go p.publisher() + } + + go p.retainer() + + for i := 0; i < subsCount; i++ { + go p.subscriber() + } + + for i := 0; i < unSunCount; i++ { + go p.unSubscriber() + } + + p.wgPublisherStarted.Wait() + + return p, nil +} + +func (mT *provider) Subscribe(req topicstypes.SubscribeReq) topicstypes.SubscribeResp { + cAllocated := false + + if req.Chan == nil { + cAllocated = true + req.Chan = make(chan topicstypes.SubscribeResp) + } + + mT.subIn <- req + + resp := <-req.Chan + + if cAllocated { + close(req.Chan) + } + + return resp +} + +func (mT *provider) UnSubscribe(req topicstypes.UnSubscribeReq) topicstypes.UnSubscribeResp { + cAllocated := false + + if req.Chan == nil { + cAllocated = true + req.Chan = make(chan topicstypes.UnSubscribeResp) + } + + mT.unSubIn <- req + + resp := <-req.Chan + + if cAllocated { + close(req.Chan) + } + + return resp +} + +func (mT *provider) unSubscribe(topic string, sub topicstypes.Subscriber) error { + defer mT.smu.Unlock() + mT.smu.Lock() + + return mT.subscriptionRemove(topic, sub) +} + +func (mT *provider) Publish(m interface{}) error { + msg, ok := m.(*mqttp.Publish) + if !ok { + return topicstypes.ErrUnexpectedObjectType + } + mT.inbound <- msg + + return nil +} + +func (mT *provider) Retain(obj vltypes.RetainObject) error { + mT.inRetained <- obj + + return nil +} + +func (mT *provider) Retained(filter string) ([]*mqttp.Publish, error) { + // [MQTT-3.3.1-5] + var r []*mqttp.Publish + + defer mT.smu.Unlock() + mT.smu.Lock() + + // [MQTT-3.3.1-5] + mT.retainSearch(filter, &r) + + return r, nil +} + +func (mT *provider) Shutdown() error { + defer mT.smu.Unlock() + mT.smu.Lock() + + close(mT.inbound) + close(mT.inRetained) + close(mT.subIn) + close(mT.unSubIn) + + mT.wgPublisher.Wait() + + if mT.persist != nil { + var res []*mqttp.Publish + // [MQTT-3.3.1-5] + mT.retainSearch("#", &res) + mT.retainSearch("/#", &res) + + var encoded []*vlpersistence.PersistedPacket + + for _, pkt := range res { + // Discard retained expired and QoS0 messages + if expireAt, _, expired := pkt.Expired(); !expired && pkt.QoS() != mqttp.QoS0 { + if buf, err := mqttp.Encode(pkt); err != nil { + mT.log.Error("Couldn't encode retained message", zap.Error(err)) + } else { + entry := &vlpersistence.PersistedPacket{ + Data: buf, + } + if !expireAt.IsZero() { + entry.ExpireAt = expireAt.Format(time.RFC3339) + } + encoded = append(encoded, entry) + } + } + } + if len(encoded) > 0 { + mT.log.Debug("Storing retained messages", zap.Int("amount", len(encoded))) + if err := mT.persist.Store(encoded); err != nil { + mT.log.Error("Couldn't persist retained messages", zap.Error(err)) + } + } + } + + mT.root = nil + return nil +} + +func (mT *provider) retain(obj vltypes.RetainObject) { + insert := true + + mT.smu.Lock() + + switch t := obj.(type) { + case *mqttp.Publish: + // [MQTT-3.3.1-10] + // [MQTT-3.3.1-7] + if len(t.Payload()) == 0 || t.QoS() == mqttp.QoS0 { + _ = mT.retainRemove(obj.Topic()) // nolint: errcheck + if len(t.Payload()) == 0 { + insert = false + } + } + } + + if insert { + mT.retainInsert(obj.Topic(), obj) + mT.metricsPackets.OnAddRetain() + } else { + mT.metricsPackets.OnSubRetain() + } + + mT.smu.Unlock() +} + +func (mT *provider) subscriber() { + defer mT.wgPublisher.Done() + mT.wgPublisherStarted.Done() + + for req := range mT.subIn { + var resp topicstypes.SubscribeResp + + if req.S == nil { + resp.Err = topicstypes.ErrInvalidArgs + } else { + if req.Params.Ops.QoS() > mqttp.QoS2 { + resp.Err = mqttp.ErrInvalidQoS + } else { + req.Params.Granted = req.Params.Ops.QoS() + resp.Params = req.Params + + exists := mT.subscriptionInsert(req.Filter, req.S, req.Params) + + var r []*mqttp.Publish + + // [MQTT-3.3.1-5] + rh := req.Params.Ops.RetainHandling() + if (rh == mqttp.RetainHandlingRetain) || ((rh == mqttp.RetainHandlingIfNotExists) && !exists) { + mT.retainSearch(req.Filter, &r) + } + + resp.Retained = r + + if !exists { + mT.metricsSubs.OnSubscribe() + } + } + } + + req.Chan <- resp + } +} + +func (mT *provider) unSubscriber() { + defer mT.wgPublisher.Done() + mT.wgPublisherStarted.Done() + + for req := range mT.unSubIn { + err := mT.unSubscribe(req.Filter, req.S) + req.Chan <- topicstypes.UnSubscribeResp{Err: err} + if err == nil { + mT.metricsSubs.OnUnsubscribe() + } + } +} + +func (mT *provider) retainer() { + defer mT.wgPublisher.Done() + mT.wgPublisherStarted.Done() + + for obj := range mT.inRetained { + mT.retain(obj) + } +} + +func (mT *provider) publisher() { + defer mT.wgPublisher.Done() + mT.wgPublisherStarted.Done() + + for msg := range mT.inbound { + pubEntries := publishes{} + + mT.smu.Lock() + mT.subscriptionSearch(msg.Topic(), msg.PublishID(), &pubEntries) + mT.smu.Unlock() + + for _, pub := range pubEntries { + for _, e := range pub { + if err := e.s.Publish(msg, e.qos, e.ops, e.ids); err != nil { + mT.log.Error("Publish error", zap.Error(err)) + } + } + } + } +} diff --git a/.attic/topics/mem/trie_test.go b/.attic/topics/mem/trie_test.go new file mode 100644 index 0000000..ba0ad59 --- /dev/null +++ b/.attic/topics/mem/trie_test.go @@ -0,0 +1,726 @@ +package mem + +import ( + "testing" + + "github.com/VolantMQ/vlapi/mqttp" + "github.com/VolantMQ/vlapi/vlsubscriber" + "github.com/VolantMQ/vlapi/vltypes" + "github.com/stretchr/testify/require" + + "github.com/VolantMQ/volantmq/metrics" + "github.com/VolantMQ/volantmq/subscriber" + + topicstypes "github.com/VolantMQ/volantmq/topics/types" +) + +var config *topicstypes.MemConfig +var retainedSystree []vltypes.RetainObject + +func init() { + metric := metrics.New() + + config = topicstypes.NewMemConfig() + config.MetricsSubs = metric.Subs() + config.MetricsPackets = metric.Packets() +} + +func allocProvider(t *testing.T) *provider { + prov, err := NewMemProvider(config) + require.NoError(t, err) + + if p, ok := prov.(*provider); ok { + return p + } + + t.Fail() + return nil +} + +func TestMatch1(t *testing.T) { + prov := allocProvider(t) + sub := &subscriber.Type{} + + req := topicstypes.SubscribeReq{ + Filter: "sport/tennis/player1/#", + S: sub, + Params: vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + }, + } + + resp := prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + + subs := publishes{} + + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subs) + require.Equal(t, 1, len(subs)) +} + +func TestMatch2(t *testing.T) { + prov := allocProvider(t) + + sub := &subscriber.Type{} + + req := topicstypes.SubscribeReq{ + Filter: "sport/tennis/player1/#", + S: sub, + Params: vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + }, + } + + resp := prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + + subs := publishes{} + + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subs) + require.Equal(t, 1, len(subs)) +} + +func TestSNodeMatch3(t *testing.T) { + prov := allocProvider(t) + + sub := &subscriber.Type{} + + req := topicstypes.SubscribeReq{ + Filter: "sport/tennis/#", + S: sub, + Params: vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + }, + } + + resp := prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + + subs := publishes{} + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subs) + require.Equal(t, 1, len(subs)) +} + +func TestMatch4(t *testing.T) { + prov := allocProvider(t) + sub := &subscriber.Type{} + + req := topicstypes.SubscribeReq{ + Filter: "#", + S: sub, + Params: vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + }, + } + + resp := prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + + subs := publishes{} + + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subs) + require.Equal(t, 1, len(subs), "should return subscribers") + + subs = publishes{} + prov.subscriptionSearch("/sport/tennis/player1/anzel", 0, &subs) + require.Equal(t, 0, len(subs), "should not return subscribers") + + err := prov.subscriptionRemove("#", sub) + require.NoError(t, err) + + subs = publishes{} + prov.subscriptionSearch("#", 0, &subs) + require.Equal(t, 0, len(subs), "should not return subscribers") + + prov.subscriptionInsert("/#", sub, req.Params) + + subs = publishes{} + prov.subscriptionSearch("bla", 0, &subs) + require.Equal(t, 0, len(subs), "should not return subscribers") + + subs = publishes{} + prov.subscriptionSearch("/bla", 0, &subs) + require.Equal(t, 1, len(subs), "should return subscribers") + + err = prov.subscriptionRemove("/#", sub) + require.NoError(t, err) + + prov.subscriptionInsert("bla/bla/#", sub, req.Params) + + subs = publishes{} + prov.subscriptionSearch("bla", 0, &subs) + require.Equal(t, 0, len(subs), "should not return subscribers") + + subs = publishes{} + prov.subscriptionSearch("bla/bla", 0, &subs) + require.Equal(t, 1, len(subs), "should return subscribers") + + subs = publishes{} + prov.subscriptionSearch("bla/bla/bla", 0, &subs) + require.Equal(t, 1, len(subs), "should return subscribers") + + subs = publishes{} + prov.subscriptionSearch("bla/bla/bla/bla", 0, &subs) + require.Equal(t, 1, len(subs), "should return subscribers") +} + +func TestMatch5(t *testing.T) { + prov := allocProvider(t) + sub1 := &subscriber.Type{} + sub2 := &subscriber.Type{} + + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert("sport/tennis/+/+/#", sub1, p) + prov.subscriptionInsert("sport/tennis/player1/anzel", sub2, p) + + subs := publishes{} + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subs) + + require.Equal(t, 2, len(subs)) +} + +func TestMatch6(t *testing.T) { + prov := allocProvider(t) + sub1 := &subscriber.Type{} + sub2 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert("sport/tennis/+/+/+/+/#", sub1, p) + prov.subscriptionInsert("sport/tennis/player1/anzel", sub2, p) + + subs := publishes{} + prov.subscriptionSearch("sport/tennis/player1/anzel/bla/bla", 0, &subs) + require.Equal(t, 1, len(subs)) +} + +func TestMatch7(t *testing.T) { + prov := allocProvider(t) + + sub1 := &subscriber.Type{} + sub2 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + } + + prov.subscriptionInsert("sport/tennis/#", sub1, p) + + p.Ops = mqttp.SubscriptionOptions(mqttp.QoS1) + + prov.subscriptionInsert("sport/tennis", sub2, p) + + subs := publishes{} + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subs) + require.Equal(t, 1, len(subs)) + require.Equal(t, sub1, subs[sub1.Hash()][0].s) +} + +func TestMatch8(t *testing.T) { + prov := allocProvider(t) + + sub := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + } + + prov.subscriptionInsert("+/+", sub, p) + + subs := publishes{} + + prov.subscriptionSearch("/finance", 0, &subs) + require.Equal(t, 1, len(subs)) +} + +func TestMatch9(t *testing.T) { + prov := allocProvider(t) + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + } + + prov.subscriptionInsert("/+", sub1, p) + + subs := publishes{} + + prov.subscriptionSearch("/finance", 0, &subs) + require.Equal(t, 1, len(subs)) +} + +func TestMatch10(t *testing.T) { + prov := allocProvider(t) + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + } + + prov.subscriptionInsert("+", sub1, p) + + subs := publishes{} + + prov.subscriptionSearch("/finance", 0, &subs) + require.Equal(t, 0, len(subs)) +} + +func TestInsertRemove(t *testing.T) { + prov := allocProvider(t) + sub := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS2), + } + + prov.subscriptionInsert("#", sub, p) + + subs := publishes{} + prov.subscriptionSearch("bla", 0, &subs) + require.Equal(t, 1, len(subs)) + + subs = publishes{} + prov.subscriptionSearch("/bla", 0, &subs) + require.Equal(t, 0, len(subs)) + + err := prov.subscriptionRemove("#", sub) + require.NoError(t, err) + + subs = publishes{} + prov.subscriptionSearch("#", 0, &subs) + require.Equal(t, 0, len(subs)) + + prov.subscriptionInsert("/#", sub, p) + + subs = publishes{} + prov.subscriptionSearch("bla", 0, &subs) + require.Equal(t, 0, len(subs)) + + subs = publishes{} + prov.subscriptionSearch("/bla", 0, &subs) + require.Equal(t, 1, len(subs)) + + err = prov.subscriptionRemove("#", sub) + require.EqualError(t, err, topicstypes.ErrNotFound.Error()) + + err = prov.subscriptionRemove("/#", sub) + require.NoError(t, err) +} + +func TestInsert1(t *testing.T) { + prov := allocProvider(t) + topic := "sport/tennis/player1/#" + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + require.Equal(t, 1, len(prov.root.children)) + require.Equal(t, 0, len(prov.root.subs)) + + level2, ok := prov.root.children["sport"] + require.True(t, ok) + require.Equal(t, 1, len(level2.children)) + require.Equal(t, 0, len(level2.subs)) + + level3, ok := level2.children["tennis"] + + require.True(t, ok) + require.Equal(t, 1, len(level3.children)) + require.Equal(t, 0, len(level3.subs)) + + level4, ok := level3.children["player1"] + + require.True(t, ok) + require.Equal(t, 1, len(level4.children)) + require.Equal(t, 0, len(level4.subs)) + + level5, ok := level4.children["#"] + + require.True(t, ok) + require.Equal(t, 0, len(level5.children)) + require.Equal(t, 1, len(level5.subs)) + + var e *topicSubscriber + + e, ok = level5.subs[sub1.Hash()] + require.Equal(t, true, ok) + require.Equal(t, sub1, e.s) +} + +func TestSNodeInsert2(t *testing.T) { + prov := allocProvider(t) + topic := "#" + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + require.Equal(t, 1, len(prov.root.children)) + require.Equal(t, 0, len(prov.root.subs)) + + n2, ok := prov.root.children[topic] + + require.True(t, ok) + require.Equal(t, 0, len(n2.children)) + require.Equal(t, 1, len(n2.subs)) + + var e *topicSubscriber + + e, ok = n2.subs[sub1.Hash()] + require.Equal(t, true, ok) + require.Equal(t, sub1, e.s) +} + +func TestSNodeInsert3(t *testing.T) { + prov := allocProvider(t) + topic := "+/tennis/#" + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + require.Equal(t, 1, len(prov.root.children)) + require.Equal(t, 0, len(prov.root.subs)) + + n2, ok := prov.root.children["+"] + + require.True(t, ok) + require.Equal(t, 1, len(n2.children)) + require.Equal(t, 0, len(n2.subs)) + + n3, ok := n2.children["tennis"] + + require.True(t, ok) + require.Equal(t, 1, len(n3.children)) + require.Equal(t, 0, len(n3.subs)) + + n4, ok := n3.children["#"] + + require.True(t, ok) + require.Equal(t, 0, len(n4.children)) + require.Equal(t, 1, len(n4.subs)) + + var e *topicSubscriber + + e, ok = n4.subs[sub1.Hash()] + require.Equal(t, true, ok) + require.Equal(t, sub1, e.s) +} + +func TestSNodeInsert4(t *testing.T) { + prov := allocProvider(t) + topic := "/finance" + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + require.Equal(t, 1, len(prov.root.children)) + require.Equal(t, 0, len(prov.root.subs)) + + n2, ok := prov.root.children[""] + + require.True(t, ok) + require.Equal(t, 1, len(n2.children)) + require.Equal(t, 0, len(n2.subs)) + + n3, ok := n2.children["finance"] + + require.True(t, ok) + require.Equal(t, 0, len(n3.children)) + require.Equal(t, 1, len(n3.subs)) + + var e *topicSubscriber + + e, ok = n3.subs[sub1.Hash()] + require.Equal(t, true, ok) + require.Equal(t, sub1, e.s) +} + +func TestSNodeInsertDup(t *testing.T) { + prov := allocProvider(t) + topic := "/finance" + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + prov.subscriptionInsert(topic, sub1, p) + + require.Equal(t, 1, len(prov.root.children)) + require.Equal(t, 0, len(prov.root.subs)) + + n2, ok := prov.root.children[""] + + require.True(t, ok) + require.Equal(t, 1, len(n2.children)) + require.Equal(t, 0, len(n2.subs)) + + n3, ok := n2.children["finance"] + + require.True(t, ok) + require.Equal(t, 0, len(n3.children)) + require.Equal(t, 1, len(n3.subs)) + + var e *topicSubscriber + + e, ok = n3.subs[sub1.Hash()] + require.Equal(t, true, ok) + require.Equal(t, sub1, e.s) +} + +func TestSNodeRemove1(t *testing.T) { + prov := allocProvider(t) + topic := "sport/tennis/player1/#" + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + + err := prov.subscriptionRemove(topic, sub1) + require.NoError(t, err) + + require.Equal(t, 0, len(prov.root.children)) + require.Equal(t, 0, len(prov.root.subs)) +} + +func TestSNodeRemove2(t *testing.T) { + prov := allocProvider(t) + topic := "sport/tennis/player1/#" + + sub1 := &subscriber.Type{} + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + + err := prov.subscriptionRemove("sport/tennis/player1", sub1) + require.EqualError(t, err, topicstypes.ErrNotFound.Error()) +} + +func TestSNodeRemove3(t *testing.T) { + prov := allocProvider(t) + topic := "sport/tennis/player1/#" + + sub1 := &subscriber.Type{} + sub2 := &subscriber.Type{} + + p := vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + prov.subscriptionInsert(topic, sub2, p) + + err := prov.subscriptionRemove("sport/tennis/player1/#", nil) + require.NoError(t, err) + require.Equal(t, 0, len(prov.root.children)) + require.Equal(t, 0, len(prov.root.subs)) +} + +func TestRetain1(t *testing.T) { + prov := allocProvider(t) + sub := &subscriber.Type{} + + for _, m := range retainedSystree { + prov.retain(m) + } + + req := topicstypes.SubscribeReq{ + Filter: "#", + S: sub, + Params: vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + }, + } + + resp := prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + require.Equal(t, 0, len(resp.Retained)) + + req.Filter = "$SYS/#" + resp = prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + require.Equal(t, len(retainedSystree), len(resp.Retained)) +} + +func TestRetain2(t *testing.T) { + prov := allocProvider(t) + sub := &subscriber.Type{} + + for _, m := range retainedSystree { + prov.retain(m) + } + + msg := newPublishMessageLarge("sport/tennis/player1/ricardo", mqttp.QoS1) + prov.retain(msg) + + req := topicstypes.SubscribeReq{ + Filter: "#", + S: sub, + Params: vlsubscriber.SubscriptionParams{ + Ops: mqttp.SubscriptionOptions(mqttp.QoS1), + }, + } + + resp := prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + + var rMsg []*mqttp.Publish + prov.retainSearch("#", &rMsg) + require.Equal(t, 1, len(rMsg)) + + req.Filter = "$SYS/#" + resp = prov.Subscribe(req) + require.NotNil(t, resp) + require.NoError(t, resp.Err) + require.Equal(t, len(retainedSystree), len(resp.Retained)) +} + +func TestRNodeInsertRemove(t *testing.T) { + prov := allocProvider(t) + + // --- Insert msg1 + + msg := newPublishMessageLarge("sport/tennis/player1/ricardo", 1) + + n := prov.root + prov.retain(msg) + require.Equal(t, 1, len(n.children)) + require.Nil(t, n.retained) + + n2, ok := n.children["sport"] + + require.True(t, ok) + require.Equal(t, 1, len(n2.children)) + require.Nil(t, n2.retained) + + n3, ok := n2.children["tennis"] + + require.True(t, ok) + require.Equal(t, 1, len(n3.children)) + require.Nil(t, n3.retained) + + n4, ok := n3.children["player1"] + + require.True(t, ok) + require.Equal(t, 1, len(n4.children)) + require.Nil(t, n4.retained) + + n5, ok := n4.children["ricardo"] + + require.True(t, ok) + require.Equal(t, 0, len(n5.children)) + require.NotNil(t, n5.retained) + + var rMsg *mqttp.Publish + rMsg, ok = n5.retained.(*mqttp.Publish) + require.True(t, ok) + require.Equal(t, msg.QoS(), rMsg.QoS()) + require.Equal(t, msg.Topic(), rMsg.Topic()) + require.Equal(t, msg.Payload(), rMsg.Payload()) + + // --- Insert msg2 + + msg2 := newPublishMessageLarge("sport/tennis/player1/andre", mqttp.QoS1) + + prov.retain(msg2) + require.Equal(t, 2, len(n4.children)) + + n6, ok := n4.children["andre"] + + require.True(t, ok) + require.Equal(t, 0, len(n6.children)) + require.NotNil(t, n6.retained) + + rMsg, ok = n6.retained.(*mqttp.Publish) + require.True(t, ok) + require.Equal(t, msg2.QoS(), rMsg.QoS()) + require.Equal(t, msg2.Topic(), rMsg.Topic()) + + // --- Remove + + msg2.SetPayload([]byte{}) + err := prov.retainRemove("sport/tennis/player1/andre") + require.NoError(t, err) + require.Equal(t, 1, len(n4.children)) +} + +func TestRNodeMatch(t *testing.T) { + prov := allocProvider(t) + + msg1 := newPublishMessageLarge("sport/tennis/ricardo/stats", mqttp.QoS1) + prov.retain(msg1) + + msg2 := newPublishMessageLarge("sport/tennis/andre/stats", mqttp.QoS1) + prov.retain(msg2) + msg3 := newPublishMessageLarge("sport/tennis/andre/bio", mqttp.QoS1) + prov.retain(msg3) + + var msglist []*mqttp.Publish + + // --- + + msglist, _ = prov.Retained(msg1.Topic()) + require.Equal(t, 1, len(msglist)) + + // --- + msglist, _ = prov.Retained(msg2.Topic()) + require.Equal(t, 1, len(msglist)) + + // --- + msglist, _ = prov.Retained(msg3.Topic()) + require.Equal(t, 1, len(msglist)) + + // --- + msglist, _ = prov.Retained("sport/tennis/andre/+") + require.Equal(t, 2, len(msglist)) + + // --- + msglist, _ = prov.Retained("sport/tennis/andre/#") + require.Equal(t, 2, len(msglist)) + + // --- + msglist, _ = prov.Retained("sport/tennis/+/stats") + require.Equal(t, 2, len(msglist)) + + // --- + msglist, _ = prov.Retained("sport/tennis/#") + require.Equal(t, 3, len(msglist)) +} + +// nolint:unparam +func newPublishMessageLarge(topic string, qos mqttp.QosType) *mqttp.Publish { + m, _ := mqttp.New(mqttp.ProtocolV311, mqttp.PUBLISH) + + msg := m.(*mqttp.Publish) + + msg.SetPayload(make([]byte, 1024)) + _ = msg.SetTopic(topic) + _ = msg.SetQoS(qos) + + return msg +} diff --git a/subscriber/subscriber.go b/subscriber/subscriber.go index 97ca19b..a52a7ef 100644 --- a/subscriber/subscriber.go +++ b/subscriber/subscriber.go @@ -27,11 +27,9 @@ type Type struct { lock sync.RWMutex publisher vlsubscriber.Publisher log *zap.SugaredLogger - // access sync.WaitGroup - subSignal chan topicsTypes.SubscribeResp - unSubSignal chan topicsTypes.UnSubscribeResp - quit chan struct{} - // stop sync.Once + subSignal chan topicsTypes.SubscribeResp + unSubSignal chan topicsTypes.UnSubscribeResp + quit chan struct{} Config } @@ -79,7 +77,7 @@ func (s *Type) Subscriptions() vlsubscriber.Subscriptions { return s.subscriptions } -// Subscribe to given topic +// Subscribe to topic func (s *Type) Subscribe(topic string, params vlsubscriber.SubscriptionParams) ([]*mqttp.Publish, error) { resp := s.Topics.Subscribe(topicsTypes.SubscribeReq{ Filter: topic, @@ -168,8 +166,8 @@ func (s *Type) Publish(p *mqttp.Publish, grantedQoS mqttp.QosType, ops mqttp.Sub return nil } -// Online moves subscriber to online state -// since this moment all of publishes are forwarded to provided callback +// Online move subscriber to online state +// since this moment all publish packets forwarded to provided callback func (s *Type) Online(c vlsubscriber.Publisher) { s.lock.Lock() s.publisher = c diff --git a/subscriber/subscriber_test.go b/subscriber/subscriber_test.go new file mode 100644 index 0000000..002f978 --- /dev/null +++ b/subscriber/subscriber_test.go @@ -0,0 +1 @@ +package subscriber