diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 3b3f7e9f4..629877719 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -91,7 +91,9 @@ jobs: go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/resetter.txt -covermode=atomic ./tests/plugins/resetter go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/rpc.txt -covermode=atomic ./tests/plugins/rpc go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/kv_plugin.txt -covermode=atomic ./tests/plugins/kv + go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/broadcast_plugin.txt -covermode=atomic ./tests/plugins/broadcast go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/websockets.txt -covermode=atomic ./tests/plugins/websockets + go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/ws_origin.txt -covermode=atomic ./plugins/websockets docker-compose -f ./tests/docker-compose.yaml down cat ./coverage-ci/*.txt > ./coverage-ci/summary.txt diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 92d76d2cc..f23f9b5d5 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -67,7 +67,6 @@ jobs: - name: Run golang tests on Windows run: | docker-compose -f ./tests/docker-compose.yaml up -d - mkdir ./coverage-ci go test -v -race ./pkg/transport/pipe go test -v -race ./pkg/transport/socket go test -v -race ./pkg/pool @@ -91,6 +90,7 @@ jobs: go test -v -race ./tests/plugins/resetter go test -v -race ./tests/plugins/rpc go test -v -race ./tests/plugins/kv + go test -v -race ./tests/plugins/broadcast go test -v -race ./tests/plugins/websockets + go test -v -race ./plugins/websockets docker-compose -f ./tests/docker-compose.yaml down - cat ./coverage-ci/*.txt > ./coverage-ci/summary.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index cbb9936be..1f8e1733c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,11 +3,16 @@ CHANGELOG v2.3.1 (_.06.2021) ------------------- +## 👀 New: + +- ✏️ Rework `broadcast` plugin. Add architecture diagrams to the `doc` folder. [PR](https://github.com/spiral/roadrunner/pull/732) + ## 🩹 Fixes: - 🐛 Fix: Bugs with `boltdb` storage: [Boom](https://github.com/spiral/roadrunner/issues/717), [Boom](https://github.com/spiral/roadrunner/issues/718), [Boom](https://github.com/spiral/roadrunner/issues/719) - 🐛 Fix: Bug with incorrect redis initialization and usage [Bug](https://github.com/spiral/roadrunner/issues/720) - 🐛 Fix: Bug, Goridge duplicate error messages [Bug](https://github.com/spiral/goridge/issues/128) +- 🐛 Fix: Bug, incorrect request `origin` check [Bug](https://github.com/spiral/roadrunner/issues/727) ## 📦 Packages: diff --git a/Makefile b/Makefile index 15f9e3942..bcfce79df 100755 --- a/Makefile +++ b/Makefile @@ -31,7 +31,9 @@ test_coverage: go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/resetter.out -covermode=atomic ./tests/plugins/resetter go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/rpc.out -covermode=atomic ./tests/plugins/rpc go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/kv_plugin.out -covermode=atomic ./tests/plugins/kv + go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/broadcast_plugin.out -covermode=atomic ./tests/plugins/broadcast go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/ws_plugin.out -covermode=atomic ./tests/plugins/websockets + go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/ws_origin.out -covermode=atomic ./plugins/websockets cat ./coverage/*.out > ./coverage/summary.out docker-compose -f tests/docker-compose.yaml down @@ -60,7 +62,9 @@ test: ## Run application tests go test -v -race -tags=debug ./tests/plugins/resetter go test -v -race -tags=debug ./tests/plugins/rpc go test -v -race -tags=debug ./tests/plugins/kv + go test -v -race -tags=debug ./tests/plugins/broadcast go test -v -race -tags=debug ./tests/plugins/websockets + go test -v -race -tags=debug ./plugins/websockets docker-compose -f tests/docker-compose.yaml down testGo1.17beta1: ## Run application tests @@ -89,4 +93,6 @@ testGo1.17beta1: ## Run application tests go1.17beta1 test -v -race -tags=debug ./tests/plugins/rpc go1.17beta1 test -v -race -tags=debug ./tests/plugins/kv go1.17beta1 test -v -race -tags=debug ./tests/plugins/websockets + go1.17beta1 test -v -race -tags=debug ./tests/plugins/broadcast + go1.17beta1 test -v -race -tags=debug ./plugins/websockets docker-compose -f tests/docker-compose.yaml down diff --git a/pkg/pubsub/interface.go b/pkg/pubsub/interface.go index d021dbbe0..06252d70e 100644 --- a/pkg/pubsub/interface.go +++ b/pkg/pubsub/interface.go @@ -1,7 +1,5 @@ package pubsub -import websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" - /* This interface is in BETA. It might be changed. */ @@ -16,6 +14,11 @@ type PubSub interface { Reader } +type SubReader interface { + Subscriber + Reader +} + // Subscriber defines the ability to operate as message passing broker. // BETA interface type Subscriber interface { @@ -33,18 +36,19 @@ type Subscriber interface { // BETA interface type Publisher interface { // Publish one or multiple Channel. - Publish(messages []byte) error + Publish(message *Message) error // PublishAsync publish message and return immediately // If error occurred it will be printed into the logger - PublishAsync(messages []byte) + PublishAsync(message *Message) } // Reader interface should return next message type Reader interface { - Next() (*websocketsv1.Message, error) + Next() (*Message, error) } -type PSProvider interface { - PSProvide(key string) (PubSub, error) +// Constructor is a special pub-sub interface made to return a constructed PubSub type +type Constructor interface { + PSConstruct(key string) (PubSub, error) } diff --git a/pkg/pubsub/psmessage.go b/pkg/pubsub/psmessage.go new file mode 100644 index 000000000..e33d9284a --- /dev/null +++ b/pkg/pubsub/psmessage.go @@ -0,0 +1,15 @@ +package pubsub + +import json "github.com/json-iterator/go" + +// Message represents a single message with payload bound to a particular topic +type Message struct { + // Topic (channel in terms of redis) + Topic string `json:"topic"` + // Payload (on some decode stages might be represented as base64 string) + Payload []byte `json:"payload"` +} + +func (m *Message) MarshalBinary() (data []byte, err error) { + return json.Marshal(m) +} diff --git a/plugins/broadcast/config.go b/plugins/broadcast/config.go new file mode 100644 index 000000000..4f1e5213e --- /dev/null +++ b/plugins/broadcast/config.go @@ -0,0 +1,25 @@ +package broadcast + +/* + +# Global redis config (priority - 2) + +websockets: # <----- one of possible subscribers + path: /ws + broker: default # <------ broadcast broker to use --------------- | + | match +broadcast: # <-------- broadcast entry point plugin | + default: # <----------------------------------------------------- | + driver: redis + # local redis config (priority - 1) + test: + driver: memory + + +priority local -> global +*/ + +// Config ... +type Config struct { + Data map[string]interface{} `mapstructure:"broadcast"` +} diff --git a/plugins/broadcast/doc/broadcast_arch.drawio b/plugins/broadcast/doc/broadcast_arch.drawio new file mode 100644 index 000000000..fd5ff1f9c --- /dev/null +++ b/plugins/broadcast/doc/broadcast_arch.drawio @@ -0,0 +1 @@ +7V1bc6M4Fv41rk1vVVIg7o+Jk8l01fR2Np7e7n7a4iLbbDB4AMdJ//qVQGCQZBsHEMSTviRGIAznfj4dSRNlunq5j+318kvkwWACJO9lotxOALBkE/3EDa95AzAVJW9ZxL6Xt8m7hpn/C5JGibRufA8mtQvTKApSf11vdKMwhG5aa7PjONrWL5tHQf1b1/YCMg0z1w7Y1u++ly7Ja0iStDvxO/QXy5Q+s7KLq0lDsrS9aFtpUu4myjSOojT/tHqZwgBTryBM3u+3PWfLJ4thmDbpMFV1/fPNfXhp/FAuZ+Hz8kd0f6mr+W2e7WBDXpk8bfpa0CCONqEH8V2kiXKzXfopnK1tF5/dIq6jtmW6CtCRjD7O/SCYRkEUZ30Vz4bm3EXtSRpHT7ByRndN6MzRGfY9yKs9wziFL5Um8l73MFrBNH5Fl5CzikloTMRM1cjxtsazvG1ZZZdFGm0iJ4vy3jtKog+EmCcQVtP7Jexcw3+5hM3+4B5RmFba8z/dEFyW5DrFDYOluGpyKK7oWl8U186b4qBO8ctSoAckec/WY2iSU0KuSNrQFLcY+kIPOSxyGMXpMlpEoR3c7Vpv6hzYXfNHFK0J3f8H0/SVeF97k0Z1rsAXP/2Bu19p5OgnuRn+fPtSPXgtDkL0upVO+PBncT98sOuWHRX99vItiTaxCw+QpggP7HgB00PXEYZhwh0UgxgGduo/1yMBHkdJ14fIR89cio8BpCtZBaZmkJ81WdIkyvPkz03uQUlJ+VBvF5xCIoeSHFAVHbmh6FQFpyJHe0SHNh6mC11u6OGYGo4QRAhbEV91LGzXcWy/Vi5YY7FJ9suibmo16dNVKlakr5fbXW/Kh69XJKnV9ZqlUUqSU6RTlSkcxs653Tx+vb6dXs/+7NbLtRDU5t5M1ygSciI2WeU4M62vEFkd0iLJFXO082unGSRwqkWazwGf0Z7u6JreyiKBpu6va4PUSgYAo2OPd7efZ6jp4dvN7NtNt4o2h/oe+huWI3WkaCoVqJcRYlXRgEhFG9T1d6FoY9Iz5V3qmcLo2Ze7L18ff56TonE9mlBFM967ohlj0jT9XWoai/tdf/n3w0T5Df87I23T9aG1zWQojV49wKMA6Muy7vaz7Qe2g4iMXn7jJBsHffBiJANxwrAA0SGt09kO/EWIPruIahCR9AZTy3ft4JqcWPmelysxTPxf2RflDCT5GLqvdjPRbvG9kN4muQrLDKvCKIT9IOGaynKJj4T3xSVZGdIojhyykpqauHHZuOK5q6lxHNmeaydY+XysLnNsvbo0c54GTU/lmTkTOIp+0NeckCbrb4ze1d4USP9QoGN6cVyBtK4V6E04nEE50WL8cR/uZRlmq+v1I9fLsiUd6tAPsKaBYSXaaC7SAkRT7dy2v000DaUmCbJ5ouhQHXoSnUGjiXqKZTTNsUZnDYE1CpEzKQlSSay4T+To62VZO1VGqR49DRywIwcToAc4u1jXZFf/a4Orf25WiGE+Siiu0Vlp/YJ+ZnIg5e2XKRZTfE6tnMPpyiXJUPA5kqSU90SfFuS3vcKRVcAeVQI10lzvUhxlFztx5ZZ5S9FQBn/ZA5Av+s2hr0dta7ptydy1OxKRcwGcp8fpl13FUO8eprdZvnjxBF+zB4v9cPEJfZptnEdoe7U3rrwdZaSSpb3GH+OsSK1qWyJ083mQFYbhgHdPWNwh7nJCLEyN6imclL+s1KvGwkZvsfBH/cPREPe49RdUAMFYX42GgnqueZDZirDv0Eki9wmmLP7TJjd1ZM+bSzx9lCVDsToCd2RZoRFv44otSeKmp0pfFUmywRD5c+inFxhqK/3CJ4baZ4G2MUV5OuCgBWLhNhYVfYQuxFoMJKcRSHOmrOEgobIukjWlJX9PuUvVdzUYH3q7+yrIftx9AUHuC9BYuiJdYYDd0sjPRs7s1KRILqMnCsLfm+SwHUxKVnvIccCg1TvjjsSai3I/qCQ3EruSJUVVTWBopg7qUI9Kz4LoOS4r6FPxUV+3Ye6fnnBCJSUoTfKjMHNTk2x2z9xfnEXEJnO8kNDhBDDseFzNC4GGXkiue6GjdXctVBc0Vd3Ok6h2XAWMShW4hOc/F7DEdpf9YMyjxC0ql3B6lXBQ6CTrHDwhipoBJx6c25sgbXM7OwiiLfT+G8V+CfLssJh/1mCZN3/JOpuQV781us82OX77NnaniJ1zFGogM6RQM1lUi5M28op/e0NywMeo5tGw4LgRMkTFD6Z1tZvIohbfW+CCouMHFtcpkYarHWJbGilioTBqe0lj21Uc9/2FFxJV1K8qbHgBuDPUetNrFg6aIXHAwgRfoIs0EhF6momk7aFfK5gk9gIiIywFWIvfIRPoGE/hIQ1iYzzzw7jupY3V0Lgq/UzdajBIWc47FWVPLUZnv6bLLBF7CDaLLAVjsavx13IxmolS4OEBc4XNfv++gLksWRxjKRQxV9jR+g/EPFcfHlghFDJXwJCObOyQeeGhjrsyXZAroyFzyxoEMpeLVXgaY+aohwDQXPmohe9AmE1xcdkB0NwSXMxQ0GdI0NyTXAgBL2zTdE02+kqorKYT+3pLqJRhV8+o+aHxgeZK05l63adU7bjKW6OLgZLzVKg7vDyFSQpa3GxfNaY40FqQGaBBa0sbGrRWBp3gO3L/bTY1Av0Un58KWpui/TdbmDU20FqQXtOgtcmBAMSC1goLgM1g/Nzx/MVhgieTs0yg2OCpuPGH1dw/+ev4LDFRKTyNRtNi0bOZVFk0ro5Fs4hQq7UTbM/yuCtM2lCVFdCTUurDw9AqYAj9N4Khafhn8LptlV2h5wOFrkX4g4HQ6ogWqhsfCK02Tf5VUbgdDUKDOgYtm2ojj9YahVZOxKCJ69t7PVPEVbu+H8RaHXjZ2FHHbk0lv8h3BkasL2kL2Xcox8JrwhFr0fHdcEnWsNDUyBFqtSk41X2a1Y6rLGi0D6HOs6NprkLdotXyKTcbHJ0WpPK0Ny5nGx9Fp3vL6D4mxx91wcd9dT+r8ZyKTl/Kgp11QaARw9OCFJuGpy85K8CJhac1FhPrHp4eCAm7lAcf3NfPvipH5M4YmtHQ0o5rSUuNnbfw+DCt4237YptkbYdFWwnVXT1snMBPlhdkhsOnSiRU7dCpEkPZ06DB46ylG4rd1bpASn1dIIM3xGQIVeFBl1Y4SYXRwQOMffTe2I8OFQ41zYe0zmeYtdNSNh/aJJwwo1hua7MKrl085WgHff9hOzB4iBI/Qx6UWydK02g1YbHxbEmx2kpdmzTwQ6RUxVaTB3nQXJ0Mido8h7N7ocLRJr03bWKBmxkMvUk5WwtLToS/Gf0nZo7DhLMYj1AokNbgrQfMs3W9DUforKd6iKNn38Nsqa3ElzMpqg2fnum65zq1QRRoulRdF2y6MR7+9f3PG/mv6eNqMdV/RL/Pvg27FwS131rjqBIICSspZ3LUhXHpyymc4F5ntvRX/PRdo5YJlzXaGneXr/Nff9ApnvVByaZ4b0v5aryXsFD5kkUFRNxvx/1HY2aabuvYKnkdqRSALszMqUPOjBUy9GNjzvLhHq1HkQ9Rkb9dRJcYVlMPxAQWHOnYG2to9KxijbeaSl8QFt8MMATO8dpMAhc+IodNBloTdwlX9miDwFaMAQxjOJlUb7VDXMaMqHRoXFvvdWKdOXvx7bc/Q7loTmEftcUljSt+Xq0DuIKZ6cf59Szbt2oC8NDKw2wahYi4mwzW6NR4Nt1GrGMd5S0A0Nc+tFwGsasJkX3EkvzbjhH9LGwnjXPw+SLUdrIoVIUP6ISLQXRpXWIfXuHyWI1y7ARi/CpzgekS7vrh1rzuaFP6yGyp/8ZMJlUI42cxoDASWee4Ry7S2MXIJpfFA1csGG/KXoxa+rJziSN0kByYf3wOskhJPoKkYWWgn0IXnV7dSpOkq7w8VdJNC7+uSNyMHUCa3T3+544ThIX+biPRfPNQHAqgmMpd4vYYcirnkjV0/bmfeRV/52v6KGwVEq0xqa7atFqji8Epfqo78GzCt/gMqtD11F2WBUNejTGvLuzFyXuGGSplTo7sAWape0CBxh10st5ovxAZZxUyKk2ULrBN8lHc+SszMHY2JJvg1TWh94m1YLOyuv4Cvtg4p/xErdjcxh4dWSKiqci2Mk/0yLms6iwSB4rlYrteIoLPx/e4rrHRBpBvXHPUjXUCY7ZOLL5ebJu+19rQ4dHJHQz5cIdyT6I9HXqyZywk3dPeW43lr5WpKXeO3PkRkXN++DRmK0Hu8wnDq9qO93MMWpFQdIXfqZh3ddAbnAWspZu0M+dVc/Pgxt5grWIPwyreuITuU4Uv58kKpmTCZJ21JZITgI25fuKVyHHZ72f8EHamTOnST6rMmWbPh/M/KTvxinzWZAc/niXvaKejS5xAS6gaAd6CCekmxvEuqUvMx03KAm4cLa+L9awvYuhl3CNpes5iEiWfJQ8NqmpbNjgrKyhCa+SKTcRqQy+VgRYphDCH6HN9w5rnQNfeJFj5thnoH4X/wMHOUxhtJyyUb6el5yOADsqXFvCKzZZKedn62YM7uI/9bPtBxjxSTOnkcE42PJQ9mPNafsFd6G0yqxCFQRY0zvMqzOoXZxxLEuhdNVnY7xzFkAlfAWtJVN440xsGIdBhHOGdfXaRLYoml18iD+Ir/g8= \ No newline at end of file diff --git a/plugins/broadcast/interface.go b/plugins/broadcast/interface.go new file mode 100644 index 000000000..46709d71e --- /dev/null +++ b/plugins/broadcast/interface.go @@ -0,0 +1,7 @@ +package broadcast + +import "github.com/spiral/roadrunner/v2/pkg/pubsub" + +type Broadcaster interface { + GetDriver(key string) (pubsub.SubReader, error) +} diff --git a/plugins/broadcast/plugin.go b/plugins/broadcast/plugin.go new file mode 100644 index 000000000..6ddef8066 --- /dev/null +++ b/plugins/broadcast/plugin.go @@ -0,0 +1,208 @@ +package broadcast + +import ( + "fmt" + "sync" + + "github.com/google/uuid" + endure "github.com/spiral/endure/pkg/container" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const ( + PluginName string = "broadcast" + // driver is the mandatory field which should present in every storage + driver string = "driver" + + redis string = "redis" + memory string = "memory" +) + +type Plugin struct { + sync.RWMutex + + cfg *Config + cfgPlugin config.Configurer + log logger.Logger + // publishers implement Publisher interface + // and able to receive a payload + publishers map[string]pubsub.PubSub + constructors map[string]pubsub.Constructor +} + +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { + const op = errors.Op("broadcast_plugin_init") + if !cfg.Has(PluginName) { + return errors.E(op, errors.Disabled) + } + p.cfg = &Config{} + // unmarshal config section + err := cfg.UnmarshalKey(PluginName, &p.cfg.Data) + if err != nil { + return errors.E(op, err) + } + + p.publishers = make(map[string]pubsub.PubSub) + p.constructors = make(map[string]pubsub.Constructor) + + p.log = log + p.cfgPlugin = cfg + return nil +} + +func (p *Plugin) Serve() chan error { + return make(chan error) +} + +func (p *Plugin) Stop() error { + return nil +} + +func (p *Plugin) Collects() []interface{} { + return []interface{}{ + p.CollectPublishers, + } +} + +// CollectPublishers collect all plugins who implement pubsub.Publisher interface +func (p *Plugin) CollectPublishers(name endure.Named, constructor pubsub.Constructor) { + // key redis, value - interface + p.constructors[name.Name()] = constructor +} + +// Publish is an entry point to the websocket PUBSUB +func (p *Plugin) Publish(m *pubsub.Message) error { + p.Lock() + defer p.Unlock() + + const op = errors.Op("broadcast_plugin_publish") + + // check if any publisher registered + if len(p.publishers) > 0 { + for j := range p.publishers { + err := p.publishers[j].Publish(m) + if err != nil { + return errors.E(op, err) + } + } + return nil + } else { + p.log.Warn("no publishers registered") + } + + return nil +} + +func (p *Plugin) PublishAsync(m *pubsub.Message) { + go func() { + p.Lock() + defer p.Unlock() + // check if any publisher registered + if len(p.publishers) > 0 { + for j := range p.publishers { + err := p.publishers[j].Publish(m) + if err != nil { + p.log.Error("publishAsync", "error", err) + // continue publish to other registered publishers + continue + } + } + } else { + p.log.Warn("no publishers registered") + } + }() +} + +func (p *Plugin) GetDriver(key string) (pubsub.SubReader, error) { //nolint:gocognit + const op = errors.Op("broadcast_plugin_get_driver") + + // choose a driver + if val, ok := p.cfg.Data[key]; ok { + // check type of the v + // should be a map[string]interface{} + switch t := val.(type) { + // correct type + case map[string]interface{}: + if _, ok := t[driver]; !ok { + panic(errors.E(op, errors.Errorf("could not find mandatory driver field in the %s storage", val))) + } + default: + return nil, errors.E(op, errors.Str("wrong type detected in the configuration, please, check yaml indentation")) + } + + // config key for the particular sub-driver kv.memcached + configKey := fmt.Sprintf("%s.%s", PluginName, key) + + switch val.(map[string]interface{})[driver] { + case memory: + if _, ok := p.constructors[memory]; !ok { + return nil, errors.E(op, errors.Errorf("no memory drivers registered, registered: %s", p.publishers)) + } + ps, err := p.constructors[memory].PSConstruct(configKey) + if err != nil { + return nil, errors.E(op, err) + } + + // save the initialized publisher channel + // for the in-memory, register new publishers + p.publishers[uuid.NewString()] = ps + + return ps, nil + case redis: + if _, ok := p.constructors[redis]; !ok { + return nil, errors.E(op, errors.Errorf("no redis drivers registered, registered: %s", p.publishers)) + } + + // first - try local configuration + switch { + case p.cfgPlugin.Has(configKey): + ps, err := p.constructors[redis].PSConstruct(configKey) + if err != nil { + return nil, errors.E(op, err) + } + + // if section already exists, return new connection + if _, ok := p.publishers[configKey]; ok { + return ps, nil + } + + // if not - initialize a connection + p.publishers[configKey] = ps + return ps, nil + + // then try global if local does not exist + case p.cfgPlugin.Has(redis): + ps, err := p.constructors[redis].PSConstruct(configKey) + if err != nil { + return nil, errors.E(op, err) + } + + // if section already exists, return new connection + if _, ok := p.publishers[configKey]; ok { + return ps, nil + } + + // if not - initialize a connection + p.publishers[configKey] = ps + return ps, nil + } + } + } + return nil, errors.E(op, errors.Str("could not find driver by provided key")) +} + +func (p *Plugin) RPC() interface{} { + return &rpc{ + plugin: p, + log: p.log, + } +} + +func (p *Plugin) Name() string { + return PluginName +} + +func (p *Plugin) Available() {} diff --git a/plugins/broadcast/rpc.go b/plugins/broadcast/rpc.go new file mode 100644 index 000000000..2ee211f81 --- /dev/null +++ b/plugins/broadcast/rpc.go @@ -0,0 +1,87 @@ +package broadcast + +import ( + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/logger" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" +) + +// rpc collectors struct +type rpc struct { + plugin *Plugin + log logger.Logger +} + +// Publish ... msg is a proto decoded payload +// see: root/proto +func (r *rpc) Publish(in *websocketsv1.Request, out *websocketsv1.Response) error { + const op = errors.Op("broadcast_publish") + + // just return in case of nil message + if in == nil { + out.Ok = false + return nil + } + + r.log.Debug("message published", "msg", in.String()) + msgLen := len(in.GetMessages()) + + for i := 0; i < msgLen; i++ { + for j := 0; j < len(in.GetMessages()[i].GetTopics()); j++ { + if in.GetMessages()[i].GetTopics()[j] == "" { + r.log.Warn("message with empty topic, skipping") + // skip empty topics + continue + } + + tmp := &pubsub.Message{ + Topic: in.GetMessages()[i].GetTopics()[j], + Payload: in.GetMessages()[i].GetPayload(), + } + + err := r.plugin.Publish(tmp) + if err != nil { + out.Ok = false + return errors.E(op, err) + } + } + } + + out.Ok = true + return nil +} + +// PublishAsync ... +// see: root/proto +func (r *rpc) PublishAsync(in *websocketsv1.Request, out *websocketsv1.Response) error { + // just return in case of nil message + if in == nil { + out.Ok = false + return nil + } + + r.log.Debug("message published", "msg", in.GetMessages()) + + msgLen := len(in.GetMessages()) + + for i := 0; i < msgLen; i++ { + for j := 0; j < len(in.GetMessages()[i].GetTopics()); j++ { + if in.GetMessages()[i].GetTopics()[j] == "" { + r.log.Warn("message with empty topic, skipping") + // skip empty topics + continue + } + + tmp := &pubsub.Message{ + Topic: in.GetMessages()[i].GetTopics()[j], + Payload: in.GetMessages()[i].GetPayload(), + } + + r.plugin.PublishAsync(tmp) + } + } + + out.Ok = true + return nil +} diff --git a/plugins/kv/config.go b/plugins/kv/config.go index 66095817f..09ba79cd2 100644 --- a/plugins/kv/config.go +++ b/plugins/kv/config.go @@ -1,6 +1,6 @@ package kv -// Config represents general storage configuration with keys as the user defined kv-names and values as the drivers +// Config represents general storage configuration with keys as the user defined kv-names and values as the constructors type Config struct { Data map[string]interface{} `mapstructure:"kv"` } diff --git a/plugins/kv/drivers/boltdb/driver.go b/plugins/kv/drivers/boltdb/driver.go index 5f4d98b19..4b6752718 100644 --- a/plugins/kv/drivers/boltdb/driver.go +++ b/plugins/kv/drivers/boltdb/driver.go @@ -9,10 +9,10 @@ import ( "time" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/utils" bolt "go.etcd.io/bbolt" ) diff --git a/plugins/kv/drivers/boltdb/plugin.go b/plugins/kv/drivers/boltdb/plugin.go index 28e2a89cc..6ae1a1f64 100644 --- a/plugins/kv/drivers/boltdb/plugin.go +++ b/plugins/kv/drivers/boltdb/plugin.go @@ -46,7 +46,7 @@ func (s *Plugin) Stop() error { return nil } -func (s *Plugin) KVProvide(key string) (kv.Storage, error) { +func (s *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("boltdb_plugin_provide") st, err := NewBoltDBDriver(s.log, key, s.cfgPlugin, s.stop) if err != nil { diff --git a/plugins/kv/drivers/memcached/driver.go b/plugins/kv/drivers/memcached/driver.go index c1f79cbb6..a2787d729 100644 --- a/plugins/kv/drivers/memcached/driver.go +++ b/plugins/kv/drivers/memcached/driver.go @@ -6,10 +6,10 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" ) type Driver struct { diff --git a/plugins/kv/drivers/memcached/plugin.go b/plugins/kv/drivers/memcached/plugin.go index 936b20477..22ea5ccac 100644 --- a/plugins/kv/drivers/memcached/plugin.go +++ b/plugins/kv/drivers/memcached/plugin.go @@ -34,7 +34,7 @@ func (s *Plugin) Name() string { // Available interface implementation func (s *Plugin) Available() {} -func (s *Plugin) KVProvide(key string) (kv.Storage, error) { +func (s *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("boltdb_plugin_provide") st, err := NewMemcachedDriver(s.log, key, s.cfgPlugin) if err != nil { diff --git a/plugins/kv/interface.go b/plugins/kv/interface.go index 5aedd5c3c..ffdbbe62c 100644 --- a/plugins/kv/interface.go +++ b/plugins/kv/interface.go @@ -1,6 +1,6 @@ package kv -import kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" +import kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" // Storage represents single abstract storage. type Storage interface { @@ -29,13 +29,8 @@ type Storage interface { Delete(keys ...string) error } -// StorageDriver interface provide storage -type StorageDriver interface { - Provider -} - -// Provider provides storage based on the config -type Provider interface { - // Provide provides Storage based on the config key - KVProvide(key string) (Storage, error) +// Constructor provides storage based on the config +type Constructor interface { + // KVConstruct provides Storage based on the config key + KVConstruct(key string) (Storage, error) } diff --git a/plugins/kv/plugin.go b/plugins/kv/plugin.go index efe922529..03dbaed69 100644 --- a/plugins/kv/plugin.go +++ b/plugins/kv/plugin.go @@ -24,8 +24,8 @@ const ( // Plugin for the unified storage type Plugin struct { log logger.Logger - // drivers contains general storage drivers, such as boltdb, memory, memcached, redis. - drivers map[string]StorageDriver + // constructors contains general storage constructors, such as boltdb, memory, memcached, redis. + constructors map[string]Constructor // storages contains user-defined storages, such as boltdb-north, memcached-us and so on. storages map[string]Storage // KV configuration @@ -43,7 +43,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { if err != nil { return errors.E(op, err) } - p.drivers = make(map[string]StorageDriver, 5) + p.constructors = make(map[string]Constructor, 5) p.storages = make(map[string]Storage, 5) p.log = log p.cfgPlugin = cfg @@ -81,13 +81,27 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit addr: [ "localhost:11211" ] - For this config we should have 3 drivers: memory, boltdb and memcached but 4 KVs: default, boltdb-south, boltdb-north and memcached + For this config we should have 3 constructors: memory, boltdb and memcached but 4 KVs: default, boltdb-south, boltdb-north and memcached when user requests for example boltdb-south, we should provide that particular preconfigured storage */ for k, v := range p.cfg.Data { - if _, ok := v.(map[string]interface{})[driver]; !ok { - errCh <- errors.E(op, errors.Errorf("could not find mandatory driver field in the %s storage", k)) - return errCh + // for example if the key not properly formatted (yaml) + if v == nil { + continue + } + + // check type of the v + // should be a map[string]interface{} + switch t := v.(type) { + // correct type + case map[string]interface{}: + if _, ok := t[driver]; !ok { + errCh <- errors.E(op, errors.Errorf("could not find mandatory driver field in the %s storage", k)) + return errCh + } + default: + p.log.Warn("wrong type detected in the configuration, please, check yaml indentation") + continue } // config key for the particular sub-driver kv.memcached @@ -95,12 +109,12 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // at this point we know, that driver field present in the configuration switch v.(map[string]interface{})[driver] { case memcached: - if _, ok := p.drivers[memcached]; !ok { - p.log.Warn("no memcached drivers registered", "registered", p.drivers) + if _, ok := p.constructors[memcached]; !ok { + p.log.Warn("no memcached constructors registered", "registered", p.constructors) continue } - storage, err := p.drivers[memcached].KVProvide(configKey) + storage, err := p.constructors[memcached].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -110,12 +124,12 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit p.storages[k] = storage case boltdb: - if _, ok := p.drivers[boltdb]; !ok { - p.log.Warn("no boltdb drivers registered", "registered", p.drivers) + if _, ok := p.constructors[boltdb]; !ok { + p.log.Warn("no boltdb constructors registered", "registered", p.constructors) continue } - storage, err := p.drivers[boltdb].KVProvide(configKey) + storage, err := p.constructors[boltdb].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -124,12 +138,12 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // save the storage p.storages[k] = storage case memory: - if _, ok := p.drivers[memory]; !ok { - p.log.Warn("no in-memory drivers registered", "registered", p.drivers) + if _, ok := p.constructors[memory]; !ok { + p.log.Warn("no in-memory constructors registered", "registered", p.constructors) continue } - storage, err := p.drivers[memory].KVProvide(configKey) + storage, err := p.constructors[memory].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -138,15 +152,15 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // save the storage p.storages[k] = storage case redis: - if _, ok := p.drivers[redis]; !ok { - p.log.Warn("no redis drivers registered", "registered", p.drivers) + if _, ok := p.constructors[redis]; !ok { + p.log.Warn("no redis constructors registered", "registered", p.constructors) continue } // first - try local configuration switch { case p.cfgPlugin.Has(configKey): - storage, err := p.drivers[redis].KVProvide(configKey) + storage, err := p.constructors[redis].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -155,7 +169,7 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // save the storage p.storages[k] = storage case p.cfgPlugin.Has(redis): - storage, err := p.drivers[redis].KVProvide(configKey) + storage, err := p.constructors[redis].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -189,9 +203,9 @@ func (p *Plugin) Collects() []interface{} { } } -func (p *Plugin) GetAllStorageDrivers(name endure.Named, storage StorageDriver) { - // save the storage driver - p.drivers[name.Name()] = storage +func (p *Plugin) GetAllStorageDrivers(name endure.Named, constructor Constructor) { + // save the storage constructor + p.constructors[name.Name()] = constructor } // RPC returns associated rpc service. diff --git a/plugins/kv/rpc.go b/plugins/kv/rpc.go index ab1f7f311..af763600d 100644 --- a/plugins/kv/rpc.go +++ b/plugins/kv/rpc.go @@ -2,8 +2,8 @@ package kv import ( "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" ) // Wrapper for the plugin diff --git a/plugins/memory/kv.go b/plugins/memory/kv.go index 9b7d72599..1cf031d13 100644 --- a/plugins/memory/kv.go +++ b/plugins/memory/kv.go @@ -6,10 +6,10 @@ import ( "time" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" ) type Driver struct { diff --git a/plugins/memory/plugin.go b/plugins/memory/plugin.go index 6151ebf0d..70badf15b 100644 --- a/plugins/memory/plugin.go +++ b/plugins/memory/plugin.go @@ -41,11 +41,11 @@ func (p *Plugin) Stop() error { return nil } -func (p *Plugin) PSProvide(key string) (pubsub.PubSub, error) { +func (p *Plugin) PSConstruct(key string) (pubsub.PubSub, error) { return NewPubSubDriver(p.log, key) } -func (p *Plugin) KVProvide(key string) (kv.Storage, error) { +func (p *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("inmemory_plugin_provide") st, err := NewInMemoryDriver(p.log, key, p.cfgPlugin, p.stop) if err != nil { diff --git a/plugins/memory/pubsub.go b/plugins/memory/pubsub.go index 75cd9d245..d027a8a5f 100644 --- a/plugins/memory/pubsub.go +++ b/plugins/memory/pubsub.go @@ -4,16 +4,14 @@ import ( "sync" "github.com/spiral/roadrunner/v2/pkg/bst" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" - "google.golang.org/protobuf/proto" ) type PubSubDriver struct { sync.RWMutex // channel with the messages from the RPC - pushCh chan []byte + pushCh chan *pubsub.Message // user-subscribed topics storage bst.Storage log logger.Logger @@ -21,21 +19,21 @@ type PubSubDriver struct { func NewPubSubDriver(log logger.Logger, _ string) (pubsub.PubSub, error) { ps := &PubSubDriver{ - pushCh: make(chan []byte, 10), + pushCh: make(chan *pubsub.Message, 10), storage: bst.NewBST(), log: log, } return ps, nil } -func (p *PubSubDriver) Publish(message []byte) error { - p.pushCh <- message +func (p *PubSubDriver) Publish(msg *pubsub.Message) error { + p.pushCh <- msg return nil } -func (p *PubSubDriver) PublishAsync(message []byte) { +func (p *PubSubDriver) PublishAsync(msg *pubsub.Message) { go func() { - p.pushCh <- message + p.pushCh <- msg }() } @@ -67,7 +65,7 @@ func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) { } } -func (p *PubSubDriver) Next() (*websocketsv1.Message, error) { +func (p *PubSubDriver) Next() (*pubsub.Message, error) { msg := <-p.pushCh if msg == nil { return nil, nil @@ -76,20 +74,13 @@ func (p *PubSubDriver) Next() (*websocketsv1.Message, error) { p.RLock() defer p.RUnlock() - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - return nil, err - } - - // push only messages, which are subscribed + // push only messages, which topics are subscibed // TODO better??? - for i := 0; i < len(m.GetTopics()); i++ { - // if we have active subscribers - send a message to a topic - // or send nil instead - if ok := p.storage.Contains(m.GetTopics()[i]); ok { - return m, nil - } + // if we have active subscribers - send a message to a topic + // or send nil instead + if ok := p.storage.Contains(msg.Topic); ok { + return msg, nil } + return nil, nil } diff --git a/plugins/redis/channel.go b/plugins/redis/channel.go new file mode 100644 index 000000000..5817853c9 --- /dev/null +++ b/plugins/redis/channel.go @@ -0,0 +1,97 @@ +package redis + +import ( + "context" + "sync" + + "github.com/go-redis/redis/v8" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/utils" +) + +type redisChannel struct { + sync.Mutex + + // redis client + client redis.UniversalClient + pubsub *redis.PubSub + + log logger.Logger + + // out channel with all subs + out chan *pubsub.Message + + exit chan struct{} +} + +func newRedisChannel(redisClient redis.UniversalClient, log logger.Logger) *redisChannel { + out := make(chan *pubsub.Message, 100) + fi := &redisChannel{ + out: out, + client: redisClient, + pubsub: redisClient.Subscribe(context.Background()), + exit: make(chan struct{}), + log: log, + } + + // start reading messages + go fi.read() + + return fi +} + +func (r *redisChannel) sub(topics ...string) error { + const op = errors.Op("redis_sub") + err := r.pubsub.Subscribe(context.Background(), topics...) + if err != nil { + return errors.E(op, err) + } + return nil +} + +// read reads messages from the pubsub subscription +func (r *redisChannel) read() { + for { + select { + // here we receive message from us (which we sent before in Publish) + // it should be compatible with the pubsub.Message structure + // payload should be in the redis.message.payload field + + case msg, ok := <-r.pubsub.Channel(): + // channel closed + if !ok { + return + } + + r.out <- &pubsub.Message{ + Topic: msg.Channel, + Payload: utils.AsBytes(msg.Payload), + } + + case <-r.exit: + return + } + } +} + +func (r *redisChannel) unsub(topic string) error { + const op = errors.Op("redis_unsub") + err := r.pubsub.Unsubscribe(context.Background(), topic) + if err != nil { + return errors.E(op, err) + } + return nil +} + +func (r *redisChannel) stop() error { + r.exit <- struct{}{} + close(r.out) + close(r.exit) + return nil +} + +func (r *redisChannel) message() *pubsub.Message { + return <-r.out +} diff --git a/plugins/redis/fanin.go b/plugins/redis/fanin.go deleted file mode 100644 index ac9ebcc2f..000000000 --- a/plugins/redis/fanin.go +++ /dev/null @@ -1,102 +0,0 @@ -package redis - -import ( - "context" - "sync" - - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" - "github.com/spiral/roadrunner/v2/plugins/logger" - "google.golang.org/protobuf/proto" - - "github.com/go-redis/redis/v8" - "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/utils" -) - -type FanIn struct { - sync.Mutex - - // redis client - client redis.UniversalClient - pubsub *redis.PubSub - - log logger.Logger - - // out channel with all subs - out chan *websocketsv1.Message - - exit chan struct{} -} - -func newFanIn(redisClient redis.UniversalClient, log logger.Logger) *FanIn { - out := make(chan *websocketsv1.Message, 100) - fi := &FanIn{ - out: out, - client: redisClient, - pubsub: redisClient.Subscribe(context.Background()), - exit: make(chan struct{}), - log: log, - } - - // start reading messages - go fi.read() - - return fi -} - -func (fi *FanIn) sub(topics ...string) error { - const op = errors.Op("fanin_addchannel") - err := fi.pubsub.Subscribe(context.Background(), topics...) - if err != nil { - return errors.E(op, err) - } - return nil -} - -// read reads messages from the pubsub subscription -func (fi *FanIn) read() { - for { - select { - // here we receive message from us (which we sent before in Publish) - // it should be compatible with the websockets.Msg interface - // payload should be in the redis.message.payload field - - case msg, ok := <-fi.pubsub.Channel(): - // channel closed - if !ok { - return - } - - m := &websocketsv1.Message{} - err := proto.Unmarshal(utils.AsBytes(msg.Payload), m) - if err != nil { - fi.log.Error("message unmarshal") - continue - } - - fi.out <- m - case <-fi.exit: - return - } - } -} - -func (fi *FanIn) unsub(topic string) error { - const op = errors.Op("fanin_remove") - err := fi.pubsub.Unsubscribe(context.Background(), topic) - if err != nil { - return errors.E(op, err) - } - return nil -} - -func (fi *FanIn) stop() error { - fi.exit <- struct{}{} - close(fi.out) - close(fi.exit) - return nil -} - -func (fi *FanIn) consume() <-chan *websocketsv1.Message { - return fi.out -} diff --git a/plugins/redis/kv.go b/plugins/redis/kv.go index 66cb83846..320b74437 100644 --- a/plugins/redis/kv.go +++ b/plugins/redis/kv.go @@ -7,10 +7,10 @@ import ( "github.com/go-redis/redis/v8" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/utils" ) diff --git a/plugins/redis/plugin.go b/plugins/redis/plugin.go index 24c21b558..9d98790b2 100644 --- a/plugins/redis/plugin.go +++ b/plugins/redis/plugin.go @@ -59,8 +59,8 @@ func (p *Plugin) Name() string { // Available interface implementation func (p *Plugin) Available() {} -// KVProvide provides KV storage implementation over the redis plugin -func (p *Plugin) KVProvide(key string) (kv.Storage, error) { +// KVConstruct provides KV storage implementation over the redis plugin +func (p *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("redis_plugin_provide") st, err := NewRedisDriver(p.log, key, p.cfgPlugin) if err != nil { @@ -70,6 +70,6 @@ func (p *Plugin) KVProvide(key string) (kv.Storage, error) { return st, nil } -func (p *Plugin) PSProvide(key string) (pubsub.PubSub, error) { +func (p *Plugin) PSConstruct(key string) (pubsub.PubSub, error) { return NewPubSubDriver(p.log, key, p.cfgPlugin, p.stopCh) } diff --git a/plugins/redis/pubsub.go b/plugins/redis/pubsub.go index dbda7ea4a..4e41acb52 100644 --- a/plugins/redis/pubsub.go +++ b/plugins/redis/pubsub.go @@ -6,11 +6,9 @@ import ( "github.com/go-redis/redis/v8" "github.com/spiral/errors" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/logger" - "google.golang.org/protobuf/proto" ) type PubSubDriver struct { @@ -18,7 +16,7 @@ type PubSubDriver struct { cfg *Config `mapstructure:"redis"` log logger.Logger - fanin *FanIn + channel *redisChannel universalClient redis.UniversalClient stopCh chan struct{} } @@ -62,7 +60,12 @@ func NewPubSubDriver(log logger.Logger, key string, cfgPlugin config.Configurer, MasterName: ps.cfg.MasterName, }) - ps.fanin = newFanIn(ps.universalClient, log) + statusCmd := ps.universalClient.Ping(context.Background()) + if statusCmd.Err() != nil { + return nil, statusCmd.Err() + } + + ps.channel = newRedisChannel(ps.universalClient, log) ps.stop() @@ -72,47 +75,32 @@ func NewPubSubDriver(log logger.Logger, key string, cfgPlugin config.Configurer, func (p *PubSubDriver) stop() { go func() { for range p.stopCh { - _ = p.fanin.stop() + _ = p.channel.stop() return } }() } -func (p *PubSubDriver) Publish(msg []byte) error { +func (p *PubSubDriver) Publish(msg *pubsub.Message) error { p.Lock() defer p.Unlock() - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - return errors.E(err) + f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload) + if f.Err() != nil { + return f.Err() } - for j := 0; j < len(m.GetTopics()); j++ { - f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) - if f.Err() != nil { - return f.Err() - } - } return nil } -func (p *PubSubDriver) PublishAsync(msg []byte) { +func (p *PubSubDriver) PublishAsync(msg *pubsub.Message) { go func() { p.Lock() defer p.Unlock() - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - p.log.Error("message unmarshal error") - return - } - for j := 0; j < len(m.GetTopics()); j++ { - f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) - if f.Err() != nil { - p.log.Error("redis publish", "error", f.Err()) - } + f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload) + if f.Err() != nil { + p.log.Error("redis publish", "error", f.Err()) } }() } @@ -128,13 +116,13 @@ func (p *PubSubDriver) Subscribe(connectionID string, topics ...string) error { return err } if res == 0 { - p.log.Warn("could not subscribe to the provided topic", "connectionID", connectionID, "topic", topics[i]) + p.log.Warn("could not subscribe to the provided topic, you might be already subscribed to it", "connectionID", connectionID, "topic", topics[i]) continue } } // and subscribe after - return p.fanin.sub(topics...) + return p.channel.sub(topics...) } func (p *PubSubDriver) Unsubscribe(connectionID string, topics ...string) error { @@ -160,7 +148,7 @@ func (p *PubSubDriver) Unsubscribe(connectionID string, topics ...string) error } // else - unsubscribe - err = p.fanin.unsub(topics[i]) + err = p.channel.unsub(topics[i]) if err != nil { return err } @@ -176,7 +164,7 @@ func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) { panic(err) } - // assighn connections + // assign connections // res expected to be from the sync.Pool for k := range r { res[k] = struct{}{} @@ -184,6 +172,6 @@ func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) { } // Next return next message -func (p *PubSubDriver) Next() (*websocketsv1.Message, error) { - return <-p.fanin.consume(), nil +func (p *PubSubDriver) Next() (*pubsub.Message, error) { + return p.channel.message(), nil } diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go index 93d9ac3b3..933a12e0e 100644 --- a/plugins/websockets/config.go +++ b/plugins/websockets/config.go @@ -1,75 +1,45 @@ package websockets import ( + "strings" "time" + "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pool" ) /* -# GLOBAL -redis: - addrs: - - 'localhost:6379' - websockets: - # pubsubs should implement PubSub interface to be collected via endure.Collects - - pubsubs:["redis", "amqp", "memory"] - # OR local - redis: - addrs: - - 'localhost:6379' - - # path used as websockets path + broker: default + allowed_origin: "*" path: "/ws" */ -type RedisConfig struct { - Addrs []string `mapstructure:"addrs"` - DB int `mapstructure:"db"` - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - MasterName string `mapstructure:"master_name"` - SentinelPassword string `mapstructure:"sentinel_password"` - RouteByLatency bool `mapstructure:"route_by_latency"` - RouteRandomly bool `mapstructure:"route_randomly"` - MaxRetries int `mapstructure:"max_retries"` - DialTimeout time.Duration `mapstructure:"dial_timeout"` - MinRetryBackoff time.Duration `mapstructure:"min_retry_backoff"` - MaxRetryBackoff time.Duration `mapstructure:"max_retry_backoff"` - PoolSize int `mapstructure:"pool_size"` - MinIdleConns int `mapstructure:"min_idle_conns"` - MaxConnAge time.Duration `mapstructure:"max_conn_age"` - ReadTimeout time.Duration `mapstructure:"read_timeout"` - WriteTimeout time.Duration `mapstructure:"write_timeout"` - PoolTimeout time.Duration `mapstructure:"pool_timeout"` - IdleTimeout time.Duration `mapstructure:"idle_timeout"` - IdleCheckFreq time.Duration `mapstructure:"idle_check_freq"` - ReadOnly bool `mapstructure:"read_only"` -} - // Config represents configuration for the ws plugin type Config struct { // http path for the websocket - Path string `mapstructure:"path"` - // ["redis", "amqp", "memory"] - PubSubs []string `mapstructure:"pubsubs"` - Middleware []string `mapstructure:"middleware"` + Path string `mapstructure:"path"` + AllowedOrigin string `mapstructure:"allowed_origin"` + Broker string `mapstructure:"broker"` - Redis *RedisConfig `mapstructure:"redis"` + // wildcard origin + allowedWOrigins []wildcard + allowedOrigins []string + allowedAll bool + // Pool with the workers for the websockets Pool *pool.Config `mapstructure:"pool"` } // InitDefault initialize default values for the ws config -func (c *Config) InitDefault() { +func (c *Config) InitDefault() error { if c.Path == "" { c.Path = "/ws" } - if len(c.PubSubs) == 0 { - // memory used by default - c.PubSubs = append(c.PubSubs, "memory") + + // broker is mandatory + if c.Broker == "" { + return errors.Str("broker key should be specified") } if c.Pool == nil { @@ -86,16 +56,28 @@ func (c *Config) InitDefault() { if c.Pool.DestroyTimeout == 0 { c.Pool.DestroyTimeout = time.Minute } - if c.Pool.Supervisor == nil { - return + if c.Pool.Supervisor != nil { + c.Pool.Supervisor.InitDefaults() } - c.Pool.Supervisor.InitDefaults() } - if c.Redis != nil { - if c.Redis.Addrs == nil { - // append default - c.Redis.Addrs = append(c.Redis.Addrs, "localhost:6379") - } + if c.AllowedOrigin == "" { + c.AllowedOrigin = "*" } + + // Normalize + origin := strings.ToLower(c.AllowedOrigin) + if origin == "*" { + // If "*" is present in the list, turn the whole list into a match all + c.allowedAll = true + return nil + } else if i := strings.IndexByte(origin, '*'); i >= 0 { + // Split the origin in two: start and end string without the * + w := wildcard{origin[0:i], origin[i+1:]} + c.allowedWOrigins = append(c.allowedWOrigins, w) + } else { + c.allowedOrigins = append(c.allowedOrigins, origin) + } + + return nil } diff --git a/plugins/websockets/connection/connection.go b/plugins/websockets/connection/connection.go index 2b8471730..04c29d832 100644 --- a/plugins/websockets/connection/connection.go +++ b/plugins/websockets/connection/connection.go @@ -22,7 +22,7 @@ func NewConnection(wsConn *websocket.Conn, log logger.Logger) *Connection { } } -func (c *Connection) Write(mt int, data []byte) error { +func (c *Connection) Write(data []byte) error { c.Lock() defer c.Unlock() @@ -34,7 +34,7 @@ func (c *Connection) Write(mt int, data []byte) error { } }() - err := c.conn.WriteMessage(mt, data) + err := c.conn.WriteMessage(websocket.TextMessage, data) if err != nil { return errors.E(op, err) } diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go index e3d47166b..664b4dfd2 100644 --- a/plugins/websockets/executor/executor.go +++ b/plugins/websockets/executor/executor.go @@ -5,15 +5,14 @@ import ( "net/http" "sync" - "github.com/fasthttp/websocket" json "github.com/json-iterator/go" "github.com/spiral/errors" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/commands" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/validator" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" ) type Response struct { @@ -23,14 +22,15 @@ type Response struct { type Executor struct { sync.Mutex + // raw ws connection conn *connection.Connection log logger.Logger // associated connection ID connID string - // map with the pubsub drivers - pubsub map[string]pubsub.PubSub + // subscriber drivers + sub pubsub.Subscriber actualTopics map[string]struct{} req *http.Request @@ -39,12 +39,12 @@ type Executor struct { // NewExecutor creates protected connection and starts command loop func NewExecutor(conn *connection.Connection, log logger.Logger, - connID string, pubsubs map[string]pubsub.PubSub, av validator.AccessValidatorFn, r *http.Request) *Executor { + connID string, sub pubsub.Subscriber, av validator.AccessValidatorFn, r *http.Request) *Executor { return &Executor{ conn: conn, connID: connID, log: log, - pubsub: pubsubs, + sub: sub, accessValidator: av, actualTopics: make(map[string]struct{}, 10), req: r, @@ -68,20 +68,20 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit err = json.Unmarshal(data, msg) if err != nil { - e.log.Error("error unmarshal message", "error", err) + e.log.Error("unmarshal message", "error", err) continue } // nil message, continue if msg == nil { - e.log.Warn("get nil message, skipping") + e.log.Warn("nil message, skipping") continue } switch msg.Command { // handle leave case commands.Join: - e.log.Debug("get join command", "msg", msg) + e.log.Debug("received join command", "msg", msg) val, err := e.accessValidator(e.req, msg.Topics...) if err != nil { @@ -96,13 +96,13 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit packet, errJ := json.Marshal(resp) if errJ != nil { - e.log.Error("error marshal the body", "error", errJ) + e.log.Error("marshal the body", "error", errJ) return errors.E(op, fmt.Errorf("%v,%v", err, errJ)) } - errW := e.conn.Write(websocket.BinaryMessage, packet) + errW := e.conn.Write(packet) if errW != nil { - e.log.Error("error writing payload to the connection", "payload", packet, "error", errW) + e.log.Error("write payload to the connection", "payload", packet, "error", errW) return errors.E(op, fmt.Errorf("%v,%v", err, errW)) } @@ -116,27 +116,25 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit packet, err := json.Marshal(resp) if err != nil { - e.log.Error("error marshal the body", "error", err) + e.log.Error("marshal the body", "error", err) return errors.E(op, err) } - err = e.conn.Write(websocket.BinaryMessage, packet) + err = e.conn.Write(packet) if err != nil { - e.log.Error("error writing payload to the connection", "payload", packet, "error", err) + e.log.Error("write payload to the connection", "payload", packet, "error", err) return errors.E(op, err) } // subscribe to the topic - if br, ok := e.pubsub[msg.Broker]; ok { - err = e.Set(br, msg.Topics) - if err != nil { - return errors.E(op, err) - } + err = e.Set(msg.Topics) + if err != nil { + return errors.E(op, err) } // handle leave case commands.Leave: - e.log.Debug("get leave command", "msg", msg) + e.log.Debug("received leave command", "msg", msg) // prepare response resp := &Response{ @@ -146,21 +144,19 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit packet, err := json.Marshal(resp) if err != nil { - e.log.Error("error marshal the body", "error", err) + e.log.Error("marshal the body", "error", err) return errors.E(op, err) } - err = e.conn.Write(websocket.BinaryMessage, packet) + err = e.conn.Write(packet) if err != nil { - e.log.Error("error writing payload to the connection", "payload", packet, "error", err) + e.log.Error("write payload to the connection", "payload", packet, "error", err) return errors.E(op, err) } - if br, ok := e.pubsub[msg.Broker]; ok { - err = e.Leave(br, msg.Topics) - if err != nil { - return errors.E(op, err) - } + err = e.Leave(msg.Topics) + if err != nil { + return errors.E(op, err) } case commands.Headers: @@ -171,13 +167,13 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit } } -func (e *Executor) Set(br pubsub.PubSub, topics []string) error { +func (e *Executor) Set(topics []string) error { // associate connection with topics - err := br.Subscribe(e.connID, topics...) + err := e.sub.Subscribe(e.connID, topics...) if err != nil { - e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + e.log.Error("subscribe to the provided topics", "topics", topics, "error", err.Error()) // in case of error, unsubscribe connection from the dead topics - _ = br.Unsubscribe(e.connID, topics...) + _ = e.sub.Unsubscribe(e.connID, topics...) return err } @@ -189,11 +185,11 @@ func (e *Executor) Set(br pubsub.PubSub, topics []string) error { return nil } -func (e *Executor) Leave(br pubsub.PubSub, topics []string) error { +func (e *Executor) Leave(topics []string) error { // remove associated connections from the storage - err := br.Unsubscribe(e.connID, topics...) + err := e.sub.Unsubscribe(e.connID, topics...) if err != nil { - e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + e.log.Error("subscribe to the provided topics", "topics", topics, "error", err.Error()) return err } @@ -208,10 +204,7 @@ func (e *Executor) Leave(br pubsub.PubSub, topics []string) error { func (e *Executor) CleanUp() { // unsubscribe particular connection from the topics for topic := range e.actualTopics { - // here - for _, ps := range e.pubsub { - _ = ps.Unsubscribe(e.connID, topic) - } + _ = e.sub.Unsubscribe(e.connID, topic) } // clean up the actualTopics data diff --git a/plugins/websockets/origin.go b/plugins/websockets/origin.go new file mode 100644 index 000000000..c6d9c9b8b --- /dev/null +++ b/plugins/websockets/origin.go @@ -0,0 +1,28 @@ +package websockets + +import ( + "strings" +) + +func isOriginAllowed(origin string, cfg *Config) bool { + if cfg.allowedAll { + return true + } + + origin = strings.ToLower(origin) + // simple case + origin = strings.ToLower(origin) + for _, o := range cfg.allowedOrigins { + if o == origin { + return true + } + } + // check wildcards + for _, w := range cfg.allowedWOrigins { + if w.match(origin) { + return true + } + } + + return false +} diff --git a/plugins/websockets/origin_test.go b/plugins/websockets/origin_test.go new file mode 100644 index 000000000..bbc49bbb6 --- /dev/null +++ b/plugins/websockets/origin_test.go @@ -0,0 +1,73 @@ +package websockets + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfig_Origin(t *testing.T) { + cfg := &Config{ + AllowedOrigin: "*", + Broker: "any", + } + + err := cfg.InitDefault() + assert.NoError(t, err) + + assert.True(t, isOriginAllowed("http://some.some.some.sssome", cfg)) + assert.True(t, isOriginAllowed("http://", cfg)) + assert.True(t, isOriginAllowed("http://google.com", cfg)) + assert.True(t, isOriginAllowed("ws://*", cfg)) + assert.True(t, isOriginAllowed("*", cfg)) + assert.True(t, isOriginAllowed("you are bad programmer", cfg)) // True :( + assert.True(t, isOriginAllowed("****", cfg)) + assert.True(t, isOriginAllowed("asde!@#!!@#!%", cfg)) + assert.True(t, isOriginAllowed("http://*.domain.com", cfg)) +} + +func TestConfig_OriginWildCard(t *testing.T) { + cfg := &Config{ + AllowedOrigin: "https://*my.site.com", + Broker: "any", + } + + err := cfg.InitDefault() + assert.NoError(t, err) + + assert.True(t, isOriginAllowed("https://my.site.com", cfg)) + assert.False(t, isOriginAllowed("http://", cfg)) + assert.False(t, isOriginAllowed("http://google.com", cfg)) + assert.False(t, isOriginAllowed("ws://*", cfg)) + assert.False(t, isOriginAllowed("*", cfg)) + assert.False(t, isOriginAllowed("you are bad programmer", cfg)) // True :( + assert.False(t, isOriginAllowed("****", cfg)) + assert.False(t, isOriginAllowed("asde!@#!!@#!%", cfg)) + assert.False(t, isOriginAllowed("http://*.domain.com", cfg)) + + assert.False(t, isOriginAllowed("https://*site.com", cfg)) + assert.True(t, isOriginAllowed("https://some.my.site.com", cfg)) +} + +func TestConfig_OriginWildCard2(t *testing.T) { + cfg := &Config{ + AllowedOrigin: "https://my.*.com", + Broker: "any", + } + + err := cfg.InitDefault() + assert.NoError(t, err) + + assert.True(t, isOriginAllowed("https://my.site.com", cfg)) + assert.False(t, isOriginAllowed("http://", cfg)) + assert.False(t, isOriginAllowed("http://google.com", cfg)) + assert.False(t, isOriginAllowed("ws://*", cfg)) + assert.False(t, isOriginAllowed("*", cfg)) + assert.False(t, isOriginAllowed("you are bad programmer", cfg)) // True :( + assert.False(t, isOriginAllowed("****", cfg)) + assert.False(t, isOriginAllowed("asde!@#!!@#!%", cfg)) + assert.False(t, isOriginAllowed("http://*.domain.com", cfg)) + + assert.False(t, isOriginAllowed("https://*site.com", cfg)) + assert.True(t, isOriginAllowed("https://my.bad.com", cfg)) +} diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index a1002bddb..ca5f2f593 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -2,7 +2,6 @@ package websockets import ( "context" - "fmt" "net/http" "sync" "time" @@ -10,14 +9,13 @@ import ( "github.com/fasthttp/websocket" "github.com/google/uuid" json "github.com/json-iterator/go" - endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/payload" phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/process" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/pkg/worker" + "github.com/spiral/roadrunner/v2/plugins/broadcast" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -26,7 +24,6 @@ import ( "github.com/spiral/roadrunner/v2/plugins/websockets/executor" "github.com/spiral/roadrunner/v2/plugins/websockets/pool" "github.com/spiral/roadrunner/v2/plugins/websockets/validator" - "google.golang.org/protobuf/proto" ) const ( @@ -35,14 +32,14 @@ const ( type Plugin struct { sync.RWMutex - // Collection with all available pubsubs - pubsubs map[string]pubsub.PubSub - psProviders map[string]pubsub.PSProvider + // subscriber+reader interfaces + subReader pubsub.SubReader + // broadcaster + broadcaster broadcast.Broadcaster - cfg *Config - cfgPlugin config.Configurer - log logger.Logger + cfg *Config + log logger.Logger // global connections map connections sync.Map @@ -53,14 +50,16 @@ type Plugin struct { wsUpgrade *websocket.Upgrader serveExit chan struct{} + // workers pool phpPool phpPool.Pool - server server.Server + // server which produces commands to the pool + server server.Server // function used to validate access to the requested resource accessValidator validator.AccessValidatorFn } -func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error { +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server, b broadcast.Broadcaster) error { const op = errors.Op("websockets_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) @@ -71,36 +70,32 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se return errors.E(op, err) } - p.cfg.InitDefault() - p.pubsubs = make(map[string]pubsub.PubSub) - p.psProviders = make(map[string]pubsub.PSProvider) - - p.log = log - p.cfgPlugin = cfg + err = p.cfg.InitDefault() + if err != nil { + return errors.E(op, err) + } p.wsUpgrade = &websocket.Upgrader{ HandshakeTimeout: time.Second * 60, ReadBufferSize: 1024, WriteBufferSize: 1024, - WriteBufferPool: nil, - Subprotocols: nil, - Error: nil, CheckOrigin: func(r *http.Request) bool { - return true + return isOriginAllowed(r.Header.Get("Origin"), p.cfg) }, - EnableCompression: false, } p.serveExit = make(chan struct{}) p.server = server - + p.log = log + p.broadcaster = b return nil } func (p *Plugin) Serve() chan error { - errCh := make(chan error, 1) const op = errors.Op("websockets_plugin_serve") - - err := p.initPubSubs() + errCh := make(chan error, 1) + // init broadcaster + var err error + p.subReader, err = p.broadcaster.GetDriver(p.cfg.Broker) if err != nil { errCh <- errors.E(op, err) return errCh @@ -126,76 +121,26 @@ func (p *Plugin) Serve() chan error { p.accessValidator = p.defaultAccessValidator(p.phpPool) }() - p.workersPool = pool.NewWorkersPool(p.pubsubs, &p.connections, p.log) - - // run all pubsubs drivers - for _, v := range p.pubsubs { - go func(ps pubsub.PubSub) { - for { - select { - case <-p.serveExit: - return - default: - data, err := ps.Next() - if err != nil { - errCh <- err - return - } - p.workersPool.Queue(data) - } - } - }(v) - } - - return errCh -} - -func (p *Plugin) initPubSubs() error { - for i := 0; i < len(p.cfg.PubSubs); i++ { - // don't need to have a section for the in-memory - if p.cfg.PubSubs[i] == "memory" { - if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { - r, err := provider.PSProvide("") - if err != nil { - return err - } + p.workersPool = pool.NewWorkersPool(p.subReader, &p.connections, p.log) - // append default in-memory provider - p.pubsubs["memory"] = r - } - continue - } - // key - memory, redis - if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { - // try local key - switch { - // try local config first - case p.cfgPlugin.Has(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])): - r, err := provider.PSProvide(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])) - if err != nil { - return err - } - - // append redis provider - p.pubsubs[p.cfg.PubSubs[i]] = r - case p.cfgPlugin.Has(p.cfg.PubSubs[i]): - r, err := provider.PSProvide(p.cfg.PubSubs[i]) + // we need here only Reader part of the interface + go func(ps pubsub.Reader) { + for { + select { + case <-p.serveExit: + return + default: + data, err := ps.Next() if err != nil { - return err + errCh <- err + return } - - // append redis provider - p.pubsubs[p.cfg.PubSubs[i]] = r - default: - return errors.Errorf("could not find configuration sections for the %s", p.cfg.PubSubs[i]) + p.workersPool.Queue(data) } - } else { - // no such driver - p.log.Warn("no such driver", "requested", p.cfg.PubSubs[i], "available", p.psProviders) } - } + }(p.subReader) - return nil + return errCh } func (p *Plugin) Stop() error { @@ -212,30 +157,12 @@ func (p *Plugin) Stop() error { return nil } -func (p *Plugin) Collects() []interface{} { - return []interface{}{ - p.GetPublishers, - } -} - func (p *Plugin) Available() {} -func (p *Plugin) RPC() interface{} { - return &rpc{ - plugin: p, - log: p.log, - } -} - func (p *Plugin) Name() string { return PluginName } -// GetPublishers collects all pubsubs -func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PSProvider) { - p.psProviders[name.Name()] = pub -} - func (p *Plugin) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != p.cfg.Path { @@ -281,7 +208,7 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { p.connections.Store(connectionID, safeConn) // Executor wraps a connection to have a safe abstraction - e := executor.NewExecutor(safeConn, p.log, connectionID, p.pubsubs, p.accessValidator, r) + e := executor.NewExecutor(safeConn, p.log, connectionID, p.subReader, p.accessValidator, r) p.log.Info("websocket client connected", "uuid", connectionID) err = e.StartCommandLoop() @@ -365,55 +292,6 @@ func (p *Plugin) Reset() error { return nil } -// Publish is an entry point to the websocket PUBSUB -func (p *Plugin) Publish(m []byte) error { - p.Lock() - defer p.Unlock() - - msg := &websocketsv1.Message{} - err := proto.Unmarshal(m, msg) - if err != nil { - return err - } - - // Get payload - for i := 0; i < len(msg.GetTopics()); i++ { - if br, ok := p.pubsubs[msg.GetBroker()]; ok { - err := br.Publish(m) - if err != nil { - return errors.E(err) - } - } else { - p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg.GetBroker()) - } - } - return nil -} - -func (p *Plugin) PublishAsync(m []byte) { - go func() { - p.Lock() - defer p.Unlock() - msg := &websocketsv1.Message{} - err := proto.Unmarshal(m, msg) - if err != nil { - p.log.Error("message unmarshal") - } - - // Get payload - for i := 0; i < len(msg.GetTopics()); i++ { - if br, ok := p.pubsubs[msg.GetBroker()]; ok { - err := br.Publish(m) - if err != nil { - p.log.Error("publish async error", "error", err) - } - } else { - p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg.GetBroker()) - } - } - }() -} - func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn { return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) { const op = errors.Op("access_validator") diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go index 1a7c6f8a8..752ba3ce7 100644 --- a/plugins/websockets/pool/workers_pool.go +++ b/plugins/websockets/pool/workers_pool.go @@ -3,29 +3,29 @@ package pool import ( "sync" - "github.com/fasthttp/websocket" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + json "github.com/json-iterator/go" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" + "github.com/spiral/roadrunner/v2/utils" ) type WorkersPool struct { - storage map[string]pubsub.PubSub + subscriber pubsub.Subscriber connections *sync.Map resPool sync.Pool log logger.Logger - queue chan *websocketsv1.Message + queue chan *pubsub.Message exit chan struct{} } // NewWorkersPool constructs worker pool for the websocket connections -func NewWorkersPool(pubsubs map[string]pubsub.PubSub, connections *sync.Map, log logger.Logger) *WorkersPool { +func NewWorkersPool(subscriber pubsub.Subscriber, connections *sync.Map, log logger.Logger) *WorkersPool { wp := &WorkersPool{ connections: connections, - queue: make(chan *websocketsv1.Message, 100), - storage: pubsubs, + queue: make(chan *pubsub.Message, 100), + subscriber: subscriber, log: log, exit: make(chan struct{}), } @@ -42,7 +42,7 @@ func NewWorkersPool(pubsubs map[string]pubsub.PubSub, connections *sync.Map, log return wp } -func (wp *WorkersPool) Queue(msg *websocketsv1.Message) { +func (wp *WorkersPool) Queue(msg *pubsub.Message) { wp.queue <- msg } @@ -67,6 +67,12 @@ func (wp *WorkersPool) get() map[string]struct{} { return wp.resPool.Get().(map[string]struct{}) } +// Response from the server +type Response struct { + Topic string `json:"topic"` + Payload string `json:"payload"` +} + func (wp *WorkersPool) do() { //nolint:gocognit go func() { for { @@ -76,57 +82,50 @@ func (wp *WorkersPool) do() { //nolint:gocognit return } _ = msg - if msg == nil { - continue - } - if len(msg.GetTopics()) == 0 { - continue - } - - br, ok := wp.storage[msg.Broker] - if !ok { - wp.log.Warn("no such broker", "requested", msg.GetBroker(), "available", wp.storage) + if msg == nil || msg.Topic == "" { continue } + // get free map res := wp.get() - for i := 0; i < len(msg.GetTopics()); i++ { - // get connections for the particular topic - br.Connections(msg.GetTopics()[i], res) - } + // get connections for the particular topic + wp.subscriber.Connections(msg.Topic, res) if len(res) == 0 { - for i := 0; i < len(msg.GetTopics()); i++ { - wp.log.Info("no such topic", "topic", msg.GetTopics()[i]) - } + wp.log.Info("no connections associated with provided topic", "topic", msg.Topic) wp.put(res) continue } - for i := range res { - c, ok := wp.connections.Load(i) + // res is a map with a connectionsID + for connID := range res { + c, ok := wp.connections.Load(connID) if !ok { - for i := 0; i < len(msg.GetTopics()); i++ { - wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.GetBroker(), "topics", msg.GetTopics()[i]) - } + wp.log.Warn("the websocket disconnected before the message being written to it", "topics", msg.Topic) + wp.put(res) continue } - conn := c.(*connection.Connection) + d, err := json.Marshal(&Response{ + Topic: msg.Topic, + Payload: utils.AsString(msg.Payload), + }) + + if err != nil { + wp.log.Error("error marshaling response", "error", err) + wp.put(res) + break + } // put data into the bytes buffer - err := conn.Write(websocket.BinaryMessage, msg.GetPayload()) + err = c.(*connection.Connection).Write(d) if err != nil { - for i := 0; i < len(msg.GetTopics()); i++ { - wp.log.Error("error sending payload over the connection", "error", err, "broker", msg.GetBroker(), "topics", msg.GetTopics()[i]) - } + wp.log.Error("error sending payload over the connection", "error", err, "topic", msg.Topic) + wp.put(res) continue } } - - // put map with results back - wp.put(res) case <-wp.exit: wp.log.Info("get exit signal, exiting from the workers pool") return diff --git a/plugins/websockets/rpc.go b/plugins/websockets/rpc.go deleted file mode 100644 index 341e0b2aa..000000000 --- a/plugins/websockets/rpc.go +++ /dev/null @@ -1,75 +0,0 @@ -package websockets - -import ( - "github.com/spiral/errors" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" - "github.com/spiral/roadrunner/v2/plugins/logger" - "google.golang.org/protobuf/proto" -) - -// rpc collectors struct -type rpc struct { - plugin *Plugin - log logger.Logger -} - -// Publish ... msg is a proto decoded payload -// see: pkg/pubsub/message.fbs -func (r *rpc) Publish(in *websocketsv1.Request, out *websocketsv1.Response) error { - const op = errors.Op("broadcast_publish") - - // just return in case of nil message - if in == nil { - out.Ok = false - return nil - } - - r.log.Debug("message published", "msg", in.Messages) - - msgLen := len(in.GetMessages()) - - for i := 0; i < msgLen; i++ { - bb, err := proto.Marshal(in.GetMessages()[i]) - if err != nil { - return errors.E(op, err) - } - - err = r.plugin.Publish(bb) - if err != nil { - out.Ok = false - return errors.E(op, err) - } - } - - out.Ok = true - return nil -} - -// PublishAsync ... -// see: pkg/pubsub/message.fbs -func (r *rpc) PublishAsync(in *websocketsv1.Request, out *websocketsv1.Response) error { - const op = errors.Op("publish_async") - - // just return in case of nil message - if in == nil { - out.Ok = false - return nil - } - - r.log.Debug("message published", "msg", in.Messages) - - msgLen := len(in.GetMessages()) - - for i := 0; i < msgLen; i++ { - bb, err := proto.Marshal(in.GetMessages()[i]) - if err != nil { - out.Ok = false - return errors.E(op, err) - } - - r.plugin.PublishAsync(bb) - } - - out.Ok = true - return nil -} diff --git a/plugins/websockets/wildcard.go b/plugins/websockets/wildcard.go new file mode 100644 index 000000000..2f1c6601d --- /dev/null +++ b/plugins/websockets/wildcard.go @@ -0,0 +1,12 @@ +package websockets + +import "strings" + +type wildcard struct { + prefix string + suffix string +} + +func (w wildcard) match(s string) bool { + return len(s) >= len(w.prefix)+len(w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix) +} diff --git a/pkg/proto/kv/v1beta/kv.pb.go b/proto/kv/v1beta/kv.pb.go similarity index 100% rename from pkg/proto/kv/v1beta/kv.pb.go rename to proto/kv/v1beta/kv.pb.go diff --git a/pkg/proto/kv/v1beta/kv.proto b/proto/kv/v1beta/kv.proto similarity index 100% rename from pkg/proto/kv/v1beta/kv.proto rename to proto/kv/v1beta/kv.proto diff --git a/pkg/proto/websockets/v1beta/websockets.pb.go b/proto/websockets/v1beta/websockets.pb.go similarity index 81% rename from pkg/proto/websockets/v1beta/websockets.pb.go rename to proto/websockets/v1beta/websockets.pb.go index d39b55daa..ad4ebbe73 100644 --- a/pkg/proto/websockets/v1beta/websockets.pb.go +++ b/proto/websockets/v1beta/websockets.pb.go @@ -26,9 +26,8 @@ type Message struct { unknownFields protoimpl.UnknownFields Command string `protobuf:"bytes,1,opt,name=command,proto3" json:"command,omitempty"` - Broker string `protobuf:"bytes,2,opt,name=broker,proto3" json:"broker,omitempty"` - Topics []string `protobuf:"bytes,3,rep,name=topics,proto3" json:"topics,omitempty"` - Payload []byte `protobuf:"bytes,4,opt,name=payload,proto3" json:"payload,omitempty"` + Topics []string `protobuf:"bytes,2,rep,name=topics,proto3" json:"topics,omitempty"` + Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"` } func (x *Message) Reset() { @@ -70,13 +69,6 @@ func (x *Message) GetCommand() string { return "" } -func (x *Message) GetBroker() string { - if x != nil { - return x.Broker - } - return "" -} - func (x *Message) GetTopics() []string { if x != nil { return x.Topics @@ -91,6 +83,7 @@ func (x *Message) GetPayload() []byte { return nil } +// RPC request with messages type Request struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -138,6 +131,7 @@ func (x *Request) GetMessages() []*Message { return nil } +// RPC response (false in case of error) type Response struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -190,22 +184,20 @@ var File_websockets_proto protoreflect.FileDescriptor var file_websockets_proto_rawDesc = []byte{ 0x0a, 0x10, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x11, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x2e, 0x76, - 0x31, 0x62, 0x65, 0x74, 0x61, 0x22, 0x6d, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x31, 0x62, 0x65, 0x74, 0x61, 0x22, 0x55, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x62, 0x72, - 0x6f, 0x6b, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x62, 0x72, 0x6f, 0x6b, - 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x73, 0x18, 0x03, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x06, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, - 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, - 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x41, 0x0a, 0x07, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x36, 0x0a, 0x08, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x2e, 0x76, - 0x31, 0x62, 0x65, 0x74, 0x61, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x08, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x22, 0x1a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x6f, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x02, 0x6f, 0x6b, 0x42, 0x15, 0x5a, 0x13, 0x2e, 0x2f, 0x3b, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, - 0x6b, 0x65, 0x74, 0x73, 0x76, 0x31, 0x62, 0x65, 0x74, 0x61, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, + 0x70, 0x69, 0x63, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x74, 0x6f, 0x70, 0x69, + 0x63, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x41, 0x0a, 0x07, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x77, 0x65, 0x62, 0x73, + 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x2e, 0x76, 0x31, 0x62, 0x65, 0x74, 0x61, 0x2e, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x08, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x22, + 0x1a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x6f, + 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x02, 0x6f, 0x6b, 0x42, 0x15, 0x5a, 0x13, 0x2e, + 0x2f, 0x3b, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x76, 0x31, 0x62, 0x65, + 0x74, 0x61, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/proto/websockets/v1beta/websockets.proto b/proto/websockets/v1beta/websockets.proto similarity index 79% rename from pkg/proto/websockets/v1beta/websockets.proto rename to proto/websockets/v1beta/websockets.proto index ede3cde93..5be6f70fc 100644 --- a/pkg/proto/websockets/v1beta/websockets.proto +++ b/proto/websockets/v1beta/websockets.proto @@ -5,9 +5,8 @@ option go_package = "./;websocketsv1beta"; message Message { string command = 1; - string broker = 2; - repeated string topics = 3; - bytes payload = 4; + repeated string topics = 2; + bytes payload = 3; } // RPC request with messages diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 67d5476b2..b6ba0f661 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -9,3 +9,7 @@ services: image: redis:6 ports: - "6379:6379" + redis2: + image: redis:6 + ports: + - "6378:6379" diff --git a/tests/plugins/broadcast/broadcast_plugin_test.go b/tests/plugins/broadcast/broadcast_plugin_test.go new file mode 100644 index 000000000..2cd4b451c --- /dev/null +++ b/tests/plugins/broadcast/broadcast_plugin_test.go @@ -0,0 +1,474 @@ +package broadcast + +import ( + "net" + "net/rpc" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/golang/mock/gomock" + endure "github.com/spiral/endure/pkg/container" + goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/config" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/memory" + "github.com/spiral/roadrunner/v2/plugins/redis" + rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/spiral/roadrunner/v2/plugins/websockets" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/tests/mocks" + "github.com/spiral/roadrunner/v2/tests/plugins/broadcast/plugins" + "github.com/stretchr/testify/assert" +) + +func TestBroadcastInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-broadcast-init.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &broadcast.Plugin{}, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &redis.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, + &memory.Plugin{}, + ) + + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + if err != nil { + t.Fatal(err) + } + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + stopCh := make(chan struct{}, 1) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-stopCh: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + stopCh <- struct{}{} + + wg.Wait() +} + +func TestBroadcastConfigError(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-broadcast-config-error.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &broadcast.Plugin{}, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &redis.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, + &memory.Plugin{}, + + &plugins.Plugin1{}, + ) + + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + _, err = cont.Serve() + assert.Error(t, err) +} + +func TestBroadcastNoConfig(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-broadcast-no-config.yaml", + Prefix: "rr", + } + + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Debug("worker destructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("worker constructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("Started RPC service", "address", "tcp://127.0.0.1:6001", "services", []string{}).MinTimes(1) + + err = cont.RegisterAll( + cfg, + &broadcast.Plugin{}, + &rpcPlugin.Plugin{}, + mockLogger, + &server.Plugin{}, + &redis.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, + &memory.Plugin{}, + ) + + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + // should be just disabled + _, err = cont.Serve() + assert.NoError(t, err) +} + +func TestBroadcastSameSubscriber(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-broadcast-same-section.yaml", + Prefix: "rr", + } + + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Debug("worker destructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("worker constructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("Started RPC service", "address", "tcp://127.0.0.1:6002", "services", []string{"broadcast"}).MinTimes(1) + mockLogger.EXPECT().Debug("message published", "msg", gomock.Any()).MinTimes(1) + + mockLogger.EXPECT().Info(`plugin1: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin1: {foo2 hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin1: {foo3 hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin2: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin3: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin4: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin5: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin6: {foo hello}`).Times(3) + + err = cont.RegisterAll( + cfg, + &broadcast.Plugin{}, + &rpcPlugin.Plugin{}, + mockLogger, + &server.Plugin{}, + &redis.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, + &memory.Plugin{}, + + // test - redis + // test2 - redis (port 6378) + // test3 - memory + // test4 - memory + &plugins.Plugin1{}, // foo, foo2, foo3 test + &plugins.Plugin2{}, // foo, test + &plugins.Plugin3{}, // foo, test2 + &plugins.Plugin4{}, // foo, test3 + &plugins.Plugin5{}, // foo, test4 + &plugins.Plugin6{}, // foo, test3 + ) + + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + if err != nil { + t.Fatal(err) + } + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + stopCh := make(chan struct{}, 1) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-stopCh: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 2) + + t.Run("PublishHelloFooFoo2Foo3", BroadcastPublishFooFoo2Foo3("6002")) + t.Run("PublishHelloFoo2", BroadcastPublishFoo2("6002")) + t.Run("PublishHelloFoo3", BroadcastPublishFoo3("6002")) + t.Run("PublishAsyncHelloFooFoo2Foo3", BroadcastPublishAsyncFooFoo2Foo3("6002")) + + time.Sleep(time.Second * 4) + stopCh <- struct{}{} + + wg.Wait() +} + +func TestBroadcastSameSubscriberGlobal(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-broadcast-global.yaml", + Prefix: "rr", + } + + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Debug("worker destructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("worker constructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("Started RPC service", "address", "tcp://127.0.0.1:6003", "services", []string{"broadcast"}).MinTimes(1) + mockLogger.EXPECT().Debug("message published", "msg", gomock.Any()).MinTimes(1) + + mockLogger.EXPECT().Info(`plugin1: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin1: {foo2 hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin1: {foo3 hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin2: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin3: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin4: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin5: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin6: {foo hello}`).Times(3) + + err = cont.RegisterAll( + cfg, + &broadcast.Plugin{}, + &rpcPlugin.Plugin{}, + mockLogger, + &server.Plugin{}, + &redis.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, + &memory.Plugin{}, + + // test - redis + // test2 - redis (port 6378) + // test3 - memory + // test4 - memory + &plugins.Plugin1{}, // foo, foo2, foo3 test + &plugins.Plugin2{}, // foo, test + &plugins.Plugin3{}, // foo, test2 + &plugins.Plugin4{}, // foo, test3 + &plugins.Plugin5{}, // foo, test4 + &plugins.Plugin6{}, // foo, test3 + ) + + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + if err != nil { + t.Fatal(err) + } + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + stopCh := make(chan struct{}, 1) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-stopCh: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 2) + + t.Run("PublishHelloFooFoo2Foo3", BroadcastPublishFooFoo2Foo3("6003")) + t.Run("PublishHelloFoo2", BroadcastPublishFoo2("6003")) + t.Run("PublishHelloFoo3", BroadcastPublishFoo3("6003")) + t.Run("PublishAsyncHelloFooFoo2Foo3", BroadcastPublishAsyncFooFoo2Foo3("6003")) + + time.Sleep(time.Second * 4) + stopCh <- struct{}{} + + wg.Wait() +} + +func BroadcastPublishFooFoo2Foo3(port string) func(t *testing.T) { + return func(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:"+port) + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.Publish", makeMessage([]byte("hello"), "foo", "foo2", "foo3"), ret) + if err != nil { + t.Fatal(err) + } + } +} + +func BroadcastPublishFoo2(port string) func(t *testing.T) { + return func(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:"+port) + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.Publish", makeMessage([]byte("hello"), "foo"), ret) + if err != nil { + t.Fatal(err) + } + } +} + +func BroadcastPublishFoo3(port string) func(t *testing.T) { + return func(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:"+port) + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.Publish", makeMessage([]byte("hello"), "foo3"), ret) + if err != nil { + t.Fatal(err) + } + } +} +func BroadcastPublishAsyncFooFoo2Foo3(port string) func(t *testing.T) { + return func(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:"+port) + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.PublishAsync", makeMessage([]byte("hello"), "foo", "foo2", "foo3"), ret) + if err != nil { + t.Fatal(err) + } + } +} + +func makeMessage(payload []byte, topics ...string) *websocketsv1.Request { + m := &websocketsv1.Request{ + Messages: []*websocketsv1.Message{ + { + Topics: topics, + Payload: payload, + }, + }, + } + + return m +} diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-config-error.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-config-error.yaml new file mode 100644 index 000000000..d8daa2510 --- /dev/null +++ b/tests/plugins/broadcast/configs/.rr-broadcast-config-error.yaml @@ -0,0 +1,33 @@ +rpc: + listen: tcp://127.0.0.1:6001 + +server: + command: "php ../../psr-worker-bench.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:21345 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +broadcast: + default: + driver: redis + +logs: + mode: development + level: debug + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/websockets/configs/.rr-websockets-redis-memory.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-global.yaml similarity index 69% rename from tests/plugins/websockets/configs/.rr-websockets-redis-memory.yaml rename to tests/plugins/broadcast/configs/.rr-broadcast-global.yaml index eedf5377f..2ca970559 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-redis-memory.yaml +++ b/tests/plugins/broadcast/configs/.rr-broadcast-global.yaml @@ -1,5 +1,5 @@ rpc: - listen: tcp://127.0.0.1:6001 + listen: tcp://127.0.0.1:6003 server: command: "php ../../psr-worker-bench.php" @@ -9,7 +9,7 @@ server: relay_timeout: "20s" http: - address: 127.0.0.1:13235 + address: 127.0.0.1:21543 max_request_size: 1024 middleware: [ "websockets" ] trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] @@ -23,11 +23,18 @@ redis: addrs: - "localhost:6379" -websockets: - # pubsubs should implement PubSub interface to be collected via endure.Collects - # pubsubs might use general config section - pubsubs: [ "redis", "memory" ] - path: "/ws" +broadcast: + test: + driver: redis + test2: + driver: redis + addrs: + - "localhost:6378" + test3: + driver: memory + test4: + driver: memory + logs: mode: development diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-init.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-init.yaml new file mode 100644 index 000000000..aa80330e4 --- /dev/null +++ b/tests/plugins/broadcast/configs/.rr-broadcast-init.yaml @@ -0,0 +1,35 @@ +rpc: + listen: tcp://127.0.0.1:6001 + +server: + command: "php ../../psr-worker-bench.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:21345 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +broadcast: + default: + driver: redis + addrs: + - "localhost:6379" + +logs: + mode: development + level: error + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-no-config.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-no-config.yaml new file mode 100644 index 000000000..907908690 --- /dev/null +++ b/tests/plugins/broadcast/configs/.rr-broadcast-no-config.yaml @@ -0,0 +1,29 @@ +rpc: + listen: tcp://127.0.0.1:6001 + +server: + command: "php ../../psr-worker-bench.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:21345 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +logs: + mode: development + level: debug + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-same-section.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-same-section.yaml new file mode 100644 index 000000000..360e05e5e --- /dev/null +++ b/tests/plugins/broadcast/configs/.rr-broadcast-same-section.yaml @@ -0,0 +1,43 @@ +rpc: + listen: tcp://127.0.0.1:6002 + +server: + command: "php ../../psr-worker-bench.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:21345 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +broadcast: + test: + driver: redis + addrs: + - "localhost:6379" + test2: + driver: redis + addrs: + - "localhost:6378" + test3: + driver: memory + test4: + driver: memory + +logs: + mode: development + level: debug + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/broadcast/plugins/plugin1.go b/tests/plugins/broadcast/plugins/plugin1.go new file mode 100644 index 000000000..d3b16256e --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin1.go @@ -0,0 +1,67 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin1Name = "plugin1" + +type Plugin1 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin1) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin1) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test") + if err != nil { + errCh <- err + return errCh + } + + err = p.driver.Subscribe("1", "foo", "foo2", "foo3") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin1Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin1) Stop() error { + _ = p.driver.Unsubscribe("1", "foo") + _ = p.driver.Unsubscribe("1", "foo2") + _ = p.driver.Unsubscribe("1", "foo3") + return nil +} + +func (p *Plugin1) Name() string { + return Plugin1Name +} diff --git a/tests/plugins/broadcast/plugins/plugin2.go b/tests/plugins/broadcast/plugins/plugin2.go new file mode 100644 index 000000000..2bd819d2f --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin2.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin2Name = "plugin2" + +type Plugin2 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin2) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin2) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("2", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin2Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin2) Stop() error { + _ = p.driver.Unsubscribe("2", "foo") + return nil +} + +func (p *Plugin2) Name() string { + return Plugin2Name +} diff --git a/tests/plugins/broadcast/plugins/plugin3.go b/tests/plugins/broadcast/plugins/plugin3.go new file mode 100644 index 000000000..ef9262224 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin3.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin3Name = "plugin3" + +type Plugin3 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin3) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin3) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test2") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("3", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin3Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin3) Stop() error { + _ = p.driver.Unsubscribe("3", "foo") + return nil +} + +func (p *Plugin3) Name() string { + return Plugin3Name +} diff --git a/tests/plugins/broadcast/plugins/plugin4.go b/tests/plugins/broadcast/plugins/plugin4.go new file mode 100644 index 000000000..c9b947778 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin4.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin4Name = "plugin4" + +type Plugin4 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin4) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin4) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test3") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("4", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin4Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin4) Stop() error { + _ = p.driver.Unsubscribe("4", "foo") + return nil +} + +func (p *Plugin4) Name() string { + return Plugin4Name +} diff --git a/tests/plugins/broadcast/plugins/plugin5.go b/tests/plugins/broadcast/plugins/plugin5.go new file mode 100644 index 000000000..01562a8f8 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin5.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin5Name = "plugin5" + +type Plugin5 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin5) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin5) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test4") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("5", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin5Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin5) Stop() error { + _ = p.driver.Unsubscribe("5", "foo") + return nil +} + +func (p *Plugin5) Name() string { + return Plugin5Name +} diff --git a/tests/plugins/broadcast/plugins/plugin6.go b/tests/plugins/broadcast/plugins/plugin6.go new file mode 100644 index 000000000..76f2d6e81 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin6.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin6Name = "plugin6" + +type Plugin6 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin6) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin6) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("6", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin6Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin6) Stop() error { + _ = p.driver.Unsubscribe("6", "foo") + return nil +} + +func (p *Plugin6) Name() string { + return Plugin6Name +} diff --git a/tests/plugins/kv/storage_plugin_test.go b/tests/plugins/kv/storage_plugin_test.go index 24b66ae17..1e466e066 100644 --- a/tests/plugins/kv/storage_plugin_test.go +++ b/tests/plugins/kv/storage_plugin_test.go @@ -12,7 +12,6 @@ import ( endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - payload "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/kv/drivers/boltdb" @@ -21,6 +20,7 @@ import ( "github.com/spiral/roadrunner/v2/plugins/memory" "github.com/spiral/roadrunner/v2/plugins/redis" rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + payload "github.com/spiral/roadrunner/v2/proto/kv/v1beta" "github.com/stretchr/testify/assert" ) diff --git a/tests/plugins/websockets/configs/.rr-websockets-memory-allow.yaml b/tests/plugins/websockets/configs/.rr-websockets-allow.yaml similarity index 85% rename from tests/plugins/websockets/configs/.rr-websockets-memory-allow.yaml rename to tests/plugins/websockets/configs/.rr-websockets-allow.yaml index 896cee05c..e6c438579 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-memory-allow.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-allow.yaml @@ -9,7 +9,7 @@ server: relay_timeout: "20s" http: - address: 127.0.0.1:11113 + address: 127.0.0.1:41278 max_request_size: 1024 middleware: [ "websockets" ] trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] @@ -23,8 +23,13 @@ redis: addrs: - "localhost:6379" +broadcast: + test: + driver: memory + websockets: - pubsubs: [ "memory" ] + broker: test + allowed_origin: "*" path: "/ws" logs: diff --git a/tests/plugins/websockets/configs/.rr-websockets-allow2.yaml b/tests/plugins/websockets/configs/.rr-websockets-allow2.yaml new file mode 100644 index 000000000..d537a80b2 --- /dev/null +++ b/tests/plugins/websockets/configs/.rr-websockets-allow2.yaml @@ -0,0 +1,44 @@ +rpc: + listen: tcp://127.0.0.1:6001 + +server: + command: "php ../../worker-ok.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:41270 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +redis: + addrs: + - "localhost:6379" + +broadcast: + test: + driver: redis + addrs: + - "localhost:6379" + +websockets: + broker: test + allowed_origin: "*" + path: "/ws" + +logs: + mode: development + level: error + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/websockets/configs/.rr-websockets-redis-no-section.yaml b/tests/plugins/websockets/configs/.rr-websockets-broker-no-section.yaml similarity index 89% rename from tests/plugins/websockets/configs/.rr-websockets-redis-no-section.yaml rename to tests/plugins/websockets/configs/.rr-websockets-broker-no-section.yaml index fd1257947..ada238457 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-redis-no-section.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-broker-no-section.yaml @@ -19,9 +19,13 @@ http: allocate_timeout: 60s destroy_timeout: 60s +broadcast: + test1: + driver: no websockets: - pubsubs: [ "redis", "memory" ] + broker: test + allowed_origin: "*" path: "/ws" logs: diff --git a/tests/plugins/websockets/configs/.rr-websockets-memory-deny.yaml b/tests/plugins/websockets/configs/.rr-websockets-deny.yaml similarity index 84% rename from tests/plugins/websockets/configs/.rr-websockets-memory-deny.yaml rename to tests/plugins/websockets/configs/.rr-websockets-deny.yaml index e3bf52181..594a746de 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-memory-deny.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-deny.yaml @@ -9,7 +9,7 @@ server: relay_timeout: "20s" http: - address: 127.0.0.1:11112 + address: 127.0.0.1:15587 max_request_size: 1024 middleware: [ "websockets" ] trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] @@ -19,12 +19,13 @@ http: allocate_timeout: 60s destroy_timeout: 60s -redis: - addrs: - - "localhost:6379" +broadcast: + test: + driver: memory websockets: - pubsubs: [ "memory" ] + broker: test + allowed_origin: "*" path: "/ws" logs: diff --git a/tests/plugins/websockets/configs/.rr-websockets-deny2.yaml b/tests/plugins/websockets/configs/.rr-websockets-deny2.yaml new file mode 100644 index 000000000..4deea30a2 --- /dev/null +++ b/tests/plugins/websockets/configs/.rr-websockets-deny2.yaml @@ -0,0 +1,40 @@ +rpc: + listen: tcp://127.0.0.1:6001 + +server: + command: "php ../../worker-deny.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:15588 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +broadcast: + test: + driver: redis + addrs: + - "localhost:6379" + +websockets: + broker: test + allowed_origin: "*" + path: "/ws" + +logs: + mode: development + level: error + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/websockets/configs/.rr-websockets-init.yaml b/tests/plugins/websockets/configs/.rr-websockets-init.yaml index dc073be30..115f9a715 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-init.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-init.yaml @@ -19,16 +19,16 @@ http: allocate_timeout: 60s destroy_timeout: 60s -redis: - addrs: - - "localhost:6379" +broadcast: + default: + driver: memory websockets: - # pubsubs should implement PubSub interface to be collected via endure.Collects - # pubsubs might use general config section or its own - pubsubs: [ "redis" ] + broker: default + allowed_origin: "*" path: "/ws" + logs: mode: development level: error diff --git a/tests/plugins/websockets/configs/.rr-websockets-redis-memory-local.yaml b/tests/plugins/websockets/configs/.rr-websockets-redis.yaml similarity index 88% rename from tests/plugins/websockets/configs/.rr-websockets-redis-memory-local.yaml rename to tests/plugins/websockets/configs/.rr-websockets-redis.yaml index 27eab5579..3557f5f1d 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-redis-memory-local.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-redis.yaml @@ -19,13 +19,17 @@ http: allocate_timeout: 60s destroy_timeout: 60s - -websockets: - pubsubs: [ "redis", "memory" ] - redis: +redis: addrs: - "localhost:6379" +broadcast: + test: + driver: redis + +websockets: + broker: test + allowed_origin: "*" path: "/ws" logs: diff --git a/tests/plugins/websockets/configs/.rr-websockets-memory-stop.yaml b/tests/plugins/websockets/configs/.rr-websockets-stop.yaml similarity index 88% rename from tests/plugins/websockets/configs/.rr-websockets-memory-stop.yaml rename to tests/plugins/websockets/configs/.rr-websockets-stop.yaml index 0614f4e7f..5377aef23 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-memory-stop.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-stop.yaml @@ -19,12 +19,13 @@ http: allocate_timeout: 60s destroy_timeout: 60s -redis: - addrs: - - "localhost:6379" +broadcast: + test: + driver: memory websockets: - pubsubs: [ "memory" ] + broker: test + allowed_origin: "*" path: "/ws" logs: diff --git a/tests/plugins/websockets/websocket_plugin_test.go b/tests/plugins/websockets/websocket_plugin_test.go index 07ee5f122..5ed0c3f34 100644 --- a/tests/plugins/websockets/websocket_plugin_test.go +++ b/tests/plugins/websockets/websocket_plugin_test.go @@ -16,7 +16,7 @@ import ( json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/plugins/broadcast" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -25,11 +25,12 @@ import ( rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" "github.com/spiral/roadrunner/v2/plugins/server" "github.com/spiral/roadrunner/v2/plugins/websockets" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/utils" "github.com/stretchr/testify/assert" ) -func TestBroadcastInit(t *testing.T) { +func TestWebsocketsInit(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) @@ -47,6 +48,7 @@ func TestBroadcastInit(t *testing.T) { &websockets.Plugin{}, &httpPlugin.Plugin{}, &memory.Plugin{}, + &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -98,52 +100,20 @@ func TestBroadcastInit(t *testing.T) { time.Sleep(time.Second * 1) t.Run("TestWSInit", wsInit) + t.Run("RPCWsMemoryPubAsync", RPCWsPubAsync("11111")) + t.Run("RPCWsMemory", RPCWsPub("11111")) stopCh <- struct{}{} wg.Wait() } -func wsInit(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11111", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - _ = resp.Body.Close() - }() - - d, err := json.Marshal(messageWS("join", "memory", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) -} - -func TestWSRedisAndMemory(t *testing.T) { +func TestWSRedis(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-redis-memory.yaml", + Path: "configs/.rr-websockets-redis.yaml", Prefix: "rr", } @@ -155,7 +125,7 @@ func TestWSRedisAndMemory(t *testing.T) { &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, - &memory.Plugin{}, + &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -205,21 +175,20 @@ func TestWSRedisAndMemory(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryPubAsync", RPCWsMemoryPubAsync) - t.Run("RPCWsMemory", RPCWsMemory) - t.Run("RPCWsRedis", RPCWsRedis) + t.Run("RPCWsRedisPubAsync", RPCWsPubAsync("13235")) + t.Run("RPCWsRedisPub", RPCWsPub("13235")) stopCh <- struct{}{} wg.Wait() } -func TestWSRedisAndMemoryGlobal(t *testing.T) { +func TestWSRedisNoSection(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-redis-memory.yaml", + Path: "configs/.rr-websockets-broker-no-section.yaml", Prefix: "rr", } @@ -231,7 +200,37 @@ func TestWSRedisAndMemoryGlobal(t *testing.T) { &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, + &broadcast.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + _, err = cont.Serve() + assert.Error(t, err) +} + +func TestWSDeny(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-websockets-deny.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, &memory.Plugin{}, + &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -281,21 +280,19 @@ func TestWSRedisAndMemoryGlobal(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryPubAsync", RPCWsMemoryPubAsync) - t.Run("RPCWsMemory", RPCWsMemory) - t.Run("RPCWsRedis", RPCWsRedis) + t.Run("RPCWsMemoryDeny", RPCWsDeny("15587")) stopCh <- struct{}{} wg.Wait() } -func TestWSRedisNoSection(t *testing.T) { +func TestWSDeny2(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-redis-no-section.yaml", + Path: "configs/.rr-websockets-deny2.yaml", Prefix: "rr", } @@ -304,10 +301,10 @@ func TestWSRedisNoSection(t *testing.T) { &rpcPlugin.Plugin{}, &logger.ZapLogger{}, &server.Plugin{}, - &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, - &memory.Plugin{}, + &redis.Plugin{}, + &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -316,234 +313,60 @@ func TestWSRedisNoSection(t *testing.T) { t.Fatal(err) } - _, err = cont.Serve() - assert.Error(t, err) -} - -func RPCWsMemoryPubAsync(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:13235", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - _ = resp.Body.Close() - }() - - d, err := json.Marshal(messageWS("join", "memory", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - publishAsync(t, "", "memory", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "hello, PHP", retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", "memory", []byte("hello websockets"), "foo")) + ch, err := cont.Serve() if err != nil { - panic(err) + t.Fatal(err) } - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + wg := &sync.WaitGroup{} + wg.Add(1) - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publishAsync(t, "", "memory", "foo") + stopCh := make(chan struct{}, 1) go func() { - time.Sleep(time.Second * 5) - publishAsync2(t, "", "memory", "foo2") - }() - - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "hello, PHP2", retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) -} - -func RPCWsMemory(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:13235", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-stopCh: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } } }() - d, err := json.Marshal(messageWS("join", "memory", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - publish("", "memory", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "hello, PHP", retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", "memory", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) - - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publish("", "memory", "foo") - - go func() { - time.Sleep(time.Second * 5) - publish2(t, "", "memory", "foo2") - }() - - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "hello, PHP2", retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) -} - -func RPCWsRedis(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:13235", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - _ = resp.Body.Close() - }() - - d, err := json.Marshal(messageWS("join", "redis", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - publish("", "redis", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "hello, PHP", retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", "redis", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) - - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publish("", "redis", "foo") - - go func() { - time.Sleep(time.Second * 5) - publish2(t, "", "redis", "foo2") - }() + time.Sleep(time.Second * 1) + t.Run("RPCWsRedisDeny", RPCWsDeny("15588")) - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "hello, PHP2", retMsg) + stopCh <- struct{}{} - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) + wg.Wait() } -func TestWSMemoryDeny(t *testing.T) { +func TestWSStop(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-memory-deny.yaml", + Path: "configs/.rr-websockets-stop.yaml", Prefix: "rr", } @@ -556,6 +379,7 @@ func TestWSMemoryDeny(t *testing.T) { &websockets.Plugin{}, &httpPlugin.Plugin{}, &memory.Plugin{}, + &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -605,73 +429,40 @@ func TestWSMemoryDeny(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryDeny", RPCWsMemoryDeny) + t.Run("RPCWsStop", RPCWsMemoryStop("11114")) stopCh <- struct{}{} wg.Wait() } -func RPCWsMemoryDeny(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11112", Path: "/ws"} +func RPCWsMemoryStop(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 20, + } - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - assert.NotNil(t, c) - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} - defer func() { - if resp != nil && resp.Body != nil { + c, resp, err := da.Dial(connURL.String(), nil) + assert.NotNil(t, resp) + assert.Error(t, err) + assert.Nil(t, c) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) //nolint:staticcheck + assert.Equal(t, resp.Header.Get("Stop"), "we-dont-like-you") //nolint:staticcheck + if resp != nil && resp.Body != nil { //nolint:staticcheck _ = resp.Body.Close() } - }() - - d, err := json.Marshal(messageWS("join", "memory", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"#join","payload":["foo","foo2"]}`, retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", "memory", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) } -func TestWSMemoryStop(t *testing.T) { +func TestWSAllow(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-memory-stop.yaml", + Path: "configs/.rr-websockets-allow.yaml", Prefix: "rr", } @@ -684,6 +475,7 @@ func TestWSMemoryStop(t *testing.T) { &websockets.Plugin{}, &httpPlugin.Plugin{}, &memory.Plugin{}, + &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -733,38 +525,19 @@ func TestWSMemoryStop(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryStop", RPCWsMemoryStop) + t.Run("RPCWsMemoryAllow", RPCWsPub("41278")) stopCh <- struct{}{} wg.Wait() } -func RPCWsMemoryStop(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11114", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NotNil(t, resp) - assert.Error(t, err) - assert.Nil(t, c) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) //nolint:staticcheck - assert.Equal(t, resp.Header.Get("Stop"), "we-dont-like-you") //nolint:staticcheck - if resp != nil && resp.Body != nil { //nolint:staticcheck - _ = resp.Body.Close() - } -} - -func TestWSMemoryOk(t *testing.T) { +func TestWSAllow2(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-memory-allow.yaml", + Path: "configs/.rr-websockets-allow2.yaml", Prefix: "rr", } @@ -777,6 +550,7 @@ func TestWSMemoryOk(t *testing.T) { &websockets.Plugin{}, &httpPlugin.Plugin{}, &memory.Plugin{}, + &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -826,33 +600,29 @@ func TestWSMemoryOk(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryAllow", RPCWsMemoryAllow) + t.Run("RPCWsMemoryAllow", RPCWsPub("41270")) stopCh <- struct{}{} wg.Wait() } -func RPCWsMemoryAllow(t *testing.T) { +func wsInit(t *testing.T) { da := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: time.Second * 20, } - connURL := url.URL{Scheme: "ws", Host: "localhost:11113", Path: "/ws"} + connURL := url.URL{Scheme: "ws", Host: "localhost:11111", Path: "/ws"} c, resp, err := da.Dial(connURL.String(), nil) assert.NoError(t, err) - assert.NotNil(t, c) - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } + _ = resp.Body.Close() }() - d, err := json.Marshal(messageWS("join", "memory", []byte("hello websockets"), "foo", "foo2")) + d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) if err != nil { panic(err) } @@ -867,64 +637,219 @@ func RPCWsMemoryAllow(t *testing.T) { // subscription done assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - publish("", "memory", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) assert.NoError(t, err) - assert.Equal(t, "hello, PHP", retMsg) +} - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", "memory", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) - } +func RPCWsPubAsync(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 18, + } - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) + c, resp, err := da.Dial(connURL.String(), nil) + assert.NoError(t, err) - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + defer func() { + _ = resp.Body.Close() + }() - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publish("", "memory", "foo") + d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) + if err != nil { + panic(err) + } - go func() { - time.Sleep(time.Second * 5) - publish2(t, "", "memory", "foo2") - }() + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "hello, PHP2", retMsg) + _, msg, err := c.ReadMessage() + retMsg := utils.AsString(msg) + assert.NoError(t, err) - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) + // subscription done + assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) + + publishAsync(t, "placeholder", "foo") + + // VERIFY a makeMessage + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) + + // //// LEAVE foo ///////// + d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + + // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC + publishAsync(t, "placeholder", "foo") + + go func() { + time.Sleep(time.Second * 3) + publishAsync(t, "placeholder", "foo2") + }() + + // should be only makeMessage from the subscribed foo0 topic + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP\"}", retMsg) + + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) + assert.NoError(t, err) + } } -func publish(command string, broker string, topics ...string) { - conn, err := net.Dial("tcp", "127.0.0.1:6001") - if err != nil { - panic(err) +func RPCWsPub(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 20, + } + + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} + + c, resp, err := da.Dial(connURL.String(), nil) + assert.NoError(t, err) + + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err := c.ReadMessage() + retMsg := utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) + + publish("", "foo") + + // VERIFY a makeMessage + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) + + // //// LEAVE foo ///////// + d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + + // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC + publish("", "foo") + + go func() { + time.Sleep(time.Second * 5) + publish2(t, "", "foo2") + }() + + // should be only makeMessage from the subscribed foo2 topic + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP2\"}", retMsg) + + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) + assert.NoError(t, err) } +} - client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) +func RPCWsDeny(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 20, + } - ret := &websocketsv1.Response{} - err = client.Call("websockets.Publish", makeMessage(command, broker, []byte("hello, PHP"), topics...), ret) - if err != nil { - panic(err) + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} + + c, resp, err := da.Dial(connURL.String(), nil) + assert.NoError(t, err) + assert.NotNil(t, c) + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err := c.ReadMessage() + retMsg := utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"#join","payload":["foo","foo2"]}`, retMsg) + + // //// LEAVE foo, foo2 ///////// + d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) + assert.NoError(t, err) } } -func publishAsync(t *testing.T, command string, broker string, topics ...string) { +// --------------------------------------------------------------------------------------------------- + +func publish(command string, topics ...string) { conn, err := net.Dial("tcp", "127.0.0.1:6001") if err != nil { panic(err) @@ -933,12 +858,13 @@ func publishAsync(t *testing.T, command string, broker string, topics ...string) client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) ret := &websocketsv1.Response{} - err = client.Call("websockets.PublishAsync", makeMessage(command, broker, []byte("hello, PHP"), topics...), ret) - assert.NoError(t, err) - assert.True(t, ret.Ok) + err = client.Call("broadcast.Publish", makeMessage(command, []byte("hello, PHP"), topics...), ret) + if err != nil { + panic(err) + } } -func publishAsync2(t *testing.T, command string, broker string, topics ...string) { +func publishAsync(t *testing.T, command string, topics ...string) { conn, err := net.Dial("tcp", "127.0.0.1:6001") if err != nil { panic(err) @@ -947,12 +873,12 @@ func publishAsync2(t *testing.T, command string, broker string, topics ...string client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) ret := &websocketsv1.Response{} - err = client.Call("websockets.PublishAsync", makeMessage(command, broker, []byte("hello, PHP2"), topics...), ret) + err = client.Call("broadcast.PublishAsync", makeMessage(command, []byte("hello, PHP"), topics...), ret) assert.NoError(t, err) assert.True(t, ret.Ok) } -func publish2(t *testing.T, command string, broker string, topics ...string) { +func publish2(t *testing.T, command string, topics ...string) { conn, err := net.Dial("tcp", "127.0.0.1:6001") if err != nil { panic(err) @@ -961,27 +887,25 @@ func publish2(t *testing.T, command string, broker string, topics ...string) { client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) ret := &websocketsv1.Response{} - err = client.Call("websockets.Publish", makeMessage(command, broker, []byte("hello, PHP2"), topics...), ret) + err = client.Call("broadcast.Publish", makeMessage(command, []byte("hello, PHP2"), topics...), ret) assert.NoError(t, err) assert.True(t, ret.Ok) } -func messageWS(command string, broker string, payload []byte, topics ...string) *websocketsv1.Message { +func messageWS(command string, payload []byte, topics ...string) *websocketsv1.Message { return &websocketsv1.Message{ Topics: topics, Command: command, - Broker: broker, Payload: payload, } } -func makeMessage(command string, broker string, payload []byte, topics ...string) *websocketsv1.Request { +func makeMessage(command string, payload []byte, topics ...string) *websocketsv1.Request { m := &websocketsv1.Request{ Messages: []*websocketsv1.Message{ { Topics: topics, Command: command, - Broker: broker, Payload: payload, }, }, diff --git a/tests/worker-origin.php b/tests/worker-origin.php new file mode 100644 index 000000000..6ce4de59c --- /dev/null +++ b/tests/worker-origin.php @@ -0,0 +1,14 @@ +waitRequest()) { + $http->respond(200, 'Response', [ + 'Access-Control-Allow-Origin' => ['*'] + ]); +}