From ef81a3e4506ed7ff08cc8043fbd31b1055779e86 Mon Sep 17 00:00:00 2001 From: Simon Skoczylas Date: Tue, 27 Aug 2024 16:21:20 +0200 Subject: [PATCH] Add OIDC discovery and empty user info --- cmd/stopnik/commands.go | 2 +- internal/config/config.go | 14 +- internal/config/config_test.go | 3 +- internal/endpoint/endpoint.go | 2 + internal/http/request.go | 48 +++++++ internal/manager/token.go | 28 ++-- internal/manager/token_test.go | 9 +- .../server/handler/authorize/authorize.go | 6 +- internal/server/handler/health/health_test.go | 3 +- .../handler/introspect/introspect_test.go | 15 +- internal/server/handler/metadata/metadata.go | 45 +----- internal/server/handler/oidc/discovery.go | 129 ++++++++++++++++++ internal/server/handler/oidc/userinfo.go | 36 +++++ internal/server/handler/revoke/revoke_test.go | 15 +- internal/server/handler/token/token.go | 2 +- internal/server/handler/token/token_test.go | 6 +- internal/server/server.go | 11 ++ 17 files changed, 289 insertions(+), 85 deletions(-) create mode 100644 internal/http/request.go create mode 100644 internal/server/handler/oidc/discovery.go create mode 100644 internal/server/handler/oidc/userinfo.go diff --git a/cmd/stopnik/commands.go b/cmd/stopnik/commands.go index ab1138e..fa36477 100644 --- a/cmd/stopnik/commands.go +++ b/cmd/stopnik/commands.go @@ -69,7 +69,7 @@ func readConfiguration(configurationFile *string, configLoader *config.Loader) ( } logger.SetLogLevel(currentConfig.Server.LogLevel) logger.Info("Config loaded from %s", *configurationFile) - if currentConfig.GetOIDC() { + if currentConfig.GetOidc() { logger.Info("OpenId Connect is enabled") } diff --git a/internal/config/config.go b/internal/config/config.go index 9dfd88a..4b992cc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "errors" "fmt" + internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/log" "math/big" ) @@ -80,7 +81,7 @@ type Client struct { Id string `yaml:"id"` ClientSecret string `yaml:"clientSecret"` Salt string `yaml:"salt"` - OIDC bool `yaml:"oidc"` + Oidc bool `yaml:"oidc"` ClientType string `yaml:"type"` AccessTTL int `yaml:"accessTTL"` RefreshTTL int `yaml:"refreshTTL"` @@ -220,7 +221,7 @@ func (config *Config) Setup() error { return errors.New(invalidClient) } - config.oidc = config.oidc || client.OIDC + config.oidc = config.oidc || client.Oidc } config.userMap = setup[User](&config.Users, func(user User) string { @@ -292,7 +293,7 @@ func (config *Config) GetFooterText() string { return GetOrDefaultString(config.UI.FooterText, "STOPnik") } -func (config *Config) GetOIDC() bool { +func (config *Config) GetOidc() bool { return config.oidc } @@ -304,8 +305,11 @@ func (client *Client) GetRefreshTTL() int { return GetOrDefaultInt(client.RefreshTTL, 0) } -func (client *Client) GetIssuer() string { - return GetOrDefaultString(client.Issuer, "STOPnik") +func (client *Client) GetIssuer(requestData *internalHttp.RequestData) string { + if requestData == nil || requestData.Host == "" || requestData.Scheme == "" { + return GetOrDefaultString(client.Issuer, "STOPnik") + } + return GetOrDefaultString(client.Issuer, requestData.IssuerString()) } func (client *Client) GetAudience() []string { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index baed1a0..d1144da 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -2,6 +2,7 @@ package config import ( "errors" + internalHttp "github.com/webishdev/stopnik/internal/http" "reflect" "testing" ) @@ -452,7 +453,7 @@ func assertClientValues(t *testing.T, config *Config, id string, expectedAccessT t.Errorf("expected refresh TTL to be %d, got %d", expectedRefreshTTL, refreshTTL) } - issuer := client.GetIssuer() + issuer := client.GetIssuer(&internalHttp.RequestData{}) if issuer != expectedIssuer { t.Errorf("expected issuer to be '%s', got '%s'", expectedIssuer, issuer) } diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index e0eda43..027044a 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -10,4 +10,6 @@ const ( Revoke string = "/revoke" Metadata string = "/.well-known/oauth-authorization-server" Keys string = "/keys" + OidcDiscovery string = "/.well-known/openid-configuration" + OidcUserInfo string = "/userinfo" ) diff --git a/internal/http/request.go b/internal/http/request.go new file mode 100644 index 0000000..3e08775 --- /dev/null +++ b/internal/http/request.go @@ -0,0 +1,48 @@ +package http + +import ( + "fmt" + "net/http" + "net/url" +) + +type RequestData struct { + Scheme string + Host string + Path string + Query string + Fragment string +} + +func NewRequestData(r *http.Request) *RequestData { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + query := "" + if r.URL.RawQuery != "" { + query = "?" + r.URL.RawQuery + } + fragment := "" + if r.URL.RawFragment != "" { + fragment = "#" + r.URL.RawFragment + } + return &RequestData{ + Scheme: scheme, + Host: r.Host, + Path: r.URL.RawPath, + Query: query, + Fragment: fragment, + } +} + +func (r *RequestData) IssuerString() string { + return fmt.Sprintf("%s://%s", r.Scheme, r.Host) +} + +func (r *RequestData) URL() (*url.URL, error) { + uri := fmt.Sprintf("%s://%s%s%s%s", r.Scheme, r.Host, r.Path, r.Query, r.Fragment) + + return url.Parse(uri) +} diff --git a/internal/manager/token.go b/internal/manager/token.go index b952925..8cf9f71 100644 --- a/internal/manager/token.go +++ b/internal/manager/token.go @@ -11,6 +11,7 @@ import ( "github.com/webishdev/stopnik/internal/oauth2" "github.com/webishdev/stopnik/internal/store" "github.com/webishdev/stopnik/log" + "net/http" "strings" "time" ) @@ -53,14 +54,15 @@ func (tokenManager *TokenManager) RevokeRefreshToken(refreshToken *oauth2.Refres refreshTokenStore.Delete(refreshToken.Key) } -func (tokenManager *TokenManager) CreateAccessTokenResponse(username string, client *config.Client, scopes []string, nonce string) oauth2.AccessTokenResponse { +func (tokenManager *TokenManager) CreateAccessTokenResponse(r *http.Request, username string, client *config.Client, scopes []string, nonce string) oauth2.AccessTokenResponse { log.Debug("Creating new access token for %s, access TTL %d, refresh TTL %d", client.Id, client.GetAccessTTL(), client.GetRefreshTTL()) + requestData := internalHttp.NewRequestData(r) accessTokenStore := *tokenManager.accessTokenStore refreshTokenStore := *tokenManager.refreshTokenStore accessTokenDuration := time.Minute * time.Duration(client.GetAccessTTL()) - accessTokenKey := tokenManager.generateAccessToken(username, client, accessTokenDuration) + accessTokenKey := tokenManager.generateAccessToken(requestData, username, client, accessTokenDuration) accessToken := &oauth2.AccessToken{ Key: accessTokenKey, TokenType: oauth2.TtBearer, @@ -79,7 +81,7 @@ func (tokenManager *TokenManager) CreateAccessTokenResponse(username string, cli if client.GetRefreshTTL() > 0 { refreshTokenDuration := time.Minute * time.Duration(client.GetRefreshTTL()) - refreshTokenKey := tokenManager.generateAccessToken(username, client, refreshTokenDuration) + refreshTokenKey := tokenManager.generateAccessToken(requestData, username, client, refreshTokenDuration) refreshToken := &oauth2.RefreshToken{ Key: refreshTokenKey, Username: username, @@ -92,11 +94,11 @@ func (tokenManager *TokenManager) CreateAccessTokenResponse(username string, cli accessTokenResponse.RefreshTokenKey = refreshTokenKey } - if client.OIDC { + if client.Oidc { user, userExists := tokenManager.config.GetUser(username) if userExists { idTokenDuration := time.Minute * time.Duration(client.GetAccessTTL()) - idTokenKey := tokenManager.generateIdToken(user, client, nonce, idTokenDuration) + idTokenKey := tokenManager.generateIdToken(requestData, user, client, nonce, idTokenDuration) accessTokenResponse.IdToken = idTokenKey } } @@ -126,17 +128,17 @@ func (tokenManager *TokenManager) ValidateAccessToken(authorizationHeader string return user, accessToken.Scopes, true } -func (tokenManager *TokenManager) generateIdToken(user *config.User, client *config.Client, nonce string, duration time.Duration) string { - idToken := generateIdToken(user, client, nonce, duration) +func (tokenManager *TokenManager) generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, duration time.Duration) string { + idToken := generateIdToken(requestData, user, client, nonce, duration) return tokenManager.generateJWTToken(client, idToken) } -func (tokenManager *TokenManager) generateAccessToken(username string, client *config.Client, duration time.Duration) string { +func (tokenManager *TokenManager) generateAccessToken(requestData *internalHttp.RequestData, username string, client *config.Client, duration time.Duration) string { tokenId := uuid.New() if client.OpaqueToken { return tokenManager.generateOpaqueAccessToken(tokenId.String()) } - accessToken := generateAccessToken(tokenId.String(), duration, username, client) + accessToken := generateAccessToken(requestData, tokenId.String(), duration, username, client) return tokenManager.generateJWTToken(client, accessToken) } @@ -187,7 +189,7 @@ func (tokenManager *TokenManager) generateJWTToken(client *config.Client, token } -func generateIdToken(user *config.User, client *config.Client, nonce string, duration time.Duration) jwt.Token { +func generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, duration time.Duration) jwt.Token { tokenId := uuid.New().String() builder := jwt.NewBuilder(). Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4 @@ -197,7 +199,7 @@ func generateIdToken(user *config.User, client *config.Client, nonce string, dur builder.JwtID(tokenId) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 - builder.Issuer(client.GetIssuer()) + builder.Issuer(client.GetIssuer(requestData)) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2 builder.Subject(user.Username) @@ -218,7 +220,7 @@ func generateIdToken(user *config.User, client *config.Client, nonce string, dur return token } -func generateAccessToken(tokenId string, duration time.Duration, username string, client *config.Client) jwt.Token { +func generateAccessToken(requestData *internalHttp.RequestData, tokenId string, duration time.Duration, username string, client *config.Client) jwt.Token { builder := jwt.NewBuilder(). Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4 IssuedAt(time.Now()) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6 @@ -232,7 +234,7 @@ func generateAccessToken(tokenId string, duration time.Duration, username string builder.JwtID(tokenId) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 - builder.Issuer(client.GetIssuer()) + builder.Issuer(client.GetIssuer(requestData)) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2 builder.Subject(username) diff --git a/internal/manager/token_test.go b/internal/manager/token_test.go index 3490551..342991c 100644 --- a/internal/manager/token_test.go +++ b/internal/manager/token_test.go @@ -3,8 +3,11 @@ package manager import ( "fmt" "github.com/webishdev/stopnik/internal/config" + "github.com/webishdev/stopnik/internal/endpoint" internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/oauth2" + "net/http" + "net/http/httptest" "reflect" "testing" ) @@ -33,7 +36,8 @@ func Test_Token(t *testing.T) { t.Fatal("client does not exist") } - accessTokenResponse := tokenManager.CreateAccessTokenResponse("foo", client, []string{"abc", "def"}, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, []string{"abc", "def"}, "") if accessTokenResponse.AccessTokenKey == "" { t.Error("empty access token") @@ -137,7 +141,8 @@ func Test_Token(t *testing.T) { t.Fatal("client does not exist") } - accessTokenResponse := tokenManager.CreateAccessTokenResponse("bar", client, []string{"abc", "def"}, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "bar", client, []string{"abc", "def"}, "") _, _, valid := tokenManager.ValidateAccessToken(fmt.Sprintf("%s %s", internalHttp.AuthBearer, accessTokenResponse.AccessTokenKey)) diff --git a/internal/server/handler/authorize/authorize.go b/internal/server/handler/authorize/authorize.go index 764b141..e6106f1 100644 --- a/internal/server/handler/authorize/authorize.go +++ b/internal/server/handler/authorize/authorize.go @@ -135,7 +135,7 @@ func (h *Handler) handleGetRequest(w http.ResponseWriter, r *http.Request) { State: state, } - if client.OIDC { + if client.Oidc { nonceQueryParameter := r.URL.Query().Get(oidc.ParameterNonce) authSession.Nonce = nonceQueryParameter } else { @@ -157,7 +157,7 @@ func (h *Handler) handleGetRequest(w http.ResponseWriter, r *http.Request) { query := redirectURL.Query() if slices.Contains(responseTypes, oauth2.RtToken) { - accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, authSession.Nonce) + accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(r, user.Username, client, scopes, authSession.Nonce) setImplicitGrantParameter(query, accessTokenResponse) } else if slices.Contains(responseTypes, oauth2.RtCode) { setAuthorizationGrantParameter(query, id.String()) @@ -225,7 +225,7 @@ func (h *Handler) handlePostRequest(w http.ResponseWriter, r *http.Request, user h.errorHandler.InternalServerErrorHandler(w, r) return } - accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(user.Username, client, authSession.Scopes, authSession.Nonce) + accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(r, user.Username, client, authSession.Scopes, authSession.Nonce) setImplicitGrantParameter(query, accessTokenResponse) } else if slices.Contains(responseTypes, oauth2.RtCode) { setAuthorizationGrantParameter(query, authSession.Id) diff --git a/internal/server/handler/health/health_test.go b/internal/server/handler/health/health_test.go index 0992095..f0eace2 100644 --- a/internal/server/handler/health/health_test.go +++ b/internal/server/handler/health/health_test.go @@ -72,7 +72,8 @@ func Test_Health(t *testing.T) { t.Error("client should exist") } - tokenResponse := tokenManager.CreateAccessTokenResponse("foo", client, []string{"a:foo", "b:bar"}, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + tokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, []string{"a:foo", "b:bar"}, "") healthHandler := NewHealthHandler(tokenManager) diff --git a/internal/server/handler/introspect/introspect_test.go b/internal/server/handler/introspect/introspect_test.go index 857b38a..e04e529 100644 --- a/internal/server/handler/introspect/introspect_test.go +++ b/internal/server/handler/introspect/introspect_test.go @@ -236,7 +236,8 @@ func testIntrospect(t *testing.T, testConfig *config.Config, keyManger *manager. sessionManager := manager.NewSessionManager(testConfig) tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "") introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager) @@ -253,7 +254,7 @@ func testIntrospect(t *testing.T, testConfig *config.Config, keyManger *manager. ) body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, body) + request = httptest.NewRequest(http.MethodPost, endpoint.Introspect, body) request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") @@ -308,7 +309,8 @@ func testIntrospectWithoutHint(t *testing.T, testConfig *config.Config, keyMange sessionManager := manager.NewSessionManager(testConfig) tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "") introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager) @@ -324,7 +326,7 @@ func testIntrospectWithoutHint(t *testing.T, testConfig *config.Config, keyMange ) body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, body) + request = httptest.NewRequest(http.MethodPost, endpoint.Introspect, body) request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") @@ -379,7 +381,8 @@ func testIntrospectDisabled(t *testing.T, testConfig *config.Config, keyManger * sessionManager := manager.NewSessionManager(testConfig) tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "") introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager) @@ -396,7 +399,7 @@ func testIntrospectDisabled(t *testing.T, testConfig *config.Config, keyManger * ) body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, body) + request = httptest.NewRequest(http.MethodPost, endpoint.Introspect, body) request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("bar", "bar"))) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") diff --git a/internal/server/handler/metadata/metadata.go b/internal/server/handler/metadata/metadata.go index fa8572b..69ce22f 100644 --- a/internal/server/handler/metadata/metadata.go +++ b/internal/server/handler/metadata/metadata.go @@ -1,7 +1,6 @@ package metadata import ( - "fmt" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/webishdev/stopnik/internal/endpoint" internalHttp "github.com/webishdev/stopnik/internal/http" @@ -10,7 +9,6 @@ import ( errorHandler "github.com/webishdev/stopnik/internal/server/handler/error" "github.com/webishdev/stopnik/log" "net/http" - "net/url" ) type response struct { @@ -38,14 +36,6 @@ type response struct { CodeChallengeMethodsSupported []pkce.CodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"` } -type requestData struct { - scheme string - host string - path string - query string - fragment string -} - type Handler struct { errorHandler *errorHandler.Handler } @@ -59,7 +49,7 @@ func NewMetadataHandler() *Handler { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.AccessLogRequest(r) if r.Method == http.MethodGet { - requestData := newRequestData(r) + requestData := internalHttp.NewRequestData(r) urlFromRequest, parseError := requestData.URL() if parseError != nil { h.errorHandler.InternalServerErrorHandler(w, r) @@ -132,36 +122,3 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } - -func newRequestData(r *http.Request) *requestData { - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - - query := "" - if r.URL.RawQuery != "" { - query = "?" + r.URL.RawQuery - } - fragment := "" - if r.URL.RawFragment != "" { - fragment = "#" + r.URL.RawFragment - } - return &requestData{ - scheme: scheme, - host: r.Host, - path: r.URL.RawPath, - query: query, - fragment: fragment, - } -} - -func (r *requestData) IssuerString() string { - return fmt.Sprintf("%s://%s", r.scheme, r.host) -} - -func (r *requestData) URL() (*url.URL, error) { - uri := fmt.Sprintf("%s://%s%s%s%s", r.scheme, r.host, r.path, r.query, r.fragment) - - return url.Parse(uri) -} diff --git a/internal/server/handler/oidc/discovery.go b/internal/server/handler/oidc/discovery.go new file mode 100644 index 0000000..4319e4f --- /dev/null +++ b/internal/server/handler/oidc/discovery.go @@ -0,0 +1,129 @@ +package oidc + +import ( + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/webishdev/stopnik/internal/endpoint" + internalHttp "github.com/webishdev/stopnik/internal/http" + "github.com/webishdev/stopnik/internal/oauth2" + "github.com/webishdev/stopnik/internal/pkce" + errorHandler "github.com/webishdev/stopnik/internal/server/handler/error" + "github.com/webishdev/stopnik/log" + "net/http" +) + +type response struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKsUri string `json:"jwks_uri,omitempty"` + UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + ResponseTypesSupported []oauth2.ResponseType `json:"response_types_supported,omitempty"` + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + GrantTypesSupported []oauth2.GrantType `json:"grant_types_supported,omitempty"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + TokenEndpointAuthSigningAlgValuesSupported []jwa.SignatureAlgorithm `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + ServiceDocumentation string `json:"service_documentation,omitempty"` + UILocalesSupported []string `json:"ui_locales_supported,omitempty"` + OpPolicyUri string `json:"op_policy_uri,omitempty"` + OpTosUri string `json:"op_tos_uri,omitempty"` + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + RevocationEndpointAuthSigningAlgValuesSupported []jwa.SignatureAlgorithm `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + IntrospectionEndpointAuthSigningAlgValuesSupported []jwa.SignatureAlgorithm `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` + CodeChallengeMethodsSupported []pkce.CodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"` +} + +type DiscoveryHandler struct { + errorHandler *errorHandler.Handler +} + +func NewOidcDiscoveryHandler() *DiscoveryHandler { + return &DiscoveryHandler{ + errorHandler: errorHandler.NewErrorHandler(), + } +} + +func (h *DiscoveryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.AccessLogRequest(r) + if r.Method == http.MethodGet { + requestData := internalHttp.NewRequestData(r) + urlFromRequest, parseError := requestData.URL() + if parseError != nil { + h.errorHandler.InternalServerErrorHandler(w, r) + return + } + + // OAuth2 + authorizationEndpoint := urlFromRequest.JoinPath(endpoint.Authorization) + tokenEndpoint := urlFromRequest.JoinPath(endpoint.Token) + + // OAuth2 extensions + introspectEndpoint := urlFromRequest.JoinPath(endpoint.Introspect) + revokeEndpoint := urlFromRequest.JoinPath(endpoint.Revoke) + keysEndpoint := urlFromRequest.JoinPath(endpoint.Keys) + + // OIDC 1.0 Core + userInfoEndpoint := urlFromRequest.JoinPath(endpoint.OidcUserInfo) + + authMethodsSupported := []string{ + "client_secret_basic", + "client_secret_post", + } + + signatureAlgorithmSupported := []jwa.SignatureAlgorithm{ + jwa.RS256, + jwa.ES256, + jwa.ES384, + jwa.ES512, + jwa.HS256, + } + + metadataResponse := &response{ + Issuer: requestData.IssuerString(), + AuthorizationEndpoint: authorizationEndpoint.String(), + TokenEndpoint: tokenEndpoint.String(), + IntrospectionEndpoint: introspectEndpoint.String(), + RevocationEndpoint: revokeEndpoint.String(), + JWKsUri: keysEndpoint.String(), + UserInfoEndpoint: userInfoEndpoint.String(), + ServiceDocumentation: "https://stopnik.webish.dev", + CodeChallengeMethodsSupported: []pkce.CodeChallengeMethod{ + pkce.PLAIN, + pkce.S256, + }, + GrantTypesSupported: []oauth2.GrantType{ + oauth2.GtAuthorizationCode, + oauth2.GtClientCredentials, + oauth2.GtPassword, + oauth2.GtRefreshToken, + oauth2.GtImplicit, + }, + ResponseTypesSupported: []oauth2.ResponseType{ + oauth2.RtCode, + oauth2.RtToken, + }, + ResponseModesSupported: []string{ + "query", + "fragment", + }, + TokenEndpointAuthMethodsSupported: authMethodsSupported, + TokenEndpointAuthSigningAlgValuesSupported: signatureAlgorithmSupported, + IntrospectionEndpointAuthMethodsSupported: authMethodsSupported, + IntrospectionEndpointAuthSigningAlgValuesSupported: signatureAlgorithmSupported, + RevocationEndpointAuthMethodsSupported: authMethodsSupported, + RevocationEndpointAuthSigningAlgValuesSupported: signatureAlgorithmSupported, + } + jsonError := internalHttp.SendJson(metadataResponse, w) + if jsonError != nil { + h.errorHandler.InternalServerErrorHandler(w, r) + return + } + } else { + h.errorHandler.MethodNotAllowedHandler(w, r) + return + } +} diff --git a/internal/server/handler/oidc/userinfo.go b/internal/server/handler/oidc/userinfo.go new file mode 100644 index 0000000..8a30b97 --- /dev/null +++ b/internal/server/handler/oidc/userinfo.go @@ -0,0 +1,36 @@ +package oidc + +import ( + internalHttp "github.com/webishdev/stopnik/internal/http" + errorHandler "github.com/webishdev/stopnik/internal/server/handler/error" + "github.com/webishdev/stopnik/log" + "net/http" +) + +type UserInfo struct { +} + +type UserInfoHandler struct { + errorHandler *errorHandler.Handler +} + +func NewOidcUserInfoHandler() *UserInfoHandler { + return &UserInfoHandler{ + errorHandler: errorHandler.NewErrorHandler(), + } +} + +func (h *UserInfoHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.AccessLogRequest(r) + if r.Method == http.MethodGet { + userInfoResponse := &UserInfo{} + jsonError := internalHttp.SendJson(userInfoResponse, w) + if jsonError != nil { + h.errorHandler.InternalServerErrorHandler(w, r) + return + } + } else { + h.errorHandler.MethodNotAllowedHandler(w, r) + return + } +} diff --git a/internal/server/handler/revoke/revoke_test.go b/internal/server/handler/revoke/revoke_test.go index c4460b6..df8c901 100644 --- a/internal/server/handler/revoke/revoke_test.go +++ b/internal/server/handler/revoke/revoke_test.go @@ -219,7 +219,8 @@ func testRevoke(t *testing.T, testConfig *config.Config, keyManager *manager.Key sessionManager := manager.NewSessionManager(testConfig) tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "") revokeHandler := NewRevokeHandler(testConfig, requestValidator, tokenManager) @@ -236,7 +237,7 @@ func testRevoke(t *testing.T, testConfig *config.Config, keyManager *manager.Key ) body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Revoke, body) + request = httptest.NewRequest(http.MethodPost, endpoint.Revoke, body) request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") @@ -296,7 +297,8 @@ func testRevokeWithoutHint(t *testing.T, testConfig *config.Config, keyManager * sessionManager := manager.NewSessionManager(testConfig) tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "") revokeHandler := NewRevokeHandler(testConfig, requestValidator, tokenManager) @@ -312,7 +314,7 @@ func testRevokeWithoutHint(t *testing.T, testConfig *config.Config, keyManager * ) body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Revoke, body) + request = httptest.NewRequest(http.MethodPost, endpoint.Revoke, body) request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") @@ -372,7 +374,8 @@ func testRevokeDisabled(t *testing.T, testConfig *config.Config, keyManager *man sessionManager := manager.NewSessionManager(testConfig) tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "") revokeHandler := NewRevokeHandler(testConfig, requestValidator, tokenManager) @@ -389,7 +392,7 @@ func testRevokeDisabled(t *testing.T, testConfig *config.Config, keyManager *man ) body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Revoke, body) + request = httptest.NewRequest(http.MethodPost, endpoint.Revoke, body) request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("bar", "bar"))) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") diff --git a/internal/server/handler/token/token.go b/internal/server/handler/token/token.go index 48a2cc0..e6910b5 100644 --- a/internal/server/handler/token/token.go +++ b/internal/server/handler/token/token.go @@ -129,7 +129,7 @@ func (h *Handler) handlePostRequest(w http.ResponseWriter, r *http.Request) { return } - accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(username, client, scopes, nonce) + accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(r, username, client, scopes, nonce) jsonError := internalHttp.SendJson(accessTokenResponse, w) if jsonError != nil { diff --git a/internal/server/handler/token/token_test.go b/internal/server/handler/token/token_test.go index 89c8f15..9b81dde 100644 --- a/internal/server/handler/token/token_test.go +++ b/internal/server/handler/token/token_test.go @@ -401,7 +401,9 @@ func testTokenRefreshTokenGrantType(t *testing.T, testConfig *config.Config, key sessionManager := manager.NewSessionManager(testConfig) tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "") + + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "") tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -413,7 +415,7 @@ func testTokenRefreshTokenGrantType(t *testing.T, testConfig *config.Config, key ) body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) + request = httptest.NewRequest(http.MethodPost, endpoint.Token, body) request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") diff --git a/internal/server/server.go b/internal/server/server.go index 9942c24..4d342bb 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "github.com/webishdev/stopnik/internal/server/handler/keys" "github.com/webishdev/stopnik/internal/server/handler/logout" "github.com/webishdev/stopnik/internal/server/handler/metadata" + "github.com/webishdev/stopnik/internal/server/handler/oidc" "github.com/webishdev/stopnik/internal/server/handler/revoke" "github.com/webishdev/stopnik/internal/server/handler/token" "github.com/webishdev/stopnik/internal/server/validation" @@ -197,6 +198,10 @@ func registerHandlers(config *config.Config, handle func(pattern string, handler metadataHandler := metadata.NewMetadataHandler() keysHandler := keys.NewKeysHandler(keyManger, config) + // Oidc 1.0 Core + discoveryHandler := oidc.NewOidcDiscoveryHandler() + userInfoHandler := oidc.NewOidcUserInfoHandler() + // Server handle(endpoint.Health, healthHandler) handle(endpoint.Account, accountHandler) @@ -211,4 +216,10 @@ func registerHandlers(config *config.Config, handle func(pattern string, handler handle(endpoint.Revoke, revokeHandler) handle(endpoint.Metadata, metadataHandler) handle(endpoint.Keys, keysHandler) + + // Oidc 1.0 Core + if config.GetOidc() { + handle(endpoint.OidcDiscovery, discoveryHandler) + handle(endpoint.OidcUserInfo, userInfoHandler) + } }