diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 481bd36b27..8fd8511845 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -43,6 +43,7 @@ type AccountDatabase interface { // Look up the account matching the given localpart. GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) + GetLocalpartForThreePID(ctx context.Context, address, medium string) (string, error) } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index f01e48f806..2b19b355c9 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -10,5 +10,6 @@ const ( LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" + LoginTypeSSO = "m.login.sso" LoginTypeToken = "m.login.token" ) diff --git a/clientapi/auth/sso/github.go b/clientapi/auth/sso/github.go new file mode 100644 index 0000000000..9ef5cd2daf --- /dev/null +++ b/clientapi/auth/sso/github.go @@ -0,0 +1,37 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "github.com/matrix-org/dendrite/setup/config" +) + +// GitHubIdentityProvider is a GitHub-flavored identity provider. +var GitHubIdentityProvider IdentityProvider = githubIdentityProvider{ + baseOIDCIdentityProvider: &baseOIDCIdentityProvider{ + AuthURL: mustParseURLTemplate("https://github.com/login/oauth/authorize?scope=user:email"), + AccessTokenURL: mustParseURLTemplate("https://github.com/login/oauth/access_token"), + UserInfoURL: mustParseURLTemplate("https://api.github.com/user"), + UserInfoAccept: "application/vnd.github.v3+json", + UserInfoEmailPath: "email", + UserInfoSuggestedUserIDPath: "login", + }, +} + +type githubIdentityProvider struct { + *baseOIDCIdentityProvider +} + +func (githubIdentityProvider) DefaultBrand() string { return config.SSOBrandGitHub } diff --git a/clientapi/auth/sso/oidc_base.go b/clientapi/auth/sso/oidc_base.go new file mode 100644 index 0000000000..edc600d4c4 --- /dev/null +++ b/clientapi/auth/sso/oidc_base.go @@ -0,0 +1,262 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "mime" + "net/http" + "net/url" + "strings" + "text/template" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/tidwall/gjson" +) + +type baseOIDCIdentityProvider struct { + AuthURL *urlTemplate + AccessTokenURL *urlTemplate + UserInfoURL *urlTemplate + UserInfoAccept string + UserInfoEmailPath string + UserInfoSuggestedUserIDPath string +} + +func (p *baseOIDCIdentityProvider) AuthorizationURL(ctx context.Context, req *IdentityProviderRequest) (string, error) { + u, err := p.AuthURL.Execute(map[string]interface{}{ + "Config": req.System, + "State": req.DendriteNonce, + "RedirectURI": req.CallbackURL, + }, url.Values{ + "client_id": []string{req.System.OIDC.ClientID}, + "response_type": []string{"code"}, + "redirect_uri": []string{req.CallbackURL}, + "state": []string{req.DendriteNonce}, + }) + if err != nil { + return "", err + } + return u.String(), nil +} + +func (p *baseOIDCIdentityProvider) ProcessCallback(ctx context.Context, req *IdentityProviderRequest, values url.Values) (*CallbackResult, error) { + state := values.Get("state") + if state == "" { + return nil, jsonerror.MissingArgument("state parameter missing") + } + if state != req.DendriteNonce { + return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce") + } + + if error := values.Get("error"); error != "" { + if euri := values.Get("error_uri"); euri != "" { + return &CallbackResult{RedirectURL: euri}, nil + } + + desc := values.Get("error_description") + if desc == "" { + desc = error + } + switch error { + case "unauthorized_client", "access_denied": + return nil, jsonerror.Forbidden("SSO said no: " + desc) + default: + return nil, fmt.Errorf("SSO failed: %v", error) + } + } + + code := values.Get("code") + if code == "" { + return nil, jsonerror.MissingArgument("code parameter missing") + } + + oidcAccessToken, err := p.getOIDCAccessToken(ctx, req, code) + if err != nil { + return nil, err + } + + id, userID, err := p.getUserInfo(ctx, req, oidcAccessToken) + if err != nil { + return nil, err + } + + return &CallbackResult{Identifier: id, SuggestedUserID: userID}, nil +} + +func (p *baseOIDCIdentityProvider) getOIDCAccessToken(ctx context.Context, req *IdentityProviderRequest, code string) (string, error) { + u, err := p.AccessTokenURL.Execute(nil, nil) + if err != nil { + return "", err + } + + body := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "redirect_uri": []string{req.CallbackURL}, + "client_id": []string{req.System.OIDC.ClientID}, + } + + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + hreq.Header.Set("Accept", "application/x-www-form-urlencoded") + + hresp, err := http.DefaultClient.Do(hreq) + if err != nil { + return "", err + } + defer hresp.Body.Close() + + ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) + if err != nil { + return "", err + } + if ctype != "application/json" { + return "", fmt.Errorf("expected URL encoded response, got content type %q", ctype) + } + + var resp struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` + } + if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil { + return "", err + } + + if resp.Error != "" { + desc := resp.ErrorDescription + if desc == "" { + desc = resp.Error + } + return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc) + } + + if strings.ToLower(resp.TokenType) != "bearer" { + return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType) + } + + return resp.AccessToken, nil +} + +func (p *baseOIDCIdentityProvider) getUserInfo(ctx context.Context, req *IdentityProviderRequest, oidcAccessToken string) (userutil.Identifier, string, error) { + u, err := p.UserInfoURL.Execute(map[string]interface{}{ + "Config": req.System, + }, nil) + if err != nil { + return nil, "", err + } + + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, "", err + } + hreq.Header.Set("Authorization", "token "+oidcAccessToken) + hreq.Header.Set("Accept", p.UserInfoAccept) + + hresp, err := http.DefaultClient.Do(hreq) + if err != nil { + return nil, "", err + } + defer hresp.Body.Close() + + ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) + if err != nil { + return nil, "", err + } + + var email string + var suggestedUserID string + switch ctype { + case "application/json": + body, err := ioutil.ReadAll(hresp.Body) + if err != nil { + return nil, "", err + } + + emailRes := gjson.GetBytes(body, p.UserInfoEmailPath) + if !emailRes.Exists() { + return nil, "", fmt.Errorf("no email in user info response body") + } + email = emailRes.String() + + // This is optional. + userIDRes := gjson.GetBytes(body, p.UserInfoSuggestedUserIDPath) + suggestedUserID = userIDRes.String() + + default: + return nil, "", fmt.Errorf("got unknown content type %q for user info", ctype) + } + + if email == "" { + return nil, "", fmt.Errorf("no email address in user info") + } + + return &userutil.ThirdPartyIdentifier{Medium: "email", Address: email}, suggestedUserID, nil +} + +type urlTemplate struct { + base *template.Template +} + +func parseURLTemplate(s string) (*urlTemplate, error) { + t, err := template.New("").Parse(s) + if err != nil { + return nil, err + } + return &urlTemplate{base: t}, nil +} + +func mustParseURLTemplate(s string) *urlTemplate { + t, err := parseURLTemplate(s) + if err != nil { + panic(err) + } + return t +} + +func (t *urlTemplate) Execute(params interface{}, defaultQuery url.Values) (*url.URL, error) { + var sb strings.Builder + err := t.base.Execute(&sb, params) + if err != nil { + return nil, err + } + + u, err := url.Parse(sb.String()) + if err != nil { + return nil, err + } + + if defaultQuery != nil { + q := u.Query() + for k, vs := range defaultQuery { + if q.Get(k) == "" { + q[k] = vs + } + } + u.RawQuery = q.Encode() + } + return u, nil +} diff --git a/clientapi/auth/sso/sso.go b/clientapi/auth/sso/sso.go new file mode 100644 index 0000000000..1b9215983d --- /dev/null +++ b/clientapi/auth/sso/sso.go @@ -0,0 +1,57 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "net/url" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/setup/config" +) + +type IdentityProvider interface { + DefaultBrand() string + + AuthorizationURL(context.Context, *IdentityProviderRequest) (string, error) + ProcessCallback(context.Context, *IdentityProviderRequest, url.Values) (*CallbackResult, error) +} + +type IdentityProviderRequest struct { + System *config.IdentityProvider + CallbackURL string + DendriteNonce string +} + +type CallbackResult struct { + RedirectURL string + Identifier userutil.Identifier + SuggestedUserID string +} + +type IdentityProviderType string + +const ( + TypeGitHub IdentityProviderType = config.SSOBrandGitHub +) + +func GetIdentityProvider(t IdentityProviderType) IdentityProvider { + switch t { + case TypeGitHub: + return GitHubIdentityProvider + default: + return nil + } +} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index e3effbe99f..f7277cd26f 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -19,6 +19,8 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/auth/sso" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -36,20 +38,54 @@ type loginResponse struct { } type flows struct { - Flows []flow `json:"flows"` + Flows []stage `json:"flows"` } -type flow struct { - Type string `json:"type"` +type stage struct { + Type string `json:"type"` + IdentityProviders []identityProvider `json:"identity_providers,omitempty"` } -func passwordLogin() flows { - f := flows{} - s := flow{ - Type: "m.login.password", +type identityProvider struct { + ID string `json:"id"` + Name string `json:"name"` + Brand string `json:"brand,omitempty"` + Icon string `json:"icon,omitempty"` +} + +func passwordLogin() []stage { + return []stage{ + {Type: authtypes.LoginTypePassword}, + } +} + +func ssoLogin(cfg *config.ClientAPI) []stage { + var idps []identityProvider + for _, idp := range cfg.Login.SSO.Providers { + brand := idp.Brand + if brand == "" { + typ := idp.Type + if typ == "" { + typ = idp.ID + } + idpType := sso.GetIdentityProvider(sso.IdentityProviderType(typ)) + if idpType != nil { + brand = idpType.DefaultBrand() + } + } + idps = append(idps, identityProvider{ + ID: idp.ID, + Name: idp.Name, + Brand: brand, + Icon: idp.Icon, + }) + } + return []stage{ + { + Type: authtypes.LoginTypeSSO, + IdentityProviders: idps, + }, } - f.Flows = append(f.Flows, s) - return f } // Login implements GET and POST /login @@ -58,10 +94,13 @@ func Login( cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { - // TODO: support other forms of login other than password, depending on config options + allFlows := passwordLogin() + if cfg.Login.SSO.Enabled { + allFlows = append(allFlows, ssoLogin(cfg)...) + } return util.JSONResponse{ Code: http.StatusOK, - JSON: passwordLogin(), + JSON: flows{Flows: allFlows}, } } else if req.Method == http.MethodPost { login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, accountDB, userAPI, cfg) @@ -73,6 +112,7 @@ func Login( cleanup(req.Context(), &authzErr) return authzErr } + return util.JSONResponse{ Code: http.StatusMethodNotAllowed, JSON: jsonerror.NotFound("Bad method"), diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index dd9e9535a7..acad0b16c3 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -502,6 +502,25 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + r0mux.Handle("/login/sso/callback", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + return SSOCallback(req, accountDB, userAPI, cfg) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/login/sso/redirect", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + return SSORedirect(req, "", cfg) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/login/sso/redirect/{idpID}", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + vars := mux.Vars(req) + return SSORedirect(req, vars["idpID"], cfg) + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/auth/{authType}/fallback/web", httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go new file mode 100644 index 0000000000..8014ba0558 --- /dev/null +++ b/clientapi/routing/sso.go @@ -0,0 +1,259 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/sso" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// SSORedirect implements /login/sso/redirect +// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-login-sso-redirect +func SSORedirect( + req *http.Request, + idpID string, + cfg *config.ClientAPI, +) util.JSONResponse { + if !cfg.Login.SSO.Enabled { + return util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.NotFound("authentication method disabled"), + } + } + + redirectURL := req.URL.Query().Get("redirectUrl") + if redirectURL == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("redirectUrl parameter missing"), + } + } + _, err := url.Parse(redirectURL) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()), + } + } + + if idpID == "" { + // Check configuration if the client didn't provide an ID. + idpID = cfg.Login.SSO.DefaultProviderID + } + if idpID == "" && len(cfg.Login.SSO.Providers) > 0 { + // Fall back to the first provider. If there are no providers, getProvider("") will fail. + idpID = cfg.Login.SSO.Providers[0].ID + } + idpCfg, idpType := getProvider(cfg, idpID) + if idpType == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), + } + } + + idpReq := &sso.IdentityProviderRequest{ + System: idpCfg, + CallbackURL: req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(), + DendriteNonce: formatNonce(redirectURL), + } + u, err := idpType.AuthorizationURL(req.Context(), idpReq) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + + resp := util.RedirectResponse(u) + resp.Headers["Set-Cookie"] = (&http.Cookie{ + Name: "oidc_nonce", + Value: idpReq.DendriteNonce, + Expires: time.Now().Add(10 * time.Minute), + Secure: true, + SameSite: http.SameSiteStrictMode, + }).String() + return resp +} + +func SSOCallback( + req *http.Request, + accountDB auth.AccountDatabase, + userAPI auth.UserInternalAPIForLogin, + cfg *config.ClientAPI, +) util.JSONResponse { + ctx := req.Context() + + query := req.URL.Query() + idpID := query.Get("provider") + if idpID == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("provider parameter missing"), + } + } + idpCfg, idpType := getProvider(cfg, idpID) + if idpType == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), + } + } + + nonce, err := req.Cookie("oidc_nonce") + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("no nonce cookie: " + err.Error()), + } + } + finalRedirectURL, err := parseNonce(nonce.Value) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: err, + } + } + + idpReq := &sso.IdentityProviderRequest{ + System: idpCfg, + CallbackURL: (&url.URL{Scheme: req.URL.Scheme, Host: req.URL.Host, Path: req.URL.Path, RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(), + DendriteNonce: nonce.Value, + } + result, err := idpType.ProcessCallback(ctx, idpReq, query) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + + if result.Identifier == nil { + // Not authenticated yet. + return util.RedirectResponse(result.RedirectURL) + } + + id, err := verifyUserIdentifier(ctx, accountDB, result.Identifier) + if err != nil { + util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier.String()).Error("failed to find user") + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.Forbidden("ID not associated with a local account"), + } + } + + token, err := createLoginToken(ctx, userAPI, id) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("PerformLoginTokenCreation failed") + return jsonerror.InternalServerError() + } + + rquery := finalRedirectURL.Query() + rquery.Set("loginToken", token.Token) + resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String()) + resp.Headers["Set-Cookie"] = (&http.Cookie{ + Name: "oidc_nonce", + Value: "", + MaxAge: -1, + Secure: true, + }).String() + return resp +} + +func getProvider(cfg *config.ClientAPI, id string) (*config.IdentityProvider, sso.IdentityProvider) { + for _, idp := range cfg.Login.SSO.Providers { + if idp.ID == id { + switch sso.IdentityProviderType(id) { + case sso.TypeGitHub: + return &idp, sso.GitHubIdentityProvider + default: + return nil, nil + } + } + } + return nil, nil +} + +func formatNonce(redirectURL string) string { + return util.RandomString(16) + "." + base64.RawURLEncoding.EncodeToString([]byte(redirectURL)) +} + +func parseNonce(s string) (redirectURL *url.URL, _ error) { + if s == "" { + return nil, jsonerror.MissingArgument("empty OIDC nonce cookie") + } + + ss := strings.Split(s, ".") + if len(ss) < 2 { + return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie") + } + + urlbs, err := base64.RawURLEncoding.DecodeString(ss[1]) + if err != nil { + return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie") + } + u, err := url.Parse(string(urlbs)) + if err != nil { + return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error()) + } + + return u, nil +} + +func verifyUserIdentifier(ctx context.Context, accountDB auth.AccountDatabase, id userutil.Identifier) (*userutil.UserIdentifier, error) { + var localpart string + switch iid := id.(type) { + case *userutil.ThirdPartyIdentifier: + var err error + localpart, err = accountDB.GetLocalpartForThreePID(ctx, iid.Address, iid.Medium) + if err != nil { + return nil, err + } + + case *userutil.UserIdentifier: + localpart = iid.UserID + + default: + return nil, fmt.Errorf("unsupported ID type: %T", id) + } + + acc, err := accountDB.GetAccountByLocalpart(ctx, localpart) + if err != nil { + return nil, err + } + return &userutil.UserIdentifier{UserID: acc.UserID}, nil +} + +func createLoginToken(ctx context.Context, userAPI auth.UserInternalAPIForLogin, id *userutil.UserIdentifier) (*uapi.LoginTokenMetadata, error) { + req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: id.UserID}} + var resp uapi.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &req, &resp); err != nil { + return nil, err + } + return &resp.Metadata, nil +} diff --git a/clientapi/userutil/identifier.go b/clientapi/userutil/identifier.go new file mode 100644 index 0000000000..46c8a0f020 --- /dev/null +++ b/clientapi/userutil/identifier.go @@ -0,0 +1,153 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userutil + +import ( + "bytes" + "encoding/json" + "errors" +) + +// An Identifier identifies a user. There are many kinds, and this is +// the common interface for them. +// +// If you need to handle an identifier as JSON, use the AnyIdentifier wrapper. +// Passing around identifiers in code, the raw Identifier is enough. +// +// See https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types +type Identifier interface { + // IdentifierType returns the identifier type, like "m.id.user". + IdentifierType() IdentifierType + + // String returns a debug-output string representation. The format + // is not specified. + String() string +} + +// A UserIdentifier contains an MXID. It may be only the local part. +type UserIdentifier struct { + UserID string `json:"user"` +} + +func (i *UserIdentifier) IdentifierType() IdentifierType { return IdentifierUser } +func (i *UserIdentifier) String() string { return i.UserID } + +// A ThirdPartyIdentifier references an identifier in another system. +type ThirdPartyIdentifier struct { + // Medium is normally MediumEmail. + Medium Medium `json:"medium"` + + // Address is the medium-specific identifier. + Address string `json:"address"` +} + +func (i *ThirdPartyIdentifier) IdentifierType() IdentifierType { return IdentifierThirdParty } +func (i *ThirdPartyIdentifier) String() string { return string(i.Medium) + ":" + i.Address } + +// A PhoneIdentifier references a phone number. +type PhoneIdentifier struct { + // Country is a ISO-3166-1 alpha-2 country code. + Country string `json:"country"` + + // PhoneNumber is a country-specific phone number, as it would be dialled from. + PhoneNumber string `json:"phone"` +} + +func (i *PhoneIdentifier) IdentifierType() IdentifierType { return IdentifierPhone } +func (i *PhoneIdentifier) String() string { return i.Country + ":" + i.PhoneNumber } + +// UnknownIdentifier is the catch-all for identifiers this code doesn't know about. +// It simply stores raw JSON. +type UnknownIdentifier struct { + json.RawMessage + Type IdentifierType +} + +func (i *UnknownIdentifier) IdentifierType() IdentifierType { return i.Type } +func (i *UnknownIdentifier) String() string { return "unknown/" + string(i.Type) } + +// AnyIdentifier is a wrapper that allows marshalling and unmarshalling the various +// types of identifiers to/from JSON. Always use this in data types that will be +// used in JSON manipulation. +type AnyIdentifier struct { + Identifier +} + +func (i AnyIdentifier) MarshalJSON() ([]byte, error) { + v := struct { + *UserIdentifier + *ThirdPartyIdentifier + *PhoneIdentifier + Type IdentifierType `json:"type"` + }{ + Type: i.Identifier.IdentifierType(), + } + switch iid := i.Identifier.(type) { + case *UserIdentifier: + v.UserIdentifier = iid + case *ThirdPartyIdentifier: + v.ThirdPartyIdentifier = iid + case *PhoneIdentifier: + v.PhoneIdentifier = iid + case *UnknownIdentifier: + return iid.RawMessage, nil + } + return json.Marshal(v) +} + +func (i *AnyIdentifier) UnmarshalJSON(bs []byte) error { + var hdr struct { + Type IdentifierType `json:"type"` + } + if err := json.Unmarshal(bs, &hdr); err != nil { + return err + } + switch hdr.Type { + case IdentifierUser: + var ui UserIdentifier + i.Identifier = &ui + return json.Unmarshal(bs, &ui) + case IdentifierThirdParty: + var tpi ThirdPartyIdentifier + i.Identifier = &tpi + return json.Unmarshal(bs, &tpi) + case IdentifierPhone: + var pi PhoneIdentifier + i.Identifier = &pi + return json.Unmarshal(bs, &pi) + case "": + return errors.New("missing identifier type") + default: + i.Identifier = &UnknownIdentifier{RawMessage: json.RawMessage(bytes.TrimSpace(bs)), Type: hdr.Type} + return nil + } +} + +// IdentifierType describes the type of identifier. +type IdentifierType string + +const ( + IdentifierUser IdentifierType = "m.id.user" + IdentifierThirdParty IdentifierType = "m.id.thirdparty" + IdentifierPhone IdentifierType = "m.id.phone" +) + +// Medium describes the interpretation of a third-party identifier. +type Medium string + +const ( + // MediumEmail signifies that the address is an email address. + MediumEmail Medium = "email" +) diff --git a/clientapi/userutil/identifier_test.go b/clientapi/userutil/identifier_test.go new file mode 100644 index 0000000000..cd02524c3b --- /dev/null +++ b/clientapi/userutil/identifier_test.go @@ -0,0 +1,61 @@ +package userutil + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestAnyIdentifierJSON(t *testing.T) { + tsts := []struct { + Name string + JSON string + Want Identifier + }{ + {Name: "empty", JSON: `{}`}, + {Name: "user", JSON: `{"type":"m.id.user","user":"auser"}`, Want: &UserIdentifier{UserID: "auser"}}, + {Name: "thirdparty", JSON: `{"type":"m.id.thirdparty","medium":"email","address":"auser@example.com"}`, Want: &ThirdPartyIdentifier{Medium: "email", Address: "auser@example.com"}}, + {Name: "phone", JSON: `{"type":"m.id.phone","country":"GB","phone":"123456789"}`, Want: &PhoneIdentifier{Country: "GB", PhoneNumber: "123456789"}}, + // This test is a little fragile since it compares the output of json.Marshal. + {Name: "unknown", JSON: `{"type":"other"}`, Want: &UnknownIdentifier{Type: "other", RawMessage: json.RawMessage(`{"type":"other"}`)}}, + } + for _, tst := range tsts { + t.Run("Unmarshal/"+tst.Name, func(t *testing.T) { + var got AnyIdentifier + if err := json.Unmarshal([]byte(tst.JSON), &got); err != nil { + if tst.Want == nil { + return + } + t.Fatalf("Unmarshal failed: %v", err) + } + + if !reflect.DeepEqual(got.Identifier, tst.Want) { + t.Errorf("got %+v, want %+v", got.Identifier, tst.Want) + } + }) + + if tst.Want == nil { + continue + } + t.Run("Marshal/"+tst.Name, func(t *testing.T) { + id := AnyIdentifier{Identifier: tst.Want} + bs, err := json.Marshal(id) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + t.Logf("Marshalled JSON: %q", string(bs)) + + var got AnyIdentifier + if err := json.Unmarshal(bs, &got); err != nil { + if tst.Want == nil { + return + } + t.Fatalf("Unmarshal failed: %v", err) + } + + if !reflect.DeepEqual(got.Identifier, tst.Want) { + t.Errorf("got %+v, want %+v", got.Identifier, tst.Want) + } + }) + } +} diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index c7cb9c33e0..cad284e20f 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -32,6 +32,8 @@ type ClientAPI struct { // was successful RecaptchaSiteVerifyAPI string `yaml:"recaptcha_siteverify_api"` + Login Login `yaml:"login"` + // TURN options TURN TURN `yaml:"turn"` @@ -53,6 +55,7 @@ func (c *ClientAPI) Defaults() { c.RecaptchaSiteVerifyAPI = "" c.RegistrationDisabled = false c.RateLimiting.Defaults() + c.Login.SSO.Enabled = false } func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -66,10 +69,130 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "client_api.recaptcha_private_key", string(c.RecaptchaPrivateKey)) checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI)) } + c.Login.Verify(configErrs) c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) } +type Login struct { + SSO SSO `yaml:"sso"` +} + +func (l *Login) Verify(configErrs *ConfigErrors) { + l.SSO.Verify(configErrs) +} + +type SSO struct { + // Enabled determines whether SSO should be allowed. + Enabled bool `yaml:"enabled"` + + // Providers list the identity providers this server is capable of confirming an + // identity with. + Providers []IdentityProvider `yaml:"providers"` + + // DefaultProviderID is the provider to use when the client doesn't indicate one. + // This is legacy support. If empty, the first provider listed is used. + DefaultProviderID string `yaml:"default_provider"` +} + +func (sso *SSO) Verify(configErrs *ConfigErrors) { + var foundDefaultProvider bool + seenPIDs := make(map[string]bool, len(sso.Providers)) + for _, p := range sso.Providers { + p.Verify(configErrs) + if p.ID == sso.DefaultProviderID { + foundDefaultProvider = true + } + if seenPIDs[p.ID] { + configErrs.Add(fmt.Sprintf("duplicate identity provider for config key %q: %s", "client_api.sso.providers", p.ID)) + } + seenPIDs[p.ID] = true + } + if sso.DefaultProviderID != "" && !foundDefaultProvider { + configErrs.Add(fmt.Sprintf("identity provider ID not found for config key %q: %s", "client_api.sso.default_provider", sso.DefaultProviderID)) + } + + if sso.Enabled { + if len(sso.Providers) == 0 { + configErrs.Add(fmt.Sprintf("empty list for config key %q", "client_api.sso.providers")) + } + } +} + +// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md. +type IdentityProvider struct { + // ID is the unique identifier of this IdP. We use the brand identifiers as provider + // identifiers for simplicity. + ID string `yaml:"id"` + + // Name is a human-friendly name of the provider. + Name string `yaml:"name"` + + // Brand is a hint on how to display the IdP to the user. If this is empty, a default + // based on the type is used. + Brand string `yaml:"brand"` + + // Icon is an MXC URI describing how to display the IdP to the user. Prefer using `brand`. + Icon string `yaml:"icon"` + + // Type describes how this provider is implemented. It must match "github". If this is + // empty, the ID is used, which means there is a weak expectation that ID is also a + // valid type, unless you have a complicated setup. + Type string `yaml:"type"` + + // OIDC contains settings for providers based on OpenID Connect (OAuth 2). + OIDC struct { + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + } `yaml:"oidc"` +} + +func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { + checkNotEmpty(configErrs, "client_api.sso.providers.id", idp.ID) + if !checkIdentityProviderBrand(idp.ID) { + configErrs.Add(fmt.Sprintf("unrecognized ID config key %q: %s", "client_api.sso.providers", idp.ID)) + } + checkNotEmpty(configErrs, "client_api.sso.providers.name", idp.Name) + if idp.Brand != "" && !checkIdentityProviderBrand(idp.Brand) { + configErrs.Add(fmt.Sprintf("unrecognized brand in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", idp.Brand)) + } + if idp.Icon != "" { + checkURL(configErrs, "client_api.sso.providers.icon", idp.Icon) + } + typ := idp.Type + if idp.Type == "" { + typ = idp.ID + } + + switch typ { + case "github": + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID) + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret) + + default: + configErrs.Add(fmt.Sprintf("unrecognized type in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", typ)) + } +} + +// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md. +func checkIdentityProviderBrand(s string) bool { + switch s { + case SSOBrandApple, SSOBrandFacebook, SSOBrandGitHub, SSOBrandGitLab, SSOBrandGoogle, SSOBrandTwitter: + return true + default: + return false + } +} + +const ( + SSOBrandApple = "apple" + SSOBrandFacebook = "facebook" + SSOBrandGitHub = "github" + SSOBrandGitLab = "gitlab" + SSOBrandGoogle = "google" + SSOBrandTwitter = "twitter" +) + type TURN struct { // TODO Guest Support // Whether or not guests can request TURN credentials diff --git a/sytest-whitelist b/sytest-whitelist index a63e90323f..ebe42ec3b4 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -259,6 +259,7 @@ Real non-joined users cannot room initalSync for non-world_readable rooms Push rules come down in an initial /sync Regular users can add and delete aliases in the default room configuration GET /r0/capabilities is not public +login types include SSO GET /joined_rooms lists newly-created room /joined_rooms returns only joined rooms Message history can be paginated over federation