diff --git a/server.conf.in b/server.conf.in index a52090ad..6488e4e7 100644 --- a/server.conf.in +++ b/server.conf.in @@ -45,17 +45,21 @@ blockkey = -encryption-key- internalsecret = the-shared-secret-for-internal-clients [backend] -# Comma-separated list of hostnames that are allowed to be used as backend -# endpoints. -allowed = nextcloud.domain.invalid +# Comma-separated list of backend ids from which clients are allowed to connect +# from. Each backend will have isolated rooms, i.e. clients connecting to room +# "abc12345" on backend 1 will be in a different room than clients connected to +# a room with the same name on backend 2. Also sessions connected from different +# backends will not be able to communicate with each other. +#backends = backend-id, another-backend # Allow any hostname as backend endpoint. This is extremely insecure and should # only be used while running the benchmark client against the server. allowall = false -# Shared secret for requests from and to the backend servers. This must be the -# same value as configured in the Nextcloud admin ui. -secret = the-shared-secret +# Common shared secret for requests from and to the backend servers if +# "allowall" is enabled. This must be the same value as configured in the +# Nextcloud admin ui. +# secret = the-shared-secret # Timeout in seconds for requests to the backend. timeout = 10 @@ -68,6 +72,24 @@ connectionsperhost = 8 # certificates. #skipverify = false +# Backend configurations as defined in the "[backend]" section above. The +# section names must match the ids used in "backends" above. +#[backend-id] +# URL of the Nextcloud instance +#url = https://cloud.domain.invalid + +# Shared secret for requests from and to the backend servers. This must be the +# same value as configured in the Nextcloud admin ui. +#secret = the-shared-secret + +#[another-backend] +# URL of the Nextcloud instance +#url = https://cloud.otherdomain.invalid + +# Shared secret for requests from and to the backend servers. This must be the +# same value as configured in the Nextcloud admin ui. +#secret = the-shared-secret + [nats] # Url of NATS backend to use. This can also be a list of URLs to connect to # multiple backends. For local development, this can be set to ":loopback:" diff --git a/src/signaling/api_backend.go b/src/signaling/api_backend.go index ad227697..37d2e5e4 100644 --- a/src/signaling/api_backend.go +++ b/src/signaling/api_backend.go @@ -36,6 +36,7 @@ const ( HeaderBackendSignalingRandom = "Spreed-Signaling-Random" HeaderBackendSignalingChecksum = "Spreed-Signaling-Checksum" + HeaderBackendServer = "Spreed-Signaling-Backend" ) func newRandomString(length int) string { diff --git a/src/signaling/api_signaling.go b/src/signaling/api_signaling.go index cf4b587e..5f43969e 100644 --- a/src/signaling/api_signaling.go +++ b/src/signaling/api_signaling.go @@ -196,6 +196,20 @@ const ( type ClientTypeInternalAuthParams struct { Random string `json:"random"` Token string `json:"token"` + + Backend string `json:"backend"` + parsedBackend *url.URL +} + +func (p *ClientTypeInternalAuthParams) CheckValid() error { + if p.Backend == "" { + return fmt.Errorf("backend missing") + } else if u, err := url.Parse(p.Backend); err != nil { + return err + } else { + p.parsedBackend = u + } + return nil } type HelloClientMessageAuth struct { @@ -247,6 +261,8 @@ func (m *HelloClientMessage) CheckValid() error { case HelloClientTypeInternal: if err := json.Unmarshal(*m.Auth.Params, &m.Auth.internalParams); err != nil { return err + } else if err := m.Auth.internalParams.CheckValid(); err != nil { + return err } default: return fmt.Errorf("unsupported auth type") diff --git a/src/signaling/api_signaling_test.go b/src/signaling/api_signaling_test.go index 92f04f3d..2ddbd185 100644 --- a/src/signaling/api_signaling_test.go +++ b/src/signaling/api_signaling_test.go @@ -87,6 +87,7 @@ func TestClientMessage(t *testing.T) { } func TestHelloClientMessage(t *testing.T) { + internalAuthParams := []byte("{\"backend\":\"https://domain.invalid\"}") valid_messages := []testCheckValid{ &HelloClientMessage{ Version: HelloVersion, @@ -107,7 +108,7 @@ func TestHelloClientMessage(t *testing.T) { Version: HelloVersion, Auth: HelloClientMessageAuth{ Type: "internal", - Params: &json.RawMessage{'{', '}'}, + Params: (*json.RawMessage)(&internalAuthParams), }, }, &HelloClientMessage{ @@ -145,6 +146,13 @@ func TestHelloClientMessage(t *testing.T) { Url: "invalid-url", }, }, + &HelloClientMessage{ + Version: HelloVersion, + Auth: HelloClientMessageAuth{ + Type: "internal", + Params: &json.RawMessage{'{', '}'}, + }, + }, &HelloClientMessage{ Version: HelloVersion, Auth: HelloClientMessageAuth{ diff --git a/src/signaling/backend_client.go b/src/signaling/backend_client.go index ba08c9c0..993394a2 100644 --- a/src/signaling/backend_client.go +++ b/src/signaling/backend_client.go @@ -44,11 +44,9 @@ var ( ) type BackendClient struct { - transport *http.Transport - whitelist map[string]bool - whitelistAll bool - secret []byte - version string + transport *http.Transport + version string + backends *BackendConfiguration mu sync.Mutex @@ -57,31 +55,9 @@ type BackendClient struct { } func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string) (*BackendClient, error) { - whitelist := make(map[string]bool) - whitelistAll, _ := config.GetBool("backend", "allowall") - if whitelistAll { - log.Println("WARNING: All backend hostnames are allowed, only use for development!") - } else { - urls, _ := config.GetString("backend", "allowed") - for _, u := range strings.Split(urls, ",") { - u = strings.TrimSpace(u) - if idx := strings.IndexByte(u, '/'); idx != -1 { - log.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u) - u = u[:idx] - } - if u != "" { - whitelist[strings.ToLower(u)] = true - } - } - if len(whitelist) == 0 { - log.Println("WARNING: No backend hostnames are allowed, check your configuration!") - } else { - hosts := make([]string, 0, len(whitelist)) - for u := range whitelist { - hosts = append(hosts, u) - } - log.Printf("Allowed backend hostnames: %s\n", hosts) - } + backends, err := NewBackendConfiguration(config) + if err != nil { + return nil, err } skipverify, _ := config.GetBool("backend", "skipverify") @@ -89,8 +65,6 @@ func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost in log.Println("WARNING: Backend verification is disabled!") } - secret, _ := config.GetString("backend", "secret") - tlsconfig := &tls.Config{ InsecureSkipVerify: skipverify, } @@ -100,11 +74,9 @@ func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost in } return &BackendClient{ - transport: transport, - whitelist: whitelist, - whitelistAll: whitelistAll, - secret: []byte(secret), - version: version, + transport: transport, + version: version, + backends: backends, maxConcurrentRequestsPerHost: maxConcurrentRequestsPerHost, clients: make(map[string]*HttpClientPool), @@ -135,13 +107,20 @@ func (b *BackendClient) getPool(url *url.URL) (*HttpClientPool, error) { return pool, nil } -func (b *BackendClient) IsUrlAllowed(u *url.URL) bool { - if u == nil { - // Reject all invalid URLs. - return false - } +func (b *BackendClient) GetCompatBackend() *Backend { + return b.backends.GetCompatBackend() +} + +func (b *BackendClient) GetBackend(u *url.URL) *Backend { + return b.backends.GetBackend(u) +} - return b.whitelistAll || b.whitelist[u.Host] +func (b *BackendClient) GetBackends() []*Backend { + return b.backends.GetBackends() +} + +func (b *BackendClient) IsUrlAllowed(u *url.URL) bool { + return b.backends.IsUrlAllowed(u) } func isOcsRequest(u *url.URL) bool { @@ -304,6 +283,11 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ return fmt.Errorf("No url passed to perform JSON request %+v", request) } + secret := b.backends.GetSecret(u) + if secret == nil { + return fmt.Errorf("No backend secret configured for for %s", u) + } + pool, err := b.getPool(u) if err != nil { log.Printf("Could not get client pool for host %s: %s\n", u.Host, err) @@ -338,7 +322,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+b.version) // Add checksum so the backend can validate the request. - AddBackendChecksum(req, data, b.secret) + AddBackendChecksum(req, data, secret) resp, err := performRequestWithRedirects(ctx, c, req, data) if err != nil { diff --git a/src/signaling/backend_client_test.go b/src/signaling/backend_client_test.go index e4e9da87..073ed042 100644 --- a/src/signaling/backend_client_test.go +++ b/src/signaling/backend_client_test.go @@ -35,71 +35,6 @@ import ( "golang.org/x/net/context" ) -func testUrls(t *testing.T, client *BackendClient, valid_urls []string, invalid_urls []string) { - for _, u := range valid_urls { - parsed, err := url.ParseRequestURI(u) - if err != nil { - t.Errorf("The url %s should be valid, got %s", u, err) - continue - } - if !client.IsUrlAllowed(parsed) { - t.Errorf("The url %s should be allowed", u) - } - } - for _, u := range invalid_urls { - parsed, _ := url.ParseRequestURI(u) - if client.IsUrlAllowed(parsed) { - t.Errorf("The url %s should not be allowed", u) - } - } -} - -func TestIsUrlAllowed(t *testing.T) { - valid_urls := []string{ - "http://domain.invalid", - "https://domain.invalid", - } - invalid_urls := []string{ - "http://otherdomain.invalid", - "https://otherdomain.invalid", - "domain.invalid", - } - client := &BackendClient{ - whitelistAll: false, - whitelist: map[string]bool{ - "domain.invalid": true, - }, - } - testUrls(t, client, valid_urls, invalid_urls) -} - -func TestIsUrlAllowed_EmptyWhitelist(t *testing.T) { - valid_urls := []string{} - invalid_urls := []string{ - "http://domain.invalid", - "https://domain.invalid", - "domain.invalid", - } - client := &BackendClient{ - whitelistAll: false, - } - testUrls(t, client, valid_urls, invalid_urls) -} - -func TestIsUrlAllowed_WhitelistAll(t *testing.T) { - valid_urls := []string{ - "http://domain.invalid", - "https://domain.invalid", - } - invalid_urls := []string{ - "domain.invalid", - } - client := &BackendClient{ - whitelistAll: true, - } - testUrls(t, client, valid_urls, invalid_urls) -} - func TestPostOnRedirect(t *testing.T) { r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { @@ -142,17 +77,20 @@ func TestPostOnRedirect(t *testing.T) { server := httptest.NewServer(r) defer server.Close() - config := &goconf.ConfigFile{} - client, err := NewBackendClient(config, 1, "0.0") + u, err := url.Parse(server.URL + "/ocs/v2.php/one") if err != nil { t.Fatal(err) } - ctx := context.Background() - u, err := url.Parse(server.URL + "/ocs/v2.php/one") + config := goconf.NewConfigFile() + config.AddOption("backend", "allowed", u.Host) + config.AddOption("backend", "secret", string(testBackendSecret)) + client, err := NewBackendClient(config, 1, "0.0") if err != nil { t.Fatal(err) } + + ctx := context.Background() request := map[string]string{ "foo": "bar", } diff --git a/src/signaling/backend_configuration.go b/src/signaling/backend_configuration.go new file mode 100644 index 00000000..ffc547ce --- /dev/null +++ b/src/signaling/backend_configuration.go @@ -0,0 +1,212 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "log" + "net/url" + "strings" + + "github.com/dlintw/goconf" +) + +type Backend struct { + id string + url string + secret []byte + compat bool +} + +func (b *Backend) Id() string { + return b.id +} + +func (b *Backend) Secret() []byte { + return b.secret +} + +func (b *Backend) IsCompat() bool { + return b.compat +} + +type BackendConfiguration struct { + backends map[string][]*Backend + + // Deprecated + allowAll bool + commonSecret []byte + compatBackend *Backend +} + +func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) { + allowAll, _ := config.GetBool("backend", "allowall") + commonSecret, _ := config.GetString("backend", "secret") + backends := make(map[string][]*Backend) + var compatBackend *Backend + if allowAll { + log.Println("WARNING: All backend hostnames are allowed, only use for development!") + compatBackend = &Backend{ + id: "compat", + secret: []byte(commonSecret), + compat: true, + } + } else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" { + seenIds := make(map[string]bool) + for _, id := range strings.Split(backendIds, ",") { + id = strings.TrimSpace(id) + if id == "" { + continue + } + + if seenIds[id] { + continue + } + seenIds[id] = true + + u, _ := config.GetString(id, "url") + secret, _ := config.GetString(id, "secret") + if u == "" || secret == "" { + log.Printf("Backend %s is missing or incomplete, skipping", id) + continue + } + + if u[len(u)-1] != '/' { + u += "/" + } + parsed, err := url.Parse(u) + if err != nil { + log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err) + continue + } + + backends[parsed.Host] = append(backends[parsed.Host], &Backend{ + id: id, + url: u, + secret: []byte(secret), + }) + log.Printf("Backend %s added for %s", id, u) + } + } else if allowedUrls, _ := config.GetString("backend", "allowed"); allowedUrls != "" { + // Old-style configuration, only hosts are configured and are using a common secret. + allowMap := make(map[string]bool) + for _, u := range strings.Split(allowedUrls, ",") { + u = strings.TrimSpace(u) + if idx := strings.IndexByte(u, '/'); idx != -1 { + log.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u) + u = u[:idx] + } + if u != "" { + allowMap[strings.ToLower(u)] = true + } + } + + if len(allowMap) == 0 { + log.Println("WARNING: No backend hostnames are allowed, check your configuration!") + } else { + compatBackend = &Backend{ + id: "compat", + secret: []byte(commonSecret), + compat: true, + } + hosts := make([]string, 0, len(allowMap)) + for host := range allowMap { + hosts = append(hosts, host) + backends[host] = []*Backend{compatBackend} + } + if len(hosts) > 1 { + log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.") + } + log.Printf("Allowed backend hostnames: %s\n", hosts) + } + } + + return &BackendConfiguration{ + backends: backends, + + allowAll: allowAll, + commonSecret: []byte(commonSecret), + compatBackend: compatBackend, + }, nil +} + +func (b *BackendConfiguration) GetCompatBackend() *Backend { + return b.compatBackend +} + +func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend { + entries, found := b.backends[u.Host] + if !found { + if b.allowAll { + return b.compatBackend + } + return nil + } + + s := u.String() + if s[len(s)-1] != '/' { + s += "/" + } + for _, entry := range entries { + if entry.url == "" { + // Old-style configuration, only hosts are configured. + return entry + } else if strings.HasPrefix(s, entry.url) { + return entry + } + } + + return nil +} + +func (b *BackendConfiguration) GetBackends() []*Backend { + var result []*Backend + for _, entries := range b.backends { + for _, entry := range entries { + result = append(result, entry) + } + } + return result +} + +func (b *BackendConfiguration) IsUrlAllowed(u *url.URL) bool { + if u == nil { + // Reject all invalid URLs. + return false + } + + backend := b.GetBackend(u) + return backend != nil +} + +func (b *BackendConfiguration) GetSecret(u *url.URL) []byte { + if u == nil { + // Reject all invalid URLs. + return nil + } + + entry := b.GetBackend(u) + if entry == nil { + return nil + } + + return entry.secret +} diff --git a/src/signaling/backend_configuration_test.go b/src/signaling/backend_configuration_test.go new file mode 100644 index 00000000..2aea66cf --- /dev/null +++ b/src/signaling/backend_configuration_test.go @@ -0,0 +1,165 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "bytes" + "net/url" + "testing" + + "github.com/dlintw/goconf" +) + +func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) { + for _, u := range valid_urls { + parsed, err := url.ParseRequestURI(u) + if err != nil { + t.Errorf("The url %s should be valid, got %s", u, err) + continue + } + if !config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should be allowed", u) + } + if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) { + t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret)) + } + } + for _, u := range invalid_urls { + parsed, _ := url.ParseRequestURI(u) + if config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should not be allowed", u) + } + } +} + +func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]string, invalid_urls []string) { + for _, entry := range valid_urls { + u := entry[0] + parsed, err := url.ParseRequestURI(u) + if err != nil { + t.Errorf("The url %s should be valid, got %s", u, err) + continue + } + if !config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should be allowed", u) + } + s := entry[1] + if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) { + t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret)) + } + } + for _, u := range invalid_urls { + parsed, _ := url.ParseRequestURI(u) + if config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should not be allowed", u) + } + } +} + +func TestIsUrlAllowed_Compat(t *testing.T) { + // Old-style configuration + valid_urls := []string{ + "http://domain.invalid", + "https://domain.invalid", + } + invalid_urls := []string{ + "http://otherdomain.invalid", + "https://otherdomain.invalid", + "domain.invalid", + } + config := goconf.NewConfigFile() + config.AddOption("backend", "allowed", "domain.invalid") + config.AddOption("backend", "secret", string(testBackendSecret)) + cfg, err := NewBackendConfiguration(config) + if err != nil { + t.Fatal(err) + } + testUrls(t, cfg, valid_urls, invalid_urls) +} + +func TestIsUrlAllowed(t *testing.T) { + valid_urls := [][]string{ + []string{"https://domain.invalid/foo", string(testBackendSecret) + "-foo"}, + []string{"https://domain.invalid/foo/", string(testBackendSecret) + "-foo"}, + []string{"https://domain.invalid/foo/folder", string(testBackendSecret) + "-foo"}, + []string{"https://domain.invalid/bar", string(testBackendSecret) + "-bar"}, + []string{"https://domain.invalid/bar/", string(testBackendSecret) + "-bar"}, + []string{"https://domain.invalid/bar/folder/", string(testBackendSecret) + "-bar"}, + []string{"https://otherdomain.invalid/", string(testBackendSecret) + "-lala"}, + []string{"https://otherdomain.invalid/folder/", string(testBackendSecret) + "-lala"}, + } + invalid_urls := []string{ + "https://domain.invalid", + "https://domain.invalid/", + "https://www.domain.invalid/foo/", + "https://domain.invalid/baz/", + } + config := goconf.NewConfigFile() + config.AddOption("backend", "backends", "foo, bar, lala, missing") + config.AddOption("foo", "url", "https://domain.invalid/foo") + config.AddOption("foo", "secret", string(testBackendSecret)+"-foo") + config.AddOption("bar", "url", "https://domain.invalid/bar/") + config.AddOption("bar", "secret", string(testBackendSecret)+"-bar") + config.AddOption("lala", "url", "https://otherdomain.invalid/") + config.AddOption("lala", "secret", string(testBackendSecret)+"-lala") + cfg, err := NewBackendConfiguration(config) + if err != nil { + t.Fatal(err) + } + testBackends(t, cfg, valid_urls, invalid_urls) +} + +func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) { + valid_urls := []string{} + invalid_urls := []string{ + "http://domain.invalid", + "https://domain.invalid", + "domain.invalid", + } + config := goconf.NewConfigFile() + config.AddOption("backend", "allowed", "") + config.AddOption("backend", "secret", string(testBackendSecret)) + cfg, err := NewBackendConfiguration(config) + if err != nil { + t.Fatal(err) + } + testUrls(t, cfg, valid_urls, invalid_urls) +} + +func TestIsUrlAllowed_AllowAll(t *testing.T) { + valid_urls := []string{ + "http://domain.invalid", + "https://domain.invalid", + } + invalid_urls := []string{ + "domain.invalid", + } + config := goconf.NewConfigFile() + config.AddOption("backend", "allowall", "true") + config.AddOption("backend", "allowed", "") + config.AddOption("backend", "secret", string(testBackendSecret)) + cfg, err := NewBackendConfiguration(config) + if err != nil { + t.Fatal(err) + } + testUrls(t, cfg, valid_urls, invalid_urls) +} diff --git a/src/signaling/backend_server.go b/src/signaling/backend_server.go index 1a8d939f..d84a3935 100644 --- a/src/signaling/backend_server.go +++ b/src/signaling/backend_server.go @@ -23,6 +23,7 @@ package signaling import ( "crypto/hmac" + "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/json" @@ -32,6 +33,7 @@ import ( "log" "net" "net/http" + "net/url" "reflect" "strings" "sync" @@ -57,19 +59,16 @@ type BackendServer struct { version string welcomeMessage string - secret []byte - turnapikey string turnsecret []byte turnvalid time.Duration turnservers []string statsAllowedIps map[string]bool + invalidSecret []byte } func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*BackendServer, error) { - secret, _ := config.GetString("backend", "secret") - turnapikey, _ := config.GetString("turn", "apikey") turnsecret, _ := config.GetString("turn", "secret") turnservers, _ := config.GetString("turn", "servers") @@ -117,21 +116,24 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac } } + invalidSecret := make([]byte, 32) + if _, err := rand.Read(invalidSecret); err != nil { + return nil, err + } + return &BackendServer{ hub: hub, nats: hub.nats, roomSessions: hub.roomSessions, version: version, - secret: []byte(secret), - - turnapikey: turnapikey, - + turnapikey: turnapikey, turnsecret: []byte(turnsecret), turnvalid: turnvalid, turnservers: turnserverslist, statsAllowedIps: statsAllowedIps, + invalidSecret: invalidSecret, }, nil } @@ -148,7 +150,7 @@ func (b *BackendServer) Start(r *mux.Router) error { } s := r.PathPrefix("/api/v1").Subrouter() s.HandleFunc("/welcome", b.setComonHeaders(b.welcomeFunc)).Methods("GET") - s.HandleFunc("/room/{roomid}", b.setComonHeaders(b.validateBackendRequest(b.roomHandler))).Methods("POST") + s.HandleFunc("/room/{roomid}", b.setComonHeaders(b.parseRequestBody(b.roomHandler))).Methods("POST") s.HandleFunc("/stats", b.setComonHeaders(b.validateStatsRequest(b.statsHandler))).Methods("GET") // Provide a REST service to get TURN credentials. @@ -236,7 +238,7 @@ func (b *BackendServer) getTurnCredentials(w http.ResponseWriter, r *http.Reques w.Write(data) } -func (b *BackendServer) validateBackendRequest(f func(http.ResponseWriter, *http.Request, []byte)) func(http.ResponseWriter, *http.Request) { +func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Request, []byte)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { // Sanity checks if r.ContentLength == -1 { @@ -265,16 +267,12 @@ func (b *BackendServer) validateBackendRequest(f func(http.ResponseWriter, *http http.Error(w, "Could not read body", http.StatusBadRequest) return } - if !ValidateBackendChecksum(r, body, b.secret) { - http.Error(w, "Authentication check failed", http.StatusForbidden) - return - } f(w, r, body) } } -func (b *BackendServer) sendRoomInvite(roomid string, userids []string, properties *json.RawMessage) { +func (b *BackendServer) sendRoomInvite(roomid string, backend *Backend, userids []string, properties *json.RawMessage) { msg := &ServerMessage{ Type: "event", Event: &EventServerMessage{ @@ -287,11 +285,11 @@ func (b *BackendServer) sendRoomInvite(roomid string, userids []string, properti }, } for _, userid := range userids { - b.nats.PublishMessage(GetSubjectForUserId(userid), msg) + b.nats.PublishMessage(GetSubjectForUserId(userid, backend), msg) } } -func (b *BackendServer) sendRoomDisinvite(roomid string, reason string, userids []string, sessionids []string) { +func (b *BackendServer) sendRoomDisinvite(roomid string, backend *Backend, reason string, userids []string, sessionids []string) { msg := &ServerMessage{ Type: "event", Event: &EventServerMessage{ @@ -306,7 +304,7 @@ func (b *BackendServer) sendRoomDisinvite(roomid string, reason string, userids }, } for _, userid := range userids { - b.nats.PublishMessage(GetSubjectForUserId(userid), msg) + b.nats.PublishMessage(GetSubjectForUserId(userid, backend), msg) } timeout := time.Second @@ -330,7 +328,7 @@ func (b *BackendServer) sendRoomDisinvite(roomid string, reason string, userids wg.Wait() } -func (b *BackendServer) sendRoomUpdate(roomid string, notified_userids []string, all_userids []string, properties *json.RawMessage) { +func (b *BackendServer) sendRoomUpdate(roomid string, backend *Backend, notified_userids []string, all_userids []string, properties *json.RawMessage) { msg := &ServerMessage{ Type: "event", Event: &EventServerMessage{ @@ -352,7 +350,7 @@ func (b *BackendServer) sendRoomUpdate(roomid string, notified_userids []string, continue } - b.nats.PublishMessage(GetSubjectForUserId(userid), msg) + b.nats.PublishMessage(GetSubjectForUserId(userid, backend), msg) } } @@ -431,7 +429,7 @@ func (b *BackendServer) fixupUserSessions(cache *ConcurrentStringStringMap, user return result } -func (b *BackendServer) sendRoomIncall(roomid string, request *BackendServerRoomRequest) error { +func (b *BackendServer) sendRoomIncall(roomid string, backend *Backend, request *BackendServerRoomRequest) error { timeout := time.Second var cache ConcurrentStringStringMap @@ -444,10 +442,10 @@ func (b *BackendServer) sendRoomIncall(roomid string, request *BackendServerRoom return nil } - return b.nats.PublishBackendServerRoomRequest("backend.room."+roomid, request) + return b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), request) } -func (b *BackendServer) sendRoomParticipantsUpdate(roomid string, request *BackendServerRoomRequest) error { +func (b *BackendServer) sendRoomParticipantsUpdate(roomid string, backend *Backend, request *BackendServerRoomRequest) error { timeout := time.Second // Convert (Nextcloud) session ids to signaling session ids. @@ -497,14 +495,57 @@ loop: } wg.Wait() - return b.nats.PublishBackendServerRoomRequest("backend.room."+roomid, request) + return b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), request) } -func (b *BackendServer) sendRoomMessage(roomid string, request *BackendServerRoomRequest) error { - return b.nats.PublishBackendServerRoomRequest("backend.room."+roomid, request) +func (b *BackendServer) sendRoomMessage(roomid string, backend *Backend, request *BackendServerRoomRequest) error { + return b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), request) } func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body []byte) { + v := mux.Vars(r) + roomid := v["roomid"] + + var backend *Backend + backendUrl := r.Header.Get(HeaderBackendServer) + if backendUrl != "" { + if u, err := url.Parse(backendUrl); err == nil { + backend = b.hub.backend.GetBackend(u) + } + + if backend == nil { + // Unknown backend URL passed, return immediately. + http.Error(w, "Authentication check failed", http.StatusForbidden) + return + } + } + + if backend == nil { + if compatBackend := b.hub.backend.GetCompatBackend(); compatBackend != nil { + // Old-style configuration using a single secret for all backends. + backend = compatBackend + } else { + // Old-style Talk, find backend that created the checksum. + // TODO(fancycode): Remove once all supported Talk versions send the backend header. + for _, b := range b.hub.backend.GetBackends() { + if ValidateBackendChecksum(r, body, b.Secret()) { + backend = b + break + } + } + } + + if backend == nil { + http.Error(w, "Authentication check failed", http.StatusForbidden) + return + } + } + + if !ValidateBackendChecksum(r, body, backend.Secret()) { + http.Error(w, "Authentication check failed", http.StatusForbidden) + return + } + var request BackendServerRoomRequest if err := json.Unmarshal(body, &request); err != nil { log.Printf("Error decoding body %s: %s\n", string(body), err) @@ -514,28 +555,26 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body request.ReceivedTime = time.Now().UnixNano() - v := mux.Vars(r) - roomid := v["roomid"] var err error switch request.Type { case "invite": - b.sendRoomInvite(roomid, request.Invite.UserIds, request.Invite.Properties) - b.sendRoomUpdate(roomid, request.Invite.UserIds, request.Invite.AllUserIds, request.Invite.Properties) + b.sendRoomInvite(roomid, backend, request.Invite.UserIds, request.Invite.Properties) + b.sendRoomUpdate(roomid, backend, request.Invite.UserIds, request.Invite.AllUserIds, request.Invite.Properties) case "disinvite": - b.sendRoomDisinvite(roomid, DisinviteReasonDisinvited, request.Disinvite.UserIds, request.Disinvite.SessionIds) - b.sendRoomUpdate(roomid, request.Disinvite.UserIds, request.Disinvite.AllUserIds, request.Disinvite.Properties) + b.sendRoomDisinvite(roomid, backend, DisinviteReasonDisinvited, request.Disinvite.UserIds, request.Disinvite.SessionIds) + b.sendRoomUpdate(roomid, backend, request.Disinvite.UserIds, request.Disinvite.AllUserIds, request.Disinvite.Properties) case "update": - err = b.nats.PublishBackendServerRoomRequest("backend.room."+roomid, &request) - b.sendRoomUpdate(roomid, nil, request.Update.UserIds, request.Update.Properties) + err = b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), &request) + b.sendRoomUpdate(roomid, backend, nil, request.Update.UserIds, request.Update.Properties) case "delete": - err = b.nats.PublishBackendServerRoomRequest("backend.room."+roomid, &request) - b.sendRoomDisinvite(roomid, DisinviteReasonDeleted, request.Delete.UserIds, nil) + err = b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), &request) + b.sendRoomDisinvite(roomid, backend, DisinviteReasonDeleted, request.Delete.UserIds, nil) case "incall": - err = b.sendRoomIncall(roomid, &request) + err = b.sendRoomIncall(roomid, backend, &request) case "participants": - err = b.sendRoomParticipantsUpdate(roomid, &request) + err = b.sendRoomParticipantsUpdate(roomid, backend, &request) case "message": - err = b.sendRoomMessage(roomid, &request) + err = b.sendRoomMessage(roomid, backend, &request) default: http.Error(w, "Unsupported request type: "+request.Type, http.StatusBadRequest) return diff --git a/src/signaling/backend_server_test.go b/src/signaling/backend_server_test.go index d3098bcf..c800906f 100644 --- a/src/signaling/backend_server_test.go +++ b/src/signaling/backend_server_test.go @@ -123,6 +123,7 @@ func performBackendRequest(url string, body []byte) (*http.Response, error) { check := CalculateBackendChecksum(rnd, body, testBackendSecret) request.Header.Set("Spreed-Signaling-Random", rnd) request.Header.Set("Spreed-Signaling-Checksum", check) + request.Header.Set("Spreed-Signaling-Backend", url) client := &http.Client{} return client.Do(request) } @@ -212,6 +213,56 @@ func TestBackendServer_InvalidAuth(t *testing.T) { } } +func TestBackendServer_OldCompatAuth(t *testing.T) { + _, _, _, _, _, server, shutdown := CreateBackendServerForTest(t) + defer shutdown() + + roomId := "the-room-id" + userid := "the-user-id" + roomProperties := json.RawMessage("{\"foo\":\"bar\"}") + msg := &BackendServerRoomRequest{ + Type: "invite", + Invite: &BackendRoomInviteRequest{ + UserIds: []string{ + userid, + }, + AllUserIds: []string{ + userid, + }, + Properties: &roomProperties, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatal(err) + } + + request, err := http.NewRequest("POST", server.URL+"/api/v1/room/"+roomId, bytes.NewReader(data)) + if err != nil { + t.Fatal(err) + } + request.Header.Set("Content-Type", "application/json") + rnd := newRandomString(32) + check := CalculateBackendChecksum(rnd, data, testBackendSecret) + request.Header.Set("Spreed-Signaling-Random", rnd) + request.Header.Set("Spreed-Signaling-Checksum", check) + client := &http.Client{} + res, err := client.Do(request) + if err != nil { + t.Fatal(err) + } + + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("Expected success, got %s: %s", res.Status, string(body)) + } +} + func TestBackendServer_InvalidBody(t *testing.T) { _, _, _, _, _, server, shutdown := CreateBackendServerForTest(t) defer shutdown() @@ -260,14 +311,20 @@ func TestBackendServer_UnsupportedRequest(t *testing.T) { } func TestBackendServer_RoomInvite(t *testing.T) { - _, _, n, _, _, server, shutdown := CreateBackendServerForTest(t) + _, _, n, hub, _, server, shutdown := CreateBackendServerForTest(t) defer shutdown() + u, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + userid := "test-userid" roomProperties := json.RawMessage("{\"foo\":\"bar\"}") + backend := hub.backend.GetBackend(u) natsChan := make(chan *nats.Msg, 1) - subject := GetSubjectForUserId(userid) + subject := GetSubjectForUserId(userid, backend) sub, err := n.Subscribe(subject, natsChan) if err != nil { t.Fatal(err) @@ -321,6 +378,13 @@ func TestBackendServer_RoomDisinvite(t *testing.T) { _, _, n, hub, _, server, shutdown := CreateBackendServerForTest(t) defer shutdown() + u, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + + backend := hub.backend.GetBackend(u) + client := NewTestClient(t, server, hub) defer client.CloseWithBye() if err := client.SendHello(testDefaultUserId); err != nil { @@ -355,7 +419,7 @@ func TestBackendServer_RoomDisinvite(t *testing.T) { roomProperties := json.RawMessage("{\"foo\":\"bar\"}") natsChan := make(chan *nats.Msg, 1) - subject := GetSubjectForUserId(testDefaultUserId) + subject := GetSubjectForUserId(testDefaultUserId, backend) sub, err := n.Subscribe(subject, natsChan) if err != nil { t.Fatal(err) @@ -556,9 +620,18 @@ func TestBackendServer_RoomUpdate(t *testing.T) { _, _, n, hub, _, server, shutdown := CreateBackendServerForTest(t) defer shutdown() + u, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + roomId := "the-room-id" emptyProperties := json.RawMessage("{}") - room, err := hub.createRoom(roomId, &emptyProperties) + backend := hub.backend.GetBackend(u) + if backend == nil { + t.Fatalf("Did not find backend") + } + room, err := hub.createRoom(roomId, &emptyProperties, backend) if err != nil { t.Fatalf("Could not create room: %s", err) } @@ -568,7 +641,7 @@ func TestBackendServer_RoomUpdate(t *testing.T) { roomProperties := json.RawMessage("{\"foo\":\"bar\"}") natsChan := make(chan *nats.Msg, 1) - subject := GetSubjectForUserId(userid) + subject := GetSubjectForUserId(userid, backend) sub, err := n.Subscribe(subject, natsChan) if err != nil { t.Fatal(err) @@ -629,16 +702,25 @@ func TestBackendServer_RoomDelete(t *testing.T) { _, _, n, hub, _, server, shutdown := CreateBackendServerForTest(t) defer shutdown() + u, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + roomId := "the-room-id" emptyProperties := json.RawMessage("{}") - if _, err := hub.createRoom(roomId, &emptyProperties); err != nil { + backend := hub.backend.GetBackend(u) + if backend == nil { + t.Fatalf("Did not find backend") + } + if _, err := hub.createRoom(roomId, &emptyProperties, backend); err != nil { t.Fatalf("Could not create room: %s", err) } userid := "test-userid" natsChan := make(chan *nats.Msg, 1) - subject := GetSubjectForUserId(userid) + subject := GetSubjectForUserId(userid, backend) sub, err := n.Subscribe(subject, natsChan) if err != nil { t.Fatal(err) diff --git a/src/signaling/clientsession.go b/src/signaling/clientsession.go index 7065e95a..b46d03e5 100644 --- a/src/signaling/clientsession.go +++ b/src/signaling/clientsession.go @@ -22,7 +22,6 @@ package signaling import ( - "encoding/base64" "encoding/json" "log" "net/url" @@ -59,6 +58,7 @@ type ClientSession struct { supportsPermissions bool permissions map[Permission]bool + backend *Backend backendUrl string parsedBackendUrl *url.URL @@ -83,7 +83,7 @@ type ClientSession struct { pendingClientMessages []*NatsMessage } -func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) { +func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) { s := &ClientSession{ hub: hub, privateId: privateId, @@ -95,6 +95,7 @@ func NewClientSession(hub *Hub, privateId string, publicId string, data *Session userId: auth.UserId, userData: auth.User, + backend: backend, backendUrl: hello.Auth.Url, parsedBackendUrl: hello.Auth.parsedUrl, @@ -197,6 +198,10 @@ func (s *ClientSession) SetPermissions(permissions []Permission) { log.Printf("Permissions of session %s changed: %s", s.PublicId(), permissions) } +func (s *ClientSession) Backend() *Backend { + return s.backend +} + func (s *ClientSession) BackendUrl() string { return s.backendUrl } @@ -301,11 +306,12 @@ func (s *ClientSession) closeAndWait(wait bool) { } } -func GetSubjectForUserId(userId string) string { - // The NATS client doesn't work if a subject contains spaces. As the user id - // can have an arbitrary format, we need to make sure the subject is valid. - // See "https://github.com/nats-io/nats.js/issues/158" for a similar report. - return "user." + base64.StdEncoding.EncodeToString([]byte(userId)) +func GetSubjectForUserId(userId string, backend *Backend) string { + if backend == nil || backend.IsCompat() { + return GetEncodedSubject("user", userId) + } else { + return GetEncodedSubject("user", userId+"|"+backend.Id()) + } } func (s *ClientSession) SubscribeNats(n NatsClient) error { @@ -314,7 +320,7 @@ func (s *ClientSession) SubscribeNats(n NatsClient) error { var err error if s.userId != "" { - if s.userSubscription, err = n.Subscribe(GetSubjectForUserId(s.userId), s.natsReceiver); err != nil { + if s.userSubscription, err = n.Subscribe(GetSubjectForUserId(s.userId, s.backend), s.natsReceiver); err != nil { return err } } @@ -331,7 +337,7 @@ func (s *ClientSession) SubscribeRoomNats(n NatsClient, roomid string, roomSessi defer s.mu.Unlock() var err error - if s.roomSubscription, err = n.Subscribe("room."+roomid, s.natsReceiver); err != nil { + if s.roomSubscription, err = n.Subscribe(GetSubjectForRoomId(roomid, s.Backend()), s.natsReceiver); err != nil { return err } diff --git a/src/signaling/hub.go b/src/signaling/hub.go index 1bf2ae46..85451218 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -565,7 +565,7 @@ func (h *Hub) processNewClient(client *Client) { h.startExpectHello(client) } -func (h *Hub) processRegister(client *Client, message *ClientMessage, auth *BackendClientResponse) { +func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *Backend, auth *BackendClientResponse) { if !client.IsConnected() { // Client disconnected while waiting for "hello" response. return @@ -584,8 +584,9 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, auth *Back sid = atomic.AddUint64(&h.sid, 1) } sessionIdData := &SessionIdData{ - Sid: sid, - Created: time.Now(), + Sid: sid, + Created: time.Now(), + BackendId: backend.Id(), } privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName) if err != nil { @@ -600,14 +601,14 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, auth *Back userId := auth.Auth.UserId if userId != "" { - log.Printf("Register user %s from %s in %s (%s) %s (private=%s)", userId, client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) + log.Printf("Register user %s@%s from %s in %s (%s) %s (private=%s)", userId, backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) } else if message.Hello.Auth.Type != HelloClientTypeClient { - log.Printf("Register %s from %s in %s (%s) %s (private=%s)", message.Hello.Auth.Type, client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) + log.Printf("Register %s@%s from %s in %s (%s) %s (private=%s)", message.Hello.Auth.Type, backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) } else { - log.Printf("Register anonymous from %s in %s (%s) %s (private=%s)", client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) + log.Printf("Register anonymous@%s from %s in %s (%s) %s (private=%s)", backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) } - session, err := NewClientSession(h, privateSessionId, publicSessionId, sessionIdData, message.Hello, auth.Auth) + session, err := NewClientSession(h, privateSessionId, publicSessionId, sessionIdData, backend, message.Hello, auth.Auth) if err != nil { client.SendMessage(message.NewWrappedErrorServerMessage(err)) return @@ -754,7 +755,8 @@ func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { defer h.startExpectHello(client) url := message.Hello.Auth.parsedUrl - if !h.backend.IsUrlAllowed(url) { + backend := h.backend.GetBackend(url) + if backend == nil { client.SendMessage(message.NewErrorServerMessage(InvalidBackendUrl)) return } @@ -772,7 +774,7 @@ func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { // TODO(jojo): Validate response - h.processRegister(client, message, &auth) + h.processRegister(client, message, backend, &auth) } func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) { @@ -792,11 +794,17 @@ func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) { return } + backend := h.backend.GetBackend(message.Hello.Auth.internalParams.parsedBackend) + if backend == nil { + client.SendMessage(message.NewErrorServerMessage(InvalidBackendUrl)) + return + } + auth := &BackendClientResponse{ Type: "auth", Auth: &BackendClientAuthResponse{}, } - h.processRegister(client, message, auth) + h.processRegister(client, message, backend, auth) } func (h *Hub) disconnectByRoomSessionId(roomSessionId string) { @@ -852,7 +860,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { } if session != nil { - if room := h.getRoom(roomId); room != nil && room.HasSession(session) { + if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) { // Session already is in that room, no action needed. return } @@ -896,26 +904,43 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { h.processJoinRoom(client, message, &room) } +func (h *Hub) getRoomForBackend(id string, backend *Backend) *Room { + internalRoomId := getRoomIdForBackend(id, backend) + + h.ru.RLock() + defer h.ru.RUnlock() + return h.rooms[internalRoomId] +} + func (h *Hub) getRoom(id string) *Room { h.ru.RLock() defer h.ru.RUnlock() - return h.rooms[id] + // TODO: The same room might exist on different backends. + for _, room := range h.rooms { + if room.Id() == id { + return room + } + } + + return nil } func (h *Hub) removeRoom(room *Room) { + internalRoomId := getRoomIdForBackend(room.Id(), room.Backend()) h.ru.Lock() - delete(h.rooms, room.Id()) + delete(h.rooms, internalRoomId) h.ru.Unlock() } -func (h *Hub) createRoom(id string, properties *json.RawMessage) (*Room, error) { +func (h *Hub) createRoom(id string, properties *json.RawMessage, backend *Backend) (*Room, error) { // Note the write lock must be held. - room, err := NewRoom(id, properties, h, h.nats) + room, err := NewRoom(id, properties, h, h.nats, backend) if err != nil { return nil, err } - h.rooms[id] = room + internalRoomId := getRoomIdForBackend(id, backend) + h.rooms[internalRoomId] = room return room, nil } @@ -937,6 +962,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back session.LeaveRoom(true) roomId := room.Room.RoomId + internalRoomId := getRoomIdForBackend(roomId, session.Backend()) if err := session.SubscribeRoomNats(h.nats, roomId, message.Room.SessionId); err != nil { client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. @@ -945,10 +971,10 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back } h.ru.Lock() - r, found := h.rooms[roomId] + r, found := h.rooms[internalRoomId] if !found { var err error - if r, err = h.createRoom(roomId, room.Room.Properties); err != nil { + if r, err = h.createRoom(roomId, room.Room.Properties, session.Backend()); err != nil { h.ru.Unlock() client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. @@ -1011,6 +1037,11 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { case RecipientTypeSession: data := h.decodeSessionId(msg.Recipient.SessionId, publicSessionName) if data != nil { + if data.BackendId != session.Backend().Id() { + // Clients are only allowed to send to sessions from the same backend. + return + } + if h.mcu != nil { // Maybe this is a message to be processed by the MCU. var data MessageClientMessageData @@ -1054,12 +1085,12 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { return } - subject = GetSubjectForUserId(msg.Recipient.UserId) + subject = GetSubjectForUserId(msg.Recipient.UserId, session.Backend()) } case RecipientTypeRoom: if session != nil { if room := session.GetRoom(); room != nil { - subject = "room." + room.Id() + subject = GetSubjectForRoomId(room.Id(), room.Backend()) if h.mcu != nil { var data MessageClientMessageData @@ -1190,12 +1221,12 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { return } - subject = GetSubjectForUserId(msg.Recipient.UserId) + subject = GetSubjectForUserId(msg.Recipient.UserId, session.Backend()) } case RecipientTypeRoom: if session != nil { if room := session.GetRoom(); room != nil { - subject = "room." + room.Id() + subject = GetSubjectForRoomId(room.Id(), room.Backend()) } } } diff --git a/src/signaling/hub_test.go b/src/signaling/hub_test.go index 4baae61e..4e17cbca 100644 --- a/src/signaling/hub_test.go +++ b/src/signaling/hub_test.go @@ -62,7 +62,25 @@ func getTestConfig(server *httptest.Server) (*goconf.ConfigFile, error) { return config, nil } -func CreateHubForTest(t *testing.T) (*Hub, NatsClient, *mux.Router, *httptest.Server, func()) { +func getTestConfigWithMultipleBackends(server *httptest.Server) (*goconf.ConfigFile, error) { + config, err := getTestConfig(server) + if err != nil { + return nil, err + } + + config.RemoveOption("backend", "allowed") + config.RemoveOption("backend", "secret") + config.AddOption("backend", "backends", "backend1, backend2") + + config.AddOption("backend1", "url", server.URL+"/one") + config.AddOption("backend1", "secret", string(testBackendSecret)) + + config.AddOption("backend2", "url", server.URL+"/two/") + config.AddOption("backend2", "secret", string(testBackendSecret)) + return config, nil +} + +func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, NatsClient, *mux.Router, *httptest.Server, func()) { r := mux.NewRouter() registerBackendHandler(t, r) @@ -71,7 +89,7 @@ func CreateHubForTest(t *testing.T) (*Hub, NatsClient, *mux.Router, *httptest.Se if err != nil { t.Fatal(err) } - config, err := getTestConfig(server) + config, err := getConfigFunc(server) if err != nil { t.Fatal(err) } @@ -94,6 +112,17 @@ func CreateHubForTest(t *testing.T) (*Hub, NatsClient, *mux.Router, *httptest.Se return h, nats, r, server, shutdown } +func CreateHubForTest(t *testing.T) (*Hub, NatsClient, *mux.Router, *httptest.Server, func()) { + return CreateHubForTestWithConfig(t, getTestConfig) +} + +func CreateHubWithMultipleBackendsForTest(t *testing.T) (*Hub, NatsClient, *mux.Router, *httptest.Server, func()) { + h, nats, r, server, shutdown := CreateHubForTestWithConfig(t, getTestConfigWithMultipleBackends) + registerBackendHandlerUrl(t, r, "/one") + registerBackendHandlerUrl(t, r, "/two") + return h, nats, r, server, shutdown +} + func WaitForHub(ctx context.Context, t *testing.T, h *Hub) { h.Stop() for { @@ -212,7 +241,11 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re } func registerBackendHandler(t *testing.T, router *mux.Router) { - router.HandleFunc("/", validateBackendChecksum(t, func(w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { + registerBackendHandlerUrl(t, router, "/") +} + +func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { + router.HandleFunc(url, validateBackendChecksum(t, func(w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { switch request.Type { case "auth": return processAuthRequest(t, w, r, request) @@ -323,6 +356,41 @@ func TestClientHelloWithSpaces(t *testing.T) { } } +func TestClientHelloAllowAll(t *testing.T) { + hub, _, _, server, shutdown := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { + config, err := getTestConfig(server) + if err != nil { + return nil, err + } + + config.RemoveOption("backend", "allowed") + config.AddOption("backend", "allowall", "true") + return config, nil + }) + defer shutdown() + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + if err := client.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if hello, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } +} + func TestSessionIdsUnordered(t *testing.T) { hub, _, _, server, shutdown := CreateHubForTest(t) defer shutdown() @@ -1692,8 +1760,8 @@ func TestRoomParticipantsListUpdateWhileDisconnected(t *testing.T) { "inCall": 1, }, } - room, found := hub.rooms[roomId] - if !found { + room := hub.getRoom(roomId) + if room == nil { t.Fatalf("Could not find room %s", roomId) } room.PublishUsersInCallChanged(users, users) @@ -2031,3 +2099,181 @@ func TestClientSendOfferPermissions(t *testing.T) { t.Errorf("Expected no payload, got %+v", msg) } } + +func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { + // Clients can't send messages to sessions connected from other backends. + hub, _, _, server, shutdown := CreateHubWithMultipleBackendsForTest(t) + defer shutdown() + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewTestClient(t, server, hub) + defer client1.CloseWithBye() + + params1 := TestBackendClientAuthParams{ + UserId: "user1", + } + if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + t.Fatal(err) + } + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + + params2 := TestBackendClientAuthParams{ + UserId: "user2", + } + if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + t.Fatal(err) + } + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + recipient1 := MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + } + recipient2 := MessageClientMessageRecipient{ + Type: "session", + SessionId: hello2.Hello.SessionId, + } + + data1 := "from-1-to-2" + client1.SendMessage(recipient2, data1) + data2 := "from-2-to-1" + client2.SendMessage(recipient1, data2) + + var payload string + ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel2() + if err := checkReceiveClientMessage(ctx2, client1, "session", hello2.Hello, &payload); err != nil { + if err != NoMessageReceivedError { + t.Error(err) + } + } else { + t.Errorf("Expected no payload, got %+v", payload) + } + + ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel3() + if err := checkReceiveClientMessage(ctx3, client2, "session", hello1.Hello, &payload); err != nil { + if err != NoMessageReceivedError { + t.Error(err) + } + } else { + t.Errorf("Expected no payload, got %+v", payload) + } +} + +func TestNoSameRoomOnDifferentBackends(t *testing.T) { + hub, _, _, server, shutdown := CreateHubWithMultipleBackendsForTest(t) + defer shutdown() + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewTestClient(t, server, hub) + defer client1.CloseWithBye() + + params1 := TestBackendClientAuthParams{ + UserId: "user1", + } + if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + t.Fatal(err) + } + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + + params2 := TestBackendClientAuthParams{ + UserId: "user2", + } + if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + t.Fatal(err) + } + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + // Join room by id. + roomId := "test-room" + if room, err := client1.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + msg1, err := client1.RunUntilMessage(ctx) + if err != nil { + t.Error(err) + } + if err := client1.checkMessageJoined(msg1, hello1.Hello); err != nil { + t.Error(err) + } + + if room, err := client2.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + msg2, err := client2.RunUntilMessage(ctx) + if err != nil { + t.Error(err) + } + if err := client2.checkMessageJoined(msg2, hello2.Hello); err != nil { + t.Error(err) + } + + hub.ru.RLock() + roomCount := 0 + for _, room := range hub.rooms { + defer room.Close() + roomCount++ + } + hub.ru.RUnlock() + + if roomCount != 2 { + t.Errorf("Expected 2 rooms, got %d", roomCount) + } + + recipient := MessageClientMessageRecipient{ + Type: "room", + } + + data1 := "from-1-to-2" + client1.SendMessage(recipient, data1) + data2 := "from-2-to-1" + client2.SendMessage(recipient, data2) + + var payload string + ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel2() + if err := checkReceiveClientMessage(ctx2, client1, "session", hello2.Hello, &payload); err != nil { + if err != NoMessageReceivedError { + t.Error(err) + } + } else { + t.Errorf("Expected no payload, got %+v", payload) + } + + ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel3() + if err := checkReceiveClientMessage(ctx3, client2, "session", hello1.Hello, &payload); err != nil { + if err != NoMessageReceivedError { + t.Error(err) + } + } else { + t.Errorf("Expected no payload, got %+v", payload) + } +} diff --git a/src/signaling/natsclient.go b/src/signaling/natsclient.go index 56757f18..d1670c23 100644 --- a/src/signaling/natsclient.go +++ b/src/signaling/natsclient.go @@ -22,6 +22,7 @@ package signaling import ( + "encoding/base64" "fmt" "log" "os" @@ -63,6 +64,13 @@ type NatsClient interface { Decode(msg *nats.Msg, v interface{}) error } +// The NATS client doesn't work if a subject contains spaces. As the room id +// can have an arbitrary format, we need to make sure the subject is valid. +// See "https://github.com/nats-io/nats.js/issues/158" for a similar report. +func GetEncodedSubject(prefix string, suffix string) string { + return prefix + "." + base64.StdEncoding.EncodeToString([]byte(suffix)) +} + type natsClient struct { nc *nats.Conn conn *nats.EncodedConn diff --git a/src/signaling/room.go b/src/signaling/room.go index cdbbfd10..ee520edb 100644 --- a/src/signaling/room.go +++ b/src/signaling/room.go @@ -47,9 +47,10 @@ var ( ) type Room struct { - id string - hub *Hub - nats NatsClient + id string + hub *Hub + nats NatsClient + backend *Backend properties *json.RawMessage roomType int @@ -72,18 +73,43 @@ type Room struct { lastNatsRoomRequests map[string]int64 } -func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, n NatsClient) (*Room, error) { +func GetSubjectForRoomId(roomId string, backend *Backend) string { + if backend == nil || backend.IsCompat() { + return GetEncodedSubject("room", roomId) + } else { + return GetEncodedSubject("room", roomId+"|"+backend.Id()) + } +} + +func GetSubjectForBackendRoomId(roomId string, backend *Backend) string { + if backend == nil || backend.IsCompat() { + return GetEncodedSubject("backend.room", roomId) + } else { + return GetEncodedSubject("backend.room", roomId+"|"+backend.Id()) + } +} + +func getRoomIdForBackend(id string, backend *Backend) string { + if id == "" { + return "" + } + + return backend.Id() + "|" + id +} + +func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, n NatsClient, backend *Backend) (*Room, error) { natsReceiver := make(chan *nats.Msg, 64) - backendSubscription, err := n.Subscribe("backend.room."+roomId, natsReceiver) + backendSubscription, err := n.Subscribe(GetSubjectForBackendRoomId(roomId, backend), natsReceiver) if err != nil { close(natsReceiver) return nil, err } room := &Room{ - id: roomId, - hub: hub, - nats: n, + id: roomId, + hub: hub, + nats: n, + backend: backend, properties: properties, @@ -115,6 +141,10 @@ func (r *Room) Properties() *json.RawMessage { return r.properties } +func (r *Room) Backend() *Backend { + return r.backend +} + func (r *Room) run() { ticker := time.NewTicker(updateActiveSessionsInterval) loop: @@ -278,7 +308,7 @@ func (r *Room) RemoveSession(session Session) bool { } func (r *Room) publish(message *ServerMessage) { - r.nats.PublishMessage("room."+r.id, message) + r.nats.PublishMessage(GetSubjectForRoomId(r.id, r.backend), message) } func (r *Room) UpdateProperties(properties *json.RawMessage) { diff --git a/src/signaling/room_test.go b/src/signaling/room_test.go index a7b3b3fa..3a127f3a 100644 --- a/src/signaling/room_test.go +++ b/src/signaling/room_test.go @@ -194,11 +194,7 @@ loop: break loop default: // The internal room has been updated with the new properties. - hub.ru.Lock() - room, found := hub.rooms[roomId] - hub.ru.Unlock() - - if !found { + if room := hub.getRoom(roomId); room == nil { err = fmt.Errorf("Room %s not found in hub", roomId) } else if room.Properties() == nil || !bytes.Equal(*room.Properties(), roomProperties) { err = fmt.Errorf("Expected room properties %s, got %+v", string(roomProperties), room.Properties()) diff --git a/src/signaling/roomsessions_test.go b/src/signaling/roomsessions_test.go index 7bc6d24b..cac5e36e 100644 --- a/src/signaling/roomsessions_test.go +++ b/src/signaling/roomsessions_test.go @@ -56,6 +56,10 @@ func (s *DummySession) UserData() *json.RawMessage { return nil } +func (s *DummySession) Backend() *Backend { + return nil +} + func (s *DummySession) BackendUrl() string { return "" } diff --git a/src/signaling/session.go b/src/signaling/session.go index 00654068..155cb8b8 100644 --- a/src/signaling/session.go +++ b/src/signaling/session.go @@ -36,8 +36,9 @@ var ( ) type SessionIdData struct { - Sid uint64 - Created time.Time + Sid uint64 + Created time.Time + BackendId string } type Session interface { @@ -49,6 +50,7 @@ type Session interface { UserId() string UserData() *json.RawMessage + Backend() *Backend BackendUrl() string ParsedBackendUrl() *url.URL diff --git a/src/signaling/testclient_test.go b/src/signaling/testclient_test.go index 20d57cce..99393d97 100644 --- a/src/signaling/testclient_test.go +++ b/src/signaling/testclient_test.go @@ -376,10 +376,12 @@ func (c *TestClient) SendHelloInternal() error { mac := hmac.New(sha256.New, testInternalSecret) mac.Write([]byte(random)) token := hex.EncodeToString(mac.Sum(nil)) + backend := c.server.URL params := ClientTypeInternalAuthParams{ - Random: random, - Token: token, + Random: random, + Token: token, + Backend: backend, } return c.SendHelloParams("", "internal", params) }