diff --git a/api_backend.go b/api_backend.go index 58ee9c71..0f87e86a 100644 --- a/api_backend.go +++ b/api_backend.go @@ -37,6 +37,11 @@ const ( HeaderBackendSignalingRandom = "Spreed-Signaling-Random" HeaderBackendSignalingChecksum = "Spreed-Signaling-Checksum" HeaderBackendServer = "Spreed-Signaling-Backend" + + ConfigGroupSignaling = "signaling" + + ConfigKeyHelloV2TokenKey = "hello-v2-token-key" + ConfigKeySessionPingLimit = "session-ping-limit" ) func newRandomString(length int) string { diff --git a/api_proxy.go b/api_proxy.go index 093f227f..562456aa 100644 --- a/api_proxy.go +++ b/api_proxy.go @@ -142,7 +142,7 @@ type HelloProxyClientMessage struct { } func (m *HelloProxyClientMessage) CheckValid() error { - if m.Version != HelloVersion { + if m.Version != HelloVersionV1 { return fmt.Errorf("unsupported hello version: %s", m.Version) } if m.ResumeId == "" { diff --git a/api_signaling.go b/api_signaling.go index 930663c7..851284fa 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -27,11 +27,16 @@ import ( "net/url" "sort" "strings" + + "github.com/golang-jwt/jwt" ) const ( - // Version that must be sent in a "hello" message. - HelloVersion = "1.0" + // Version 1.0 validates auth params against the Nextcloud instance. + HelloVersionV1 = "1.0" + + // Version 2.0 validates auth params encoded as JWT. + HelloVersionV2 = "2.0" ) // ClientMessage is a message that is sent from a client to the server. @@ -325,6 +330,23 @@ func (p *ClientTypeInternalAuthParams) CheckValid() error { return nil } +type HelloV2AuthParams struct { + Token string `json:"token"` +} + +func (p *HelloV2AuthParams) CheckValid() error { + if p.Token == "" { + return fmt.Errorf("token missing") + } + return nil +} + +type HelloV2TokenClaims struct { + jwt.StandardClaims + + UserData *json.RawMessage `json:"userdata,omitempty"` +} + type HelloClientMessageAuth struct { // The client type that is connecting. Leave empty to use the default // "HelloClientTypeClient" @@ -336,6 +358,7 @@ type HelloClientMessageAuth struct { parsedUrl *url.URL internalParams ClientTypeInternalAuthParams + helloV2Params HelloV2AuthParams } // Type "hello" @@ -352,8 +375,8 @@ type HelloClientMessage struct { } func (m *HelloClientMessage) CheckValid() error { - if m.Version != HelloVersion { - return fmt.Errorf("unsupported hello version: %s", m.Version) + if m.Version != HelloVersionV1 && m.Version != HelloVersionV2 { + return InvalidHelloVersion } if m.ResumeId == "" { if m.Auth.Params == nil || len(*m.Auth.Params) == 0 { @@ -375,6 +398,17 @@ func (m *HelloClientMessage) CheckValid() error { m.Auth.parsedUrl = u } + + switch m.Version { + case HelloVersionV1: + // No additional validation necessary. + case HelloVersionV2: + if err := json.Unmarshal(*m.Auth.Params, &m.Auth.helloV2Params); err != nil { + return err + } else if err := m.Auth.helloV2Params.CheckValid(); err != nil { + return err + } + } case HelloClientTypeInternal: if err := json.Unmarshal(*m.Auth.Params, &m.Auth.internalParams); err != nil { return err diff --git a/api_signaling_test.go b/api_signaling_test.go index 6e9bc7ab..94e54c7a 100644 --- a/api_signaling_test.go +++ b/api_signaling_test.go @@ -90,16 +90,18 @@ func TestClientMessage(t *testing.T) { func TestHelloClientMessage(t *testing.T) { internalAuthParams := []byte("{\"backend\":\"https://domain.invalid\"}") + tokenAuthParams := []byte("{\"token\":\"invalid-token\"}") valid_messages := []testCheckValid{ + // Hello version 1 &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Url: "https://domain.invalid", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "client", Params: &json.RawMessage{'{', '}'}, @@ -107,61 +109,116 @@ func TestHelloClientMessage(t *testing.T) { }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "internal", Params: (*json.RawMessage)(&internalAuthParams), }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, + ResumeId: "the-resume-id", + }, + // Hello version 2 + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&tokenAuthParams), + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Type: "client", + Params: (*json.RawMessage)(&tokenAuthParams), + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, ResumeId: "the-resume-id", }, } invalid_messages := []testCheckValid{ + // Hello version 1 &HelloClientMessage{}, &HelloClientMessage{Version: "0.0"}, - &HelloClientMessage{Version: HelloVersion}, + &HelloClientMessage{Version: HelloVersionV1}, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Type: "invalid-type", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Url: "https://domain.invalid", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Url: "invalid-url", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "internal", Params: &json.RawMessage{'{', '}'}, }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "internal", Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON. }, }, + // Hello version 2 + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&tokenAuthParams), + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&tokenAuthParams), + Url: "invalid-url", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&internalAuthParams), + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON. + Url: "https://domain.invalid", + }, + }, } testMessages(t, "hello", valid_messages, invalid_messages) diff --git a/client/main.go b/client/main.go index 1338de90..a9b22294 100644 --- a/client/main.go +++ b/client/main.go @@ -603,7 +603,7 @@ func main() { request := &signaling.ClientMessage{ Type: "hello", Hello: &signaling.HelloClientMessage{ - Version: signaling.HelloVersion, + Version: signaling.HelloVersionV1, Auth: signaling.HelloClientMessageAuth{ Url: backendUrl + "/auth", Params: &json.RawMessage{'{', '}'}, diff --git a/clientsession_test.go b/clientsession_test.go index 427761f4..eb152ddf 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -238,7 +238,7 @@ func TestBandwidth_Backend(t *testing.T) { params := TestBackendClientAuthParams{ UserId: testDefaultUserId, } - if err := client.SendHelloParams(server.URL+"/one", "client", params); err != nil { + if err := client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params); err != nil { t.Fatal(err) } diff --git a/hub.go b/hub.go index 81ca1b15..8fd2c1e5 100644 --- a/hub.go +++ b/hub.go @@ -39,20 +39,24 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/golang-jwt/jwt" "github.com/gorilla/mux" "github.com/gorilla/securecookie" "github.com/gorilla/websocket" ) var ( - DuplicateClient = NewError("duplicate_client", "Client already registered.") - HelloExpected = NewError("hello_expected", "Expected Hello request.") - UserAuthFailed = NewError("auth_failed", "The user could not be authenticated.") - RoomJoinFailed = NewError("room_join_failed", "Could not join the room.") - InvalidClientType = NewError("invalid_client_type", "The client type is not supported.") - InvalidBackendUrl = NewError("invalid_backend", "The backend URL is not supported.") - InvalidToken = NewError("invalid_token", "The passed token is invalid.") - NoSuchSession = NewError("no_such_session", "The session to resume does not exist.") + DuplicateClient = NewError("duplicate_client", "Client already registered.") + HelloExpected = NewError("hello_expected", "Expected Hello request.") + InvalidHelloVersion = NewError("invalid_hello_version", "The hello version is not supported.") + UserAuthFailed = NewError("auth_failed", "The user could not be authenticated.") + RoomJoinFailed = NewError("room_join_failed", "Could not join the room.") + InvalidClientType = NewError("invalid_client_type", "The client type is not supported.") + InvalidBackendUrl = NewError("invalid_backend", "The backend URL is not supported.") + InvalidToken = NewError("invalid_token", "The passed token is invalid.") + NoSuchSession = NewError("no_such_session", "The session to resume does not exist.") + TokenNotValidYet = NewError("token_not_valid_yet", "The token is not valid yet.") + TokenExpired = NewError("token_expired", "The token is expired.") // Maximum number of concurrent requests to a backend. defaultMaxConcurrentRequestsPerHost = 8 @@ -803,10 +807,18 @@ func (h *Hub) processMessage(client *Client, data []byte) { if err := message.CheckValid(); err != nil { if session := client.GetSession(); session != nil { log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) - session.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + if err, ok := err.(*Error); ok { + session.SendMessage(message.NewErrorServerMessage(err)) + } else { + session.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + } } else { log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) - client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + if err, ok := err.(*Error); ok { + client.SendMessage(message.NewErrorServerMessage(err)) + } else { + client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + } } return } @@ -849,7 +861,7 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage) Id: message.Id, Type: "hello", Hello: &HelloServerMessage{ - Version: HelloVersion, + Version: message.Hello.Version, SessionId: session.PublicId(), ResumeId: session.PrivateId(), UserId: session.UserId(), @@ -928,31 +940,137 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) { } } -func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { - // Make sure the client must send another "hello" in case of errors. - defer h.startExpectHello(client) - +func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { url := message.Hello.Auth.parsedUrl backend := h.backend.GetBackend(url) if backend == nil { - client.SendMessage(message.NewErrorServerMessage(InvalidBackendUrl)) - return + return nil, nil, InvalidBackendUrl } // Run in timeout context to prevent blocking too long. ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) defer cancel() - request := NewBackendClientAuthRequest(message.Hello.Auth.Params) var auth BackendClientResponse + request := NewBackendClientAuthRequest(message.Hello.Auth.Params) if err := h.backend.PerformJSONRequest(ctx, url, request, &auth); err != nil { - client.SendMessage(message.NewWrappedErrorServerMessage(err)) - return + return nil, nil, err } // TODO(jojo): Validate response - h.processRegister(client, message, backend, &auth) + return backend, &auth, nil +} + +func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { + url := message.Hello.Auth.parsedUrl + backend := h.backend.GetBackend(url) + if backend == nil { + return nil, nil, InvalidBackendUrl + } + + token, err := jwt.ParseWithClaims(message.Hello.Auth.helloV2Params.Token, &HelloV2TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + // Only public-private-key algorithms are supported. + var loadKeyFunc func([]byte) (interface{}, error) + switch token.Method.(type) { + case *jwt.SigningMethodRSA: + loadKeyFunc = func(data []byte) (interface{}, error) { + return jwt.ParseRSAPublicKeyFromPEM(data) + } + case *jwt.SigningMethodECDSA: + loadKeyFunc = func(data []byte) (interface{}, error) { + return jwt.ParseECPublicKeyFromPEM(data) + } + case *jwt.SigningMethodEd25519: + loadKeyFunc = func(data []byte) (interface{}, error) { + return jwt.ParseEdPublicKeyFromPEM(data) + } + default: + log.Printf("Unexpected signing method: %v", token.Header["alg"]) + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + // Run in timeout context to prevent blocking too long. + ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) + defer cancel() + + keyData, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) + if !found { + return nil, fmt.Errorf("No key found for issuer") + } + + key, err := loadKeyFunc([]byte(keyData)) + if err != nil { + return nil, fmt.Errorf("Could not parse token key: %w", err) + } + + return key, nil + }) + if err != nil { + if err, ok := err.(*jwt.ValidationError); ok { + if err.Errors&jwt.ValidationErrorIssuedAt == jwt.ValidationErrorIssuedAt { + return nil, nil, TokenNotValidYet + } + if err.Errors&jwt.ValidationErrorExpired == jwt.ValidationErrorExpired { + return nil, nil, TokenExpired + } + } + + return nil, nil, InvalidToken + } + + claims, ok := token.Claims.(*HelloV2TokenClaims) + if !ok || !token.Valid { + return nil, nil, InvalidToken + } + now := time.Now() + if !claims.VerifyIssuedAt(now.Unix(), true) { + return nil, nil, TokenNotValidYet + } + if !claims.VerifyExpiresAt(now.Unix(), true) { + return nil, nil, TokenExpired + } + + auth := &BackendClientResponse{ + Type: "auth", + Auth: &BackendClientAuthResponse{ + Version: message.Hello.Version, + UserId: claims.Subject, + User: claims.UserData, + }, + } + return backend, auth, nil +} + +func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { + // Make sure the client must send another "hello" in case of errors. + defer h.startExpectHello(client) + + var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error) + switch message.Hello.Version { + case HelloVersionV1: + // Auth information contains a ticket that must be validated against the + // Nextcloud instance. + authFunc = h.processHelloV1 + case HelloVersionV2: + // Auth information contains a JWT that contains all information of the user. + authFunc = h.processHelloV2 + default: + client.SendMessage(message.NewErrorServerMessage(InvalidHelloVersion)) + return + } + + backend, auth, err := authFunc(client, message) + if err != nil { + if e, ok := err.(*Error); ok { + client.SendMessage(message.NewErrorServerMessage(e)) + } else { + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + } + return + } + + h.processRegister(client, message, backend, auth) } func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) { diff --git a/hub_test.go b/hub_test.go index 2583c71e..8ee01afd 100644 --- a/hub_test.go +++ b/hub_test.go @@ -23,11 +23,20 @@ package signaling import ( "context" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" "encoding/json" + "encoding/pem" "io" "net/http" "net/http/httptest" "net/url" + "os" "reflect" "strings" "sync" @@ -36,6 +45,7 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/golang-jwt/jwt" "github.com/gorilla/mux" "github.com/gorilla/websocket" ) @@ -47,6 +57,14 @@ const ( testTimeout = 10 * time.Second ) +var ( + testHelloV2Algorithms = []string{ + "RSA", + "ECDSA", + "Ed25519", + } +) + // Only used for testing. func (h *Hub) getRoom(id string) *Room { h.ru.RLock() @@ -376,6 +394,131 @@ func processPingRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re return response } +func ensureAuthTokens(t *testing.T) (string, string) { + if privateKey := os.Getenv("PRIVATE_AUTH_TOKEN_" + t.Name()); privateKey != "" { + publicKey := os.Getenv("PUBLIC_AUTH_TOKEN_" + t.Name()) + if publicKey == "" { + // should not happen, always both keys are created + t.Fatal("public key is empty") + } + return privateKey, publicKey + } + + var private []byte + var public []byte + + if strings.Contains(t.Name(), "ECDSA") { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + private, err = x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + private = pem.EncodeToMemory(&pem.Block{ + Type: "ECDSA PRIVATE KEY", + Bytes: private, + }) + + public, err = x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + t.Fatal(err) + } + public = pem.EncodeToMemory(&pem.Block{ + Type: "ECDSA PUBLIC KEY", + Bytes: public, + }) + } else if strings.Contains(t.Name(), "Ed25519") { + publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + private, err = x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + t.Fatal(err) + } + private = pem.EncodeToMemory(&pem.Block{ + Type: "Ed25519 PRIVATE KEY", + Bytes: private, + }) + + public, err = x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + t.Fatal(err) + } + public = pem.EncodeToMemory(&pem.Block{ + Type: "Ed25519 PUBLIC KEY", + Bytes: public, + }) + } else { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + private = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + public, err = x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + t.Fatal(err) + } + public = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: public, + }) + } + + privateKey := base64.StdEncoding.EncodeToString(private) + t.Setenv("PRIVATE_AUTH_TOKEN_"+t.Name(), privateKey) + publicKey := base64.StdEncoding.EncodeToString(public) + t.Setenv("PUBLIC_AUTH_TOKEN_"+t.Name(), publicKey) + return privateKey, publicKey +} + +func getPrivateAuthToken(t *testing.T) (key interface{}) { + private, _ := ensureAuthTokens(t) + data, err := base64.StdEncoding.DecodeString(private) + if err != nil { + t.Fatal(err) + } + if strings.Contains(t.Name(), "ECDSA") { + key, err = jwt.ParseECPrivateKeyFromPEM(data) + } else if strings.Contains(t.Name(), "Ed25519") { + key, err = jwt.ParseEdPrivateKeyFromPEM(data) + } else { + key, err = jwt.ParseRSAPrivateKeyFromPEM(data) + } + if err != nil { + t.Fatal(err) + } + return key +} + +func getPublicAuthToken(t *testing.T) (key interface{}) { + _, public := ensureAuthTokens(t) + data, err := base64.StdEncoding.DecodeString(public) + if err != nil { + t.Fatal(err) + } + if strings.Contains(t.Name(), "ECDSA") { + key, err = jwt.ParseECPublicKeyFromPEM(data) + } else if strings.Contains(t.Name(), "Ed25519") { + key, err = jwt.ParseEdPublicKeyFromPEM(data) + } else { + key, err = jwt.ParseRSAPublicKeyFromPEM(data) + } + if err != nil { + t.Fatal(err) + } + return key +} + func registerBackendHandler(t *testing.T, router *mux.Router) { registerBackendHandlerUrl(t, router, "/") } @@ -420,6 +563,27 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { if strings.Contains(t.Name(), "MultiRoom") { signaling[ConfigKeySessionPingLimit] = 2 } + if strings.Contains(t.Name(), "V2") { + key := getPublicAuthToken(t) + public, err := x509.MarshalPKIXPublicKey(key) + if err != nil { + t.Fatal(err) + } + var pemType string + if strings.Contains(t.Name(), "ECDSA") { + pemType = "ECDSA PUBLIC KEY" + } else if strings.Contains(t.Name(), "Ed25519") { + pemType = "Ed25519 PUBLIC KEY" + } else { + pemType = "RSA PUBLIC KEY" + } + + public = pem.EncodeToMemory(&pem.Block{ + Type: pemType, + Bytes: public, + }) + signaling[ConfigKeyHelloV2TokenKey] = string(public) + } spreedCapa, _ := json.Marshal(map[string]interface{}{ "features": features, "config": config, @@ -530,7 +694,35 @@ func TestExpectClientHello(t *testing.T) { } } -func TestClientHello(t *testing.T) { +func TestExpectClientHelloUnsupportedVersion(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + params := TestBackendClientAuthParams{ + UserId: testDefaultUserId, + } + if err := client.SendHelloParams(server.URL, "0.0", "", params); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "invalid_hello_version" { + t.Errorf("Expected \"invalid_hello_version\" reason, got %+v", message.Error) + } +} + +func TestClientHelloV1(t *testing.T) { hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) @@ -555,6 +747,179 @@ func TestClientHello(t *testing.T) { } } +func TestClientHelloV2(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + if err := client.SendHelloV2(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + hello, err := client.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + + data := hub.decodeSessionId(hello.Hello.SessionId, publicSessionName) + if data == nil { + t.Fatalf("Could not decode session id: %s", hello.Hello.SessionId) + } + + hub.mu.RLock() + session := hub.sessions[data.Sid] + hub.mu.RUnlock() + if session == nil { + t.Fatalf("Could not get session for id %+v", data) + } + + var userdata map[string]string + if err := json.Unmarshal(*session.UserData(), &userdata); err != nil { + t.Fatal(err) + } + + if expected := "Displayname " + testDefaultUserId; userdata["displayname"] != expected { + t.Errorf("Expected displayname %s, got %s", expected, userdata["displayname"]) + } + }) + } +} + +func TestClientHelloV2_IssuedInFuture(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + issuedAt := time.Now().Add(time.Minute) + expiresAt := issuedAt.Add(time.Second) + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_not_valid_yet" { + t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error) + } + }) + } +} + +func TestClientHelloV2_Expired(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + issuedAt := time.Now().Add(-time.Minute) + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, issuedAt.Add(time.Second)); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_expired" { + t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error) + } + }) + } +} + +func TestClientHelloV2_IssuedAtMissing(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + var issuedAt time.Time + expiresAt := time.Now().Add(time.Minute) + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_not_valid_yet" { + t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error) + } + }) + } +} + +func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + issuedAt := time.Now().Add(-time.Minute) + var expiresAt time.Time + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_expired" { + t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error) + } + }) + } +} + func TestClientHelloWithSpaces(t *testing.T) { hub, _, _, server := CreateHubForTest(t) @@ -644,7 +1009,7 @@ func TestClientHelloSessionLimit(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: testDefaultUserId, } - if err := client.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + if err := client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params1); err != nil { t.Fatal(err) } @@ -669,7 +1034,7 @@ func TestClientHelloSessionLimit(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: testDefaultUserId + "2", } - if err := client2.SendHelloParams(server.URL+"/one", "client", params2); err != nil { + if err := client2.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } @@ -685,7 +1050,7 @@ func TestClientHelloSessionLimit(t *testing.T) { } // The client can connect to a different backend. - if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } @@ -712,7 +1077,7 @@ func TestClientHelloSessionLimit(t *testing.T) { params3 := TestBackendClientAuthParams{ UserId: testDefaultUserId + "3", } - if err := client3.SendHelloParams(server.URL+"/one", "client", params3); err != nil { + if err := client3.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params3); err != nil { t.Fatal(err) } @@ -1314,7 +1679,7 @@ func TestClientHelloClient_V3Api(t *testing.T) { } // The "/api/v1/signaling/" URL will be changed to use "v3" as the "signaling-v3" // feature is returned by the capabilities endpoint. - if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", "client", params); err != nil { + if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", HelloVersionV1, "client", params); err != nil { t.Fatal(err) } @@ -3128,7 +3493,7 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: "user1", } - if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params1); err != nil { t.Fatal(err) } hello1, err := client1.RunUntilHello(ctx) @@ -3142,7 +3507,7 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: "user2", } - if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } hello2, err := client2.RunUntilHello(ctx) @@ -3198,7 +3563,7 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: "user1", } - if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params1); err != nil { t.Fatal(err) } hello1, err := client1.RunUntilHello(ctx) @@ -3212,7 +3577,7 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: "user2", } - if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } hello2, err := client2.RunUntilHello(ctx) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 72a2febe..232c9031 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -589,7 +589,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { Id: message.Id, Type: "hello", Hello: &signaling.HelloProxyServerMessage{ - Version: signaling.HelloVersion, + Version: signaling.HelloVersionV1, SessionId: session.PublicId(), Server: &signaling.WelcomeServerMessage{ Version: s.version, diff --git a/room_ping.go b/room_ping.go index d3cf2318..2e83fe91 100644 --- a/room_ping.go +++ b/room_ping.go @@ -29,11 +29,6 @@ import ( "time" ) -const ( - ConfigGroupSignaling = "signaling" - ConfigKeySessionPingLimit = "session-ping-limit" -) - type pingEntries struct { url *url.URL diff --git a/testclient_test.go b/testclient_test.go index 4cf40f4d..ce8afed9 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -37,6 +37,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt" "github.com/gorilla/websocket" ) @@ -328,11 +329,14 @@ func (c *TestClient) WaitForSessionRemoved(ctx context.Context, sessionId string } func (c *TestClient) WriteJSON(data interface{}) error { - if msg, ok := data.(*ClientMessage); ok { - if err := msg.CheckValid(); err != nil { - return err + if !strings.Contains(c.t.Name(), "HelloUnsupportedVersion") { + if msg, ok := data.(*ClientMessage); ok { + if err := msg.CheckValid(); err != nil { + return err + } } } + return c.conn.WriteJSON(data) } @@ -343,10 +347,63 @@ func (c *TestClient) EnsuerWriteJSON(data interface{}) { } func (c *TestClient) SendHello(userid string) error { + return c.SendHelloV1(userid) +} + +func (c *TestClient) SendHelloV1(userid string) error { params := TestBackendClientAuthParams{ UserId: userid, } - return c.SendHelloParams(c.server.URL, "", params) + return c.SendHelloParams(c.server.URL, HelloVersionV1, "", params) +} + +func (c *TestClient) SendHelloV2(userid string) error { + now := time.Now() + return c.SendHelloV2WithTimes(userid, now, now.Add(time.Second)) +} + +func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error { + userdata := map[string]string{ + "displayname": "Displayname " + userid, + } + + data, err := json.Marshal(userdata) + if err != nil { + c.t.Fatal(err) + } + + claims := &HelloV2TokenClaims{ + StandardClaims: jwt.StandardClaims{ + Issuer: c.server.URL, + Subject: userid, + }, + UserData: (*json.RawMessage)(&data), + } + if !issuedAt.IsZero() { + claims.IssuedAt = issuedAt.Unix() + } + if !expiresAt.IsZero() { + claims.ExpiresAt = expiresAt.Unix() + } + + var token *jwt.Token + if strings.Contains(c.t.Name(), "ECDSA") { + token = jwt.NewWithClaims(jwt.SigningMethodES256, claims) + } else if strings.Contains(c.t.Name(), "Ed25519") { + token = jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + } else { + token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + } + private := getPrivateAuthToken(c.t) + tokenString, err := token.SignedString(private) + if err != nil { + c.t.Fatal(err) + } + + params := HelloV2AuthParams{ + Token: tokenString, + } + return c.SendHelloParams(c.server.URL, HelloVersionV2, "", params) } func (c *TestClient) SendHelloResume(resumeId string) error { @@ -354,7 +411,7 @@ func (c *TestClient) SendHelloResume(resumeId string) error { Id: "1234", Type: "hello", Hello: &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, ResumeId: resumeId, }, } @@ -365,7 +422,7 @@ func (c *TestClient) SendHelloClient(userid string) error { params := TestBackendClientAuthParams{ UserId: userid, } - return c.SendHelloParams(c.server.URL, "client", params) + return c.SendHelloParams(c.server.URL, HelloVersionV1, "client", params) } func (c *TestClient) SendHelloInternal() error { @@ -380,10 +437,10 @@ func (c *TestClient) SendHelloInternal() error { Token: token, Backend: backend, } - return c.SendHelloParams("", "internal", params) + return c.SendHelloParams("", HelloVersionV1, "internal", params) } -func (c *TestClient) SendHelloParams(url string, clientType string, params interface{}) error { +func (c *TestClient) SendHelloParams(url string, version string, clientType string, params interface{}) error { data, err := json.Marshal(params) if err != nil { c.t.Fatal(err) @@ -393,7 +450,7 @@ func (c *TestClient) SendHelloParams(url string, clientType string, params inter Id: "1234", Type: "hello", Hello: &HelloClientMessage{ - Version: HelloVersion, + Version: version, Auth: HelloClientMessageAuth{ Type: clientType, Url: url,