diff --git a/floodsub.go b/floodsub.go index fc67d22f..b1247d5a 100644 --- a/floodsub.go +++ b/floodsub.go @@ -17,7 +17,12 @@ import ( timecache "github.com/whyrusleeping/timecache" ) -const ID = protocol.ID("/floodsub/1.0.0") +const ( + ID = protocol.ID("/floodsub/1.0.0") + defaultValidateTimeout = 150 * time.Millisecond + defaultValidateConcurrency = 100 + defaultValidateThrottle = 8192 +) var log = logging.Logger("floodsub") @@ -54,6 +59,18 @@ type PubSub struct { // topics tracks which topics each of our peers are subscribed to topics map[string]map[peer.ID]struct{} + // sendMsg handles messages that have been validated + sendMsg chan *sendReq + + // addVal handles validator registration requests + addVal chan *addValReq + + // topicVals tracks per topic validators + topicVals map[string]*topicVal + + // validateThrottle limits the number of active validation goroutines + validateThrottle chan struct{} + peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -78,24 +95,37 @@ type RPC struct { from peer.ID } +type Option func(*PubSub) error + // NewFloodSub returns a new FloodSub management object -func NewFloodSub(ctx context.Context, h host.Host) *PubSub { +func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) { ps := &PubSub{ - host: h, - ctx: ctx, - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan inet.Stream), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - seenMessages: timecache.NewTimeCache(time.Second * 30), - counter: uint64(time.Now().UnixNano()), + host: h, + ctx: ctx, + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan inet.Stream), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan *sendReq, 32), + addVal: make(chan *addValReq), + validateThrottle: make(chan struct{}, defaultValidateThrottle), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + topicVals: make(map[string]*topicVal), + seenMessages: timecache.NewTimeCache(time.Second * 30), + counter: uint64(time.Now().UnixNano()), + } + + for _, opt := range opts { + err := opt(ps) + if err != nil { + return nil, err + } } h.SetStreamHandler(ID, ps.handleNewStream) @@ -103,7 +133,14 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { go ps.processLoop(ctx) - return ps + return ps, nil +} + +func WithValidateThrottle(n int) Option { + return func(ps *PubSub) error { + ps.validateThrottle = make(chan struct{}, n) + return nil + } } // processLoop handles all inputs arriving on the channels @@ -176,7 +213,15 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } case msg := <-p.publish: - p.maybePublishMessage(p.host.ID(), msg.Message) + vals := p.getValidators(msg) + p.pushMsg(vals, p.host.ID(), msg) + + case req := <-p.sendMsg: + p.maybePublishMessage(req.from, req.msg.Message) + + case req := <-p.addVal: + p.addValidator(req) + case <-ctx.Done(): log.Info("pubsub processloop shutting down") return @@ -210,24 +255,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) { // subscribes to the topic. // Only called from processLoop. func (p *PubSub) handleAddSubscription(req *addSubReq) { - subs := p.myTopics[req.topic] + sub := req.sub + subs := p.myTopics[sub.topic] // announce we want this topic if len(subs) == 0 { - p.announce(req.topic, true) + p.announce(sub.topic, true) } // make new if not there if subs == nil { - p.myTopics[req.topic] = make(map[*Subscription]struct{}) - subs = p.myTopics[req.topic] + p.myTopics[sub.topic] = make(map[*Subscription]struct{}) + subs = p.myTopics[sub.topic] } - sub := &Subscription{ - ch: make(chan *Message, 32), - topic: req.topic, - cancelCh: p.cancelCh, - } + sub.ch = make(chan *Message, 32) + sub.cancelCh = p.cancelCh p.myTopics[sub.topic][sub] = struct{}{} @@ -314,8 +357,11 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error { continue } - p.maybePublishMessage(rpc.from, pmsg) + msg := &Message{pmsg} + vals := p.getValidators(msg) + p.pushMsg(vals, rpc.from, msg) } + return nil } @@ -324,6 +370,75 @@ func msgID(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } +// pushMsg pushes a message performing validation as necessary +func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { + if len(vals) > 0 { + // validation is asynchronous and globally throttled with the throttleValidate semaphore. + // the purpose of the global throttle is to bound the goncurrency possible from incoming + // network traffic; each validator also has an individual throttle to preclude + // slow (or faulty) validators from starving other topics; see validate below. + select { + case p.validateThrottle <- struct{}{}: + go func() { + p.validate(vals, src, msg) + <-p.validateThrottle + }() + default: + log.Warningf("message validation throttled; dropping message from %s", src) + } + return + } + + p.maybePublishMessage(src, msg.Message) +} + +// validate performs validation and only sends the message if all validators succeed +func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) { + ctx, cancel := context.WithCancel(p.ctx) + defer cancel() + + rch := make(chan bool, len(vals)) + rcount := 0 + throttle := false + +loop: + for _, val := range vals { + rcount++ + + select { + case val.validateThrottle <- struct{}{}: + go func(val *topicVal) { + rch <- val.validateMsg(ctx, msg) + <-val.validateThrottle + }(val) + + default: + log.Debugf("validation throttled for topic %s", val.topic) + throttle = true + break loop + } + } + + if throttle { + log.Warningf("message validation throttled; dropping message from %s", src) + return + } + + for i := 0; i < rcount; i++ { + valid := <-rch + if !valid { + log.Warningf("message validation failed; dropping message from %s", src) + return + } + } + + // all validators were successful, send the message + p.sendMsg <- &sendReq{ + from: src, + msg: msg, + } +} + func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) { id := msgID(pmsg) if p.seenMessage(id) { @@ -348,7 +463,7 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { continue } - for p, _ := range tmap { + for p := range tmap { tosend[p] = struct{}{} } } @@ -375,20 +490,38 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { return nil } +// getValidators returns all validators that apply to a given message +func (p *PubSub) getValidators(msg *Message) []*topicVal { + var vals []*topicVal + + for _, topic := range msg.GetTopicIDs() { + val, ok := p.topicVals[topic] + if !ok { + continue + } + + vals = append(vals, val) + } + + return vals +} + type addSubReq struct { - topic string - resp chan *Subscription + sub *Subscription + resp chan *Subscription } +type SubOpt func(sub *Subscription) error + // Subscribe returns a new Subscription for the given topic -func (p *PubSub) Subscribe(topic string) (*Subscription, error) { +func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) { td := pb.TopicDescriptor{Name: &topic} - return p.SubscribeByTopicDescriptor(&td) + return p.SubscribeByTopicDescriptor(&td, opts...) } // SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor -func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) { +func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubOpt) (*Subscription, error) { if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { return nil, fmt.Errorf("auth mode not yet supported") } @@ -397,10 +530,21 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti return nil, fmt.Errorf("encryption mode not yet supported") } + sub := &Subscription{ + topic: td.GetName(), + } + + for _, opt := range opts { + err := opt(sub) + if err != nil { + return nil, err + } + } + out := make(chan *Subscription, 1) p.addSub <- &addSubReq{ - topic: td.GetName(), - resp: out, + sub: sub, + resp: out, } return <-out, nil @@ -439,6 +583,12 @@ type listPeerReq struct { topic string } +// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done. +type sendReq struct { + from peer.ID + msg *Message +} + // ListPeers returns a list of peers we are connected to. func (p *PubSub) ListPeers(topic string) []peer.ID { out := make(chan []peer.ID) @@ -448,3 +598,100 @@ func (p *PubSub) ListPeers(topic string) []peer.ID { } return <-out } + +// per topic validators +type addValReq struct { + topic string + validate Validator + timeout time.Duration + throttle int + resp chan error +} + +type topicVal struct { + topic string + validate Validator + validateTimeout time.Duration + validateThrottle chan struct{} +} + +// Validator is a function that validates a message +type Validator func(context.Context, *Message) bool + +// ValidatorOpt is an option for RegisterTopicValidator +type ValidatorOpt func(addVal *addValReq) error + +// WithValidatorTimeout is an option that sets the topic validator timeout +func WithValidatorTimeout(timeout time.Duration) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.timeout = timeout + return nil + } +} + +// WithValidatorConcurrency is an option that sets topic validator throttle +func WithValidatorConcurrency(n int) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.throttle = n + return nil + } +} + +// RegisterTopicValidator registers a validator for topic +func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error { + addVal := &addValReq{ + topic: topic, + validate: val, + resp: make(chan error, 1), + } + + for _, opt := range opts { + err := opt(addVal) + if err != nil { + return err + } + } + + p.addVal <- addVal + return <-addVal.resp +} + +func (ps *PubSub) addValidator(req *addValReq) { + topic := req.topic + + _, ok := ps.topicVals[topic] + if ok { + req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic) + return + } + + val := &topicVal{ + topic: topic, + validate: req.validate, + validateTimeout: defaultValidateTimeout, + validateThrottle: make(chan struct{}, defaultValidateConcurrency), + } + + if req.timeout > 0 { + val.validateTimeout = req.timeout + } + + if req.throttle > 0 { + val.validateThrottle = make(chan struct{}, req.throttle) + } + + ps.topicVals[topic] = val + req.resp <- nil +} + +func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool { + vctx, cancel := context.WithTimeout(ctx, val.validateTimeout) + defer cancel() + + valid := val.validate(vctx, msg) + if !valid { + log.Debugf("validation failed for topic %s", val.topic) + } + + return valid +} diff --git a/floodsub_test.go b/floodsub_test.go index 28c10781..bf5ac80b 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "sort" + "sync" "testing" "time" @@ -80,10 +81,14 @@ func connectAll(t *testing.T, hosts []host.Host) { } } -func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub { +func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub { var psubs []*PubSub for _, h := range hs { - psubs = append(psubs, NewFloodSub(ctx, h)) + ps, err := NewFloodSub(ctx, h, opts...) + if err != nil { + panic(err) + } + psubs = append(psubs, ps) } return psubs } @@ -289,11 +294,14 @@ func TestSelfReceive(t *testing.T) { host := getNetHosts(t, ctx, 1)[0] - psub := NewFloodSub(ctx, host) + psub, err := NewFloodSub(ctx, host) + if err != nil { + t.Fatal(err) + } msg := []byte("hello world") - err := psub.Publish("foobar", msg) + err = psub.Publish("foobar", msg) if err != nil { t.Fatal(err) } @@ -323,14 +331,181 @@ func TestOneToOne(t *testing.T) { connect(t, hosts[0], hosts[1]) - ch, err := psubs[1].Subscribe("foobar") + sub, err := psubs[1].Subscribe("foobar") if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 50) - checkMessageRouting(t, "foobar", psubs, []*Subscription{ch}) + checkMessageRouting(t, "foobar", psubs, []*Subscription{sub}) +} + +func TestValidate(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool { + return !bytes.Contains(msg.Data, []byte("illegal")) + }) + if err != nil { + t.Fatal(err) + } + + sub, err := psubs[1].Subscribe(topic) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + msgs := []struct { + msg []byte + validates bool + }{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("there also is nothing controversial about this message"), validates: true}, + {msg: []byte("openly illegal content will be censored"), validates: false}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + } + + for _, tc := range msgs { + for _, p := range psubs { + err := p.Publish(topic, tc.msg) + if err != nil { + t.Fatal(err) + } + + select { + case msg := <-sub.ch: + if !tc.validates { + t.Log(msg) + t.Error("expected message validation to filter out the message") + } + case <-time.After(333 * time.Millisecond): + if tc.validates { + t.Error("expected message validation to accept the message") + } + } + } + } +} + +func TestValidateOverload(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + type msg struct { + msg []byte + validates bool + } + + tcs := []struct { + msgs []msg + + maxConcurrency int + }{ + { + maxConcurrency: 10, + msgs: []msg{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("there also is nothing controversial about this message"), validates: true}, + {msg: []byte("also fine"), validates: true}, + {msg: []byte("still, all good"), validates: true}, + {msg: []byte("this is getting boring"), validates: true}, + {msg: []byte("foo"), validates: true}, + {msg: []byte("foobar"), validates: true}, + {msg: []byte("foofoo"), validates: true}, + {msg: []byte("barfoo"), validates: true}, + {msg: []byte("oh no!"), validates: false}, + }, + }, + { + maxConcurrency: 2, + msgs: []msg{ + {msg: []byte("this is a legal message"), validates: true}, + {msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true}, + {msg: []byte("oh no!"), validates: false}, + }, + }, + } + + for _, tc := range tcs { + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + topic := "foobar" + + block := make(chan struct{}) + + err := psubs[1].RegisterTopicValidator(topic, + func(ctx context.Context, msg *Message) bool { + <-block + return true + }, + WithValidatorConcurrency(tc.maxConcurrency)) + + if err != nil { + t.Fatal(err) + } + + sub, err := psubs[1].Subscribe(topic) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 50) + + if len(tc.msgs) != tc.maxConcurrency+1 { + t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1) + } + + p := psubs[0] + + var wg sync.WaitGroup + wg.Add(1) + go func() { + for _, tmsg := range tc.msgs { + select { + case msg := <-sub.ch: + if !tmsg.validates { + t.Log(msg) + t.Error("expected message validation to drop the message because all validator goroutines are taken") + } + case <-time.After(333 * time.Millisecond): + if tmsg.validates { + t.Error("expected message validation to accept the message") + } + } + } + wg.Done() + }() + + for i, tmsg := range tc.msgs { + err := p.Publish(topic, tmsg.msg) + if err != nil { + t.Fatal(err) + } + + // wait a bit to let pubsub's internal state machine start validating the message + time.Sleep(10 * time.Millisecond) + + // unblock validator goroutines after we sent one too many + if i == len(tc.msgs)-1 { + close(block) + } + } + wg.Wait() + } } func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { @@ -414,7 +589,10 @@ func TestSubReporting(t *testing.T) { defer cancel() host := getNetHosts(t, ctx, 1)[0] - psub := NewFloodSub(ctx, host) + psub, err := NewFloodSub(ctx, host) + if err != nil { + t.Fatal(err) + } fooSub, err := psub.Subscribe("foo") if err != nil {