From 51d16a388155a2fd089ef6cb19523cd5efc894de Mon Sep 17 00:00:00 2001 From: Simon Skoczylas Date: Thu, 5 Sep 2024 17:03:39 +0200 Subject: [PATCH] Add tests, improve documentation, improve ForwardAuth --- internal/config/config.go | 25 +- internal/config/config_test.go | 15 +- internal/http/header.go | 3 + internal/manager/token/token.go | 22 +- internal/manager/token/token_test.go | 102 ++- internal/oauth2/token.go | 10 +- .../server/handler/account/account_test.go | 118 ++-- internal/server/handler/assets/assets_test.go | 16 +- .../server/handler/authorize/authorize.go | 4 +- .../handler/authorize/authorize_test.go | 637 +++++++++--------- .../server/handler/forwardauth/forwardauth.go | 20 +- .../handler/forwardauth/forwardauth_test.go | 113 ++++ internal/server/handler/health/health_test.go | 90 ++- .../handler/introspect/introspect_test.go | 136 ++-- internal/server/handler/keys/keys_test.go | 73 +- internal/server/handler/logout/logout_test.go | 12 +- .../server/handler/metadata/metadata_test.go | 69 +- .../server/handler/oidc/discovery_test.go | 69 +- internal/server/handler/oidc/userinfo_test.go | 9 +- internal/server/handler/revoke/revoke_test.go | 86 +-- internal/server/handler/token/token_test.go | 288 ++++---- internal/server/server.go | 2 +- website/docs/introduction/config.md | 220 +++--- 23 files changed, 1092 insertions(+), 1047 deletions(-) create mode 100644 internal/server/handler/forwardauth/forwardauth_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 28b6423..4ff232f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,8 +32,9 @@ type Cookies struct { } type ForwardAuth struct { - Endpoint string `yaml:"endpoint"` - ExternalUrl string `yaml:"externalUrl"` + Endpoint string `yaml:"endpoint"` + ExternalUrl string `yaml:"externalUrl"` + ParameterName string `yaml:"parameterName"` } type Server struct { @@ -47,6 +48,7 @@ type Server struct { IntrospectScope string `yaml:"introspectScope"` RevokeScope string `yaml:"revokeScopeScope"` SessionTimeoutSeconds int `yaml:"sessionTimeoutSeconds"` + Issuer string `yaml:"issuer"` ForwardAuth ForwardAuth `yaml:"forwardAuth"` } @@ -108,7 +110,6 @@ type Client struct { OpaqueToken bool `yaml:"opaqueToken"` PasswordFallbackAllowed bool `yaml:"passwordFallbackAllowed"` Claims []Claim `yaml:"claims"` - Issuer string `yaml:"issuer"` Audience []string `yaml:"audience"` PrivateKey string `yaml:"privateKey"` RolesClaim string `yaml:"rolesClaim"` @@ -295,6 +296,13 @@ func (config *Config) GetOidc() bool { return config.oidc } +func (config *Config) GetIssuer(requestData *internalHttp.RequestData) string { + if requestData == nil || requestData.Host == "" || requestData.Scheme == "" { + return GetOrDefaultString(config.Server.Issuer, "STOPnik") + } + return GetOrDefaultString(config.Server.Issuer, requestData.IssuerString()) +} + func (config *Config) GetForwardAuthEnabled() bool { return config.Server.ForwardAuth.ExternalUrl != "" } @@ -303,6 +311,10 @@ func (config *Config) GetForwardAuthEndpoint() string { return GetOrDefaultString(config.Server.ForwardAuth.Endpoint, "/forward") } +func (config *Config) GetForwardAuthParameterName() string { + return GetOrDefaultString(config.Server.ForwardAuth.ParameterName, "forward_id") +} + func (client *Client) GetRolesClaim() string { return GetOrDefaultString(client.RolesClaim, "roles") } @@ -319,13 +331,6 @@ func (client *Client) GetIdTTL() int { return GetOrDefaultInt(client.IdTTL, 0) } -func (client *Client) GetIssuer(requestData *internalHttp.RequestData) string { - if requestData == nil || requestData.Host == "" || requestData.Scheme == "" { - return GetOrDefaultString(client.Issuer, "STOPnik") - } - return GetOrDefaultString(client.Issuer, requestData.IssuerString()) -} - func (client *Client) GetAudience() []string { return GetOrDefaultStringSlice(client.Audience, []string{"all"}) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 58044d4..ff248ea 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -163,6 +163,7 @@ func Test_SimpleServerConfiguration(t *testing.T) { *origin = Config{ Server: Server{ Secret: "5XyLSgKpo5kWrJqm", + Issuer: "http://foo.com/bar", Cookies: Cookies{ AuthName: "my_auth", }, @@ -199,6 +200,11 @@ func Test_SimpleServerConfiguration(t *testing.T) { t.Error("expected server secret to be '5XyLSgKpo5kWrJqm'") } + issuer := config.GetIssuer(&internalHttp.RequestData{}) + if issuer != "http://foo.com/bar" { + t.Error("expected issuer to be 'http://foo.com/bar'") + } + authCookieName := config.GetAuthCookieName() if authCookieName != "my_auth" { t.Error("expected auth cookie name to be 'my_auth'") @@ -431,6 +437,9 @@ func Test_ValidClients(t *testing.T) { }, func(in []byte, out interface{}) (err error) { origin := out.(*Config) *origin = Config{ + Server: Server{ + Issuer: "other", + }, Clients: []Client{ { Id: "foo", @@ -444,7 +453,6 @@ func Test_ValidClients(t *testing.T) { AccessTTL: 20, RefreshTTL: 60, IdTTL: 40, - Issuer: "other", RolesClaim: "groups", Audience: []string{"one", "two"}, }, @@ -652,11 +660,6 @@ func assertClientValues(t *testing.T, config *Config, expected testExpectedClien t.Errorf("expected id token TTL to be %d, got %d", expected.expectedIdTokenTTL, idTokenTTL) } - issuer := client.GetIssuer(&internalHttp.RequestData{}) - if issuer != expected.expectedIssuer { - t.Errorf("expected issuer to be '%s', got '%s'", expected.expectedIssuer, issuer) - } - audience := client.GetAudience() if !reflect.DeepEqual(audience, expected.expectedAudience) { t.Errorf("expected audience to be '%s', got '%s'", expected.expectedAudience, audience) diff --git a/internal/http/header.go b/internal/http/header.go index ee71a90..81f74c1 100644 --- a/internal/http/header.go +++ b/internal/http/header.go @@ -5,6 +5,9 @@ const ( ContentType string = "Content-Type" Authorization string = "Authorization" AccessControlAllowOrigin string = "Access-Control-Allow-Origin" + XForwardProtocol string = "X-Forwarded-Proto" + XForwardHost string = "X-Forwarded-Host" + XForwardUri string = "X-Forwarded-Uri" ) const ( diff --git a/internal/manager/token/token.go b/internal/manager/token/token.go index e99cd35..e753022 100644 --- a/internal/manager/token/token.go +++ b/internal/manager/token/token.go @@ -90,9 +90,9 @@ func (tokenManager *Manager) CreateAccessTokenResponse(r *http.Request, username accessTokenStore.SetWithDuration(accessToken.Key, accessToken, accessTokenDuration) accessTokenResponse := oauth2.AccessTokenResponse{ - AccessTokenKey: accessToken.Key, - TokenType: oauth2.TtBearer, - ExpiresIn: int(accessTokenDuration / time.Second), + AccessTokenValue: accessToken.Key, + TokenType: oauth2.TtBearer, + ExpiresIn: int(accessTokenDuration / time.Second), } if client.GetRefreshTTL() > 0 { @@ -107,14 +107,14 @@ func (tokenManager *Manager) CreateAccessTokenResponse(r *http.Request, username refreshTokenStore.SetWithDuration(refreshToken.Key, refreshToken, refreshTokenDuration) - accessTokenResponse.RefreshTokenKey = refreshToken.Key + accessTokenResponse.RefreshTokenValue = refreshToken.Key } if client.Oidc && oidc.HasOidcScope(scopes) { user, userExists := tokenManager.config.GetUser(username) if userExists { accessTokenHash := tokenManager.CreateAccessTokenHash(client, accessToken.Key) - accessTokenResponse.IdToken = tokenManager.CreateIdToken(r, user.Username, client, scopes, nonce, accessTokenHash) + accessTokenResponse.IdTokenValue = tokenManager.CreateIdToken(r, user.Username, client, scopes, nonce, accessTokenHash) } } @@ -178,7 +178,7 @@ func (tokenManager *Manager) ValidateAccessToken(authorizationHeader string) (*c } func (tokenManager *Manager) generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, atHash string, duration time.Duration) string { - idToken := generateIdToken(requestData, user, client, nonce, atHash, duration) + idToken := generateIdToken(requestData, tokenManager.config, user, client, nonce, atHash, duration) return tokenManager.generateJWTToken(client, idToken) } @@ -187,7 +187,7 @@ func (tokenManager *Manager) generateAccessToken(requestData *internalHttp.Reque if client.OpaqueToken { return tokenManager.generateOpaqueAccessToken(tokenId.String()) } - accessToken := generateAccessToken(requestData, tokenId.String(), duration, username, client) + accessToken := generateAccessToken(requestData, tokenManager.config, client, tokenId.String(), duration, username) return tokenManager.generateJWTToken(client, accessToken) } @@ -238,7 +238,7 @@ func (tokenManager *Manager) generateJWTToken(client *config.Client, token jwt.T } -func generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, atHash string, duration time.Duration) jwt.Token { +func generateIdToken(requestData *internalHttp.RequestData, config *config.Config, user *config.User, client *config.Client, nonce string, atHash string, duration time.Duration) jwt.Token { tokenId := uuid.New().String() builder := jwt.NewBuilder(). Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4 @@ -248,7 +248,7 @@ func generateIdToken(requestData *internalHttp.RequestData, user *config.User, c builder.JwtID(tokenId) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 - builder.Issuer(client.GetIssuer(requestData)) + builder.Issuer(config.GetIssuer(requestData)) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2 builder.Subject(user.Username) @@ -274,7 +274,7 @@ func generateIdToken(requestData *internalHttp.RequestData, user *config.User, c return token } -func generateAccessToken(requestData *internalHttp.RequestData, tokenId string, duration time.Duration, username string, client *config.Client) jwt.Token { +func generateAccessToken(requestData *internalHttp.RequestData, config *config.Config, client *config.Client, tokenId string, duration time.Duration, username string) jwt.Token { builder := jwt.NewBuilder(). Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4 IssuedAt(time.Now()) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6 @@ -288,7 +288,7 @@ func generateAccessToken(requestData *internalHttp.RequestData, tokenId string, builder.JwtID(tokenId) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 - builder.Issuer(client.GetIssuer(requestData)) + builder.Issuer(config.GetIssuer(requestData)) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2 builder.Subject(username) diff --git a/internal/manager/token/token_test.go b/internal/manager/token/token_test.go index 8ec7b7a..7d806f5 100644 --- a/internal/manager/token/token_test.go +++ b/internal/manager/token/token_test.go @@ -7,55 +7,77 @@ import ( "github.com/webishdev/stopnik/internal/endpoint" internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/oauth2" + "github.com/webishdev/stopnik/internal/oidc" "net/http" "net/http/httptest" "reflect" "testing" ) -func Test_Token(t *testing.T) { +func Test_AccessTokenResponse(t *testing.T) { type tokenTestParameter struct { opaque bool refreshTokenTTL int + idTTL int } var opaqueTokenParameter = []tokenTestParameter{ - {true, 0}, - {false, 0}, - {true, 100}, - {false, 100}, + {true, 0, 0}, + {false, 0, 0}, + {true, 100, 0}, + {false, 100, 0}, + {false, 0, 100}, + {false, 100, 100}, + {false, 100, 100}, } for _, test := range opaqueTokenParameter { - testMessage := fmt.Sprintf("Valid token opaque %t refreshTTL %d", test.opaque, test.refreshTokenTTL) + testMessage := fmt.Sprintf("Valid token opaque %t refreshTTL %d idTTL %d", test.opaque, test.refreshTokenTTL, test.idTTL) t.Run(testMessage, func(t *testing.T) { - testConfig := createTestConfig(t, test.opaque, test.refreshTokenTTL) + testConfig := createTestConfig(t, test.opaque, test.refreshTokenTTL, test.idTTL, "../../../.test_files/ecdsa521key.pem") tokenManager := GetTokenManagerInstance() client, clientExists := testConfig.GetClient("foo") if !clientExists { t.Fatal("client does not exist") } + requestScopes := []string{"abc", "def"} + if test.idTTL > 0 { + requestScopes = append(requestScopes, oidc.ScopeOpenId) + } + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, []string{"abc", "def"}, "") + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, requestScopes, "") - if accessTokenResponse.AccessTokenKey == "" { + if accessTokenResponse.AccessTokenValue == "" { t.Error("empty access token") } - if test.refreshTokenTTL == 0 && accessTokenResponse.RefreshTokenKey != "" { + if test.refreshTokenTTL == 0 && accessTokenResponse.RefreshTokenValue != "" { t.Error("refresh token should not exists") } - if test.refreshTokenTTL > 0 && accessTokenResponse.RefreshTokenKey == "" { + if test.refreshTokenTTL > 0 && accessTokenResponse.RefreshTokenValue == "" { t.Error("refresh token should exists") } + if test.idTTL > 0 && !client.Oidc { + t.Error("client should be configured for OIDC because of id token") + } + + if test.idTTL == 0 && accessTokenResponse.IdTokenValue != "" { + t.Error("id token should not exists") + } + + if test.idTTL > 0 && accessTokenResponse.IdTokenValue == "" { + t.Error("id token should exists") + } + if accessTokenResponse.TokenType != oauth2.TtBearer { t.Error("wrong token type") } - authorizationHeader := fmt.Sprintf("%s %s", internalHttp.AuthBearer, accessTokenResponse.AccessTokenKey) + authorizationHeader := fmt.Sprintf("%s %s", internalHttp.AuthBearer, accessTokenResponse.AccessTokenValue) user, _, scopes, valid := tokenManager.ValidateAccessToken(authorizationHeader) if !valid { @@ -66,21 +88,21 @@ func Test_Token(t *testing.T) { t.Error("wrong username") } - if !reflect.DeepEqual(scopes, []string{"abc", "def"}) { - t.Errorf("assertion error, %v != %v", scopes, []string{"abc", "def"}) + if !reflect.DeepEqual(scopes, requestScopes) { + t.Errorf("assertion error, %v != %v", scopes, requestScopes) } - accessToken, accessTokenExists := tokenManager.GetAccessToken(accessTokenResponse.AccessTokenKey) + accessToken, accessTokenExists := tokenManager.GetAccessToken(accessTokenResponse.AccessTokenValue) if !accessTokenExists { t.Error("access token does not exist") } - if accessToken.Key != accessTokenResponse.AccessTokenKey { + if accessToken.Key != accessTokenResponse.AccessTokenValue { t.Error("wrong access token") } - refreshToken, refreshTokenExists := tokenManager.GetRefreshToken(accessTokenResponse.RefreshTokenKey) + refreshToken, refreshTokenExists := tokenManager.GetRefreshToken(accessTokenResponse.RefreshTokenValue) if test.refreshTokenTTL == 0 && refreshTokenExists { t.Error("refresh token should not exists") @@ -90,58 +112,60 @@ func Test_Token(t *testing.T) { t.Error("refresh token should exists") } - if test.refreshTokenTTL > 0 && refreshToken.Key != accessTokenResponse.RefreshTokenKey { + if test.refreshTokenTTL > 0 && refreshToken.Key != accessTokenResponse.RefreshTokenValue { t.Error("wrong refresh token") } tokenManager.RevokeRefreshToken(refreshToken) - _, refreshTokenExists = tokenManager.GetRefreshToken(accessTokenResponse.RefreshTokenKey) + _, refreshTokenExists = tokenManager.GetRefreshToken(accessTokenResponse.RefreshTokenValue) if refreshTokenExists { t.Error("refresh token should not exists") } tokenManager.RevokeAccessToken(accessToken) - _, accessTokenExists = tokenManager.GetAccessToken(accessTokenResponse.AccessTokenKey) + _, accessTokenExists = tokenManager.GetAccessToken(accessTokenResponse.AccessTokenValue) if accessTokenExists { t.Error("access token should not exists") } }) } - t.Run("Invalid HTTP Authorization header", func(t *testing.T) { - createTestConfig(t, false, 0) + t.Run("Invalid User in token", func(t *testing.T) { + testConfig := createTestConfig(t, false, 0, 0, "") tokenManager := GetTokenManagerInstance() + client, clientExists := testConfig.GetClient("foo") + if !clientExists { + t.Fatal("client does not exist") + } - _, _, _, valid := tokenManager.ValidateAccessToken("foooo") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "bar", client, []string{"abc", "def"}, "") + + _, _, _, valid := tokenManager.ValidateAccessToken(fmt.Sprintf("%s %s", internalHttp.AuthBearer, accessTokenResponse.AccessTokenValue)) if valid { t.Error("should not be valid") } }) +} - t.Run("Invalid Token value", func(t *testing.T) { - createTestConfig(t, false, 0) +func Test_ValidAccessToken(t *testing.T) { + t.Run("Invalid HTTP Authorization header", func(t *testing.T) { + createTestConfig(t, false, 0, 0, "") tokenManager := GetTokenManagerInstance() - _, _, _, valid := tokenManager.ValidateAccessToken(fmt.Sprintf("%s %s", internalHttp.AuthBearer, "foo")) + _, _, _, valid := tokenManager.ValidateAccessToken("foooo") if valid { t.Error("should not be valid") } }) - t.Run("Invalid User in token", func(t *testing.T) { - testConfig := createTestConfig(t, false, 0) + t.Run("Invalid Token value", func(t *testing.T) { + createTestConfig(t, false, 0, 0, "") tokenManager := GetTokenManagerInstance() - client, clientExists := testConfig.GetClient("foo") - if !clientExists { - t.Fatal("client does not exist") - } - request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) - accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "bar", client, []string{"abc", "def"}, "") - - _, _, _, valid := tokenManager.ValidateAccessToken(fmt.Sprintf("%s %s", internalHttp.AuthBearer, accessTokenResponse.AccessTokenKey)) + _, _, _, valid := tokenManager.ValidateAccessToken(fmt.Sprintf("%s %s", internalHttp.AuthBearer, "foo")) if valid { t.Error("should not be valid") @@ -174,7 +198,8 @@ func Test_HashToken(t *testing.T) { } -func createTestConfig(t *testing.T, opaque bool, refreshTokenTTL int) *config.Config { +func createTestConfig(t *testing.T, opaque bool, refreshTokenTTL int, idTTokenTTL int, keyPath string) *config.Config { + var isOidc = idTTokenTTL > 0 testConfig := &config.Config{ Clients: []config.Client{ { @@ -183,6 +208,9 @@ func createTestConfig(t *testing.T, opaque bool, refreshTokenTTL int) *config.Co Redirects: []string{"https://example.com/callback"}, OpaqueToken: opaque, RefreshTTL: refreshTokenTTL, + IdTTL: idTTokenTTL, + Oidc: isOidc, + PrivateKey: keyPath, }, }, Users: []config.User{ diff --git a/internal/oauth2/token.go b/internal/oauth2/token.go index adf2ed9..fe8fad6 100644 --- a/internal/oauth2/token.go +++ b/internal/oauth2/token.go @@ -16,9 +16,9 @@ type RefreshToken struct { // AccessTokenResponse as described in https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.4 type AccessTokenResponse struct { - AccessTokenKey string `json:"access_token,omitempty"` - TokenType TokenType `json:"token_type,omitempty"` - ExpiresIn int `json:"expires_in,omitempty"` // seconds - RefreshTokenKey string `json:"refresh_token,omitempty"` - IdToken string `json:"id_token,omitempty"` // https://openid.net/specs/openid-connect-core-1_0.html#IDToken + AccessTokenValue string `json:"access_token,omitempty"` + TokenType TokenType `json:"token_type,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` // seconds + RefreshTokenValue string `json:"refresh_token,omitempty"` + IdTokenValue string `json:"id_token,omitempty"` // https://openid.net/specs/openid-connect-core-1_0.html#IDToken } diff --git a/internal/server/handler/account/account_test.go b/internal/server/handler/account/account_test.go index 7e6f5fa..e4f4331 100644 --- a/internal/server/handler/account/account_test.go +++ b/internal/server/handler/account/account_test.go @@ -15,7 +15,7 @@ import ( "testing" ) -func Test_Account(t *testing.T) { +func Test_AccountWithCookie(t *testing.T) { testConfig := &config.Config{ Clients: []config.Client{ @@ -38,92 +38,78 @@ func Test_Account(t *testing.T) { t.Fatal(initializationError) } - testAccountWithoutCookie(t, testConfig) + requestValidator := validation.NewRequestValidator() + cookieManager := cookie.GetCookieManagerInstance() + templateManager := template.GetTemplateManagerInstance() - testAccountWithCookie(t, testConfig) + user, _ := testConfig.GetUser("foo") + authCookie, _ := cookieManager.CreateAuthCookie(user.Username) - testAccountLogin(t, testConfig) + accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) - testAccountNotAllowedHttpMethods(t) -} - -func testAccountWithoutCookie(t *testing.T, testConfig *config.Config) { - t.Run("Account without cookie", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - cookieManager := cookie.GetCookieManagerInstance() - templateManager := template.GetTemplateManagerInstance() - - accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) + rr := httptest.NewRecorder() - rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, endpoint.Account, nil) + request.AddCookie(&authCookie) - accountHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, endpoint.Account, nil)) + accountHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } - contentType := rr.Header().Get(internalHttp.ContentType) + contentType := rr.Header().Get(internalHttp.ContentType) - if contentType != "text/html; charset=utf-8" { - t.Errorf("content type was not text/html: %v", contentType) - } + if contentType != "text/html; charset=utf-8" { + t.Errorf("content type was not text/html: %v", contentType) + } - response := rr.Result() - body, bodyReadErr := io.ReadAll(response.Body) + response := rr.Result() + body, bodyReadErr := io.ReadAll(response.Body) - if bodyReadErr != nil { - t.Errorf("could not read response body: %v", bodyReadErr) - } + if bodyReadErr != nil { + t.Errorf("could not read response body: %v", bodyReadErr) + } - if body == nil { - t.Errorf("response body was nil") - } - }) + if body == nil { + t.Errorf("response body was nil") + } } -func testAccountWithCookie(t *testing.T, testConfig *config.Config) { - t.Run("Account cookie", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - cookieManager := cookie.GetCookieManagerInstance() - templateManager := template.GetTemplateManagerInstance() +func Test_AccountWithoutCookie(t *testing.T) { + requestValidator := validation.NewRequestValidator() + cookieManager := cookie.GetCookieManagerInstance() + templateManager := template.GetTemplateManagerInstance() - user, _ := testConfig.GetUser("foo") - cookie, _ := cookieManager.CreateAuthCookie(user.Username) + accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) - accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) + rr := httptest.NewRecorder() - rr := httptest.NewRecorder() + accountHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, endpoint.Account, nil)) - request := httptest.NewRequest(http.MethodGet, endpoint.Account, nil) - request.AddCookie(&cookie) - - accountHandler.ServeHTTP(rr, request) - - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } - contentType := rr.Header().Get(internalHttp.ContentType) + contentType := rr.Header().Get(internalHttp.ContentType) - if contentType != "text/html; charset=utf-8" { - t.Errorf("content type was not text/html: %v", contentType) - } + if contentType != "text/html; charset=utf-8" { + t.Errorf("content type was not text/html: %v", contentType) + } - response := rr.Result() - body, bodyReadErr := io.ReadAll(response.Body) + response := rr.Result() + body, bodyReadErr := io.ReadAll(response.Body) - if bodyReadErr != nil { - t.Errorf("could not read response body: %v", bodyReadErr) - } + if bodyReadErr != nil { + t.Errorf("could not read response body: %v", bodyReadErr) + } - if body == nil { - t.Errorf("response body was nil") - } - }) + if body == nil { + t.Errorf("response body was nil") + } } -func testAccountLogin(t *testing.T, testConfig *config.Config) { +func Test_AccountLogin(t *testing.T) { type loginParameter struct { username string password string @@ -140,7 +126,7 @@ func testAccountLogin(t *testing.T, testConfig *config.Config) { cookieManager := cookie.GetCookieManagerInstance() templateManager := template.GetTemplateManagerInstance() - cookie, _ := cookieManager.CreateAuthCookie(test.username) + authCookie, _ := cookieManager.CreateAuthCookie(test.username) accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) @@ -154,7 +140,7 @@ func testAccountLogin(t *testing.T, testConfig *config.Config) { request := httptest.NewRequest(http.MethodPost, endpoint.Account, body) request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") - request.AddCookie(&cookie) + request.AddCookie(&authCookie) accountHandler.ServeHTTP(rr, request) @@ -174,7 +160,7 @@ func testAccountLogin(t *testing.T, testConfig *config.Config) { } } -func testAccountNotAllowedHttpMethods(t *testing.T) { +func Test_AccountNotAllowedHttpMethods(t *testing.T) { var testInvalidAccountHttpMethods = []string{ http.MethodPut, http.MethodPatch, diff --git a/internal/server/handler/assets/assets_test.go b/internal/server/handler/assets/assets_test.go index 50639ac..355c661 100644 --- a/internal/server/handler/assets/assets_test.go +++ b/internal/server/handler/assets/assets_test.go @@ -33,13 +33,6 @@ func Test_Assets(t *testing.T) { assetsHandler := NewAssetHandler() - var testAssetsHttpMethods = []string{ - http.MethodPost, - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - } - type assetHttpParameter struct { path string expectedCode int @@ -90,6 +83,15 @@ func Test_Assets(t *testing.T) { } }) } +} + +func Test_AssetsNotAllowedHttpMethods(t *testing.T) { + var testAssetsHttpMethods = []string{ + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + } for _, method := range testAssetsHttpMethods { testMessage := fmt.Sprintf("Assets with unsupported method %s", method) diff --git a/internal/server/handler/authorize/authorize.go b/internal/server/handler/authorize/authorize.go index 917b35c..8c55827 100644 --- a/internal/server/handler/authorize/authorize.go +++ b/internal/server/handler/authorize/authorize.go @@ -171,7 +171,7 @@ func (h *Handler) handleGetRequest(w http.ResponseWriter, r *http.Request) { } else if slices.Contains(responseTypes, oauth2.RtCode) { setAuthorizationGrantParameter(query, id.String()) } else if idTokenRequest { - accessTokenHash := h.tokenManager.CreateAccessTokenHash(client, accessTokenResponse.AccessTokenKey) + accessTokenHash := h.tokenManager.CreateAccessTokenHash(client, accessTokenResponse.AccessTokenValue) idToken = h.tokenManager.CreateIdToken(r, user.Username, client, scopes, authSession.Nonce, accessTokenHash) } else { log.Error("Invalid response type %v", responseTypes) @@ -289,7 +289,7 @@ func setAuthorizationGrantParameter(query url.Values, code string) { } func setImplicitGrantParameter(query url.Values, accessTokenResponse oauth2.AccessTokenResponse) { - query.Set(oauth2.ParameterAccessToken, accessTokenResponse.AccessTokenKey) + query.Set(oauth2.ParameterAccessToken, accessTokenResponse.AccessTokenValue) query.Set(oauth2.ParameterTokenType, string(accessTokenResponse.TokenType)) query.Set(oauth2.ParameterExpiresIn, fmt.Sprintf("%d", accessTokenResponse.ExpiresIn)) // https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2 diff --git a/internal/server/handler/authorize/authorize_test.go b/internal/server/handler/authorize/authorize_test.go index e945ea0..5af6cf4 100644 --- a/internal/server/handler/authorize/authorize_test.go +++ b/internal/server/handler/authorize/authorize_test.go @@ -46,35 +46,17 @@ func Test_Authorize(t *testing.T) { t.Fatal(initializationError) } - testAuthorizeNoClientId(t) - - testAuthorizeInvalidClientId(t) - - testAuthorizeInvalidRedirect(t) - - testAuthorizeInvalidResponseType(t) - - testAuthorizeNoCookeExists(t) - testAuthorizeAuthorizationGrant(t, testConfig) testAuthorizeImplicitGrant(t, testConfig) - testAuthorizeInvalidLogin(t) - - testAuthorizeEmptyLogin(t) - - testAuthorizeValidLoginNoSession(t) - testAuthorizeValidLoginAuthorizationGrant(t, testConfig) testAuthorizeValidLoginImplicitGrant(t, testConfig) - testAuthorizeNotAllowedHttpMethods(t) - } -func testAuthorizeInvalidLogin(t *testing.T) { +func Test_AuthorizeInvalidLogin(t *testing.T) { type invalidLoginParameter struct { state string scope string @@ -161,7 +143,7 @@ func testAuthorizeInvalidLogin(t *testing.T) { } } -func testAuthorizeEmptyLogin(t *testing.T) { +func Test_AuthorizeEmptyLogin(t *testing.T) { type emptyLoginParameter struct { state string scope string @@ -246,7 +228,7 @@ func testAuthorizeEmptyLogin(t *testing.T) { } } -func testAuthorizeValidLoginNoSession(t *testing.T) { +func Test_AuthorizeValidLoginNoSession(t *testing.T) { type validLoginParameter struct { state string scope string @@ -341,155 +323,224 @@ func testAuthorizeValidLoginNoSession(t *testing.T) { } } -func testAuthorizeValidLoginAuthorizationGrant(t *testing.T, testConfig *config.Config) { - type validLoginParameter struct { - state string - scope string +func Test_AuthorizeNotAllowedHttpMethods(t *testing.T) { + var testInvalidAuthorizeHttpMethods = []string{ + http.MethodPut, + http.MethodPatch, + http.MethodDelete, } - var validLoginParameters = []validLoginParameter{ - {"", ""}, - {"abc", ""}, - {"", "foo:moo"}, - {"abc", "foo:moo"}, + for _, method := range testInvalidAuthorizeHttpMethods { + testMessage := fmt.Sprintf("Authorize with unsupported method %s", method) + t.Run(testMessage, func(t *testing.T) { + authorizeHandler := NewAuthorizeHandler(&validation.RequestValidator{}, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + + rr := httptest.NewRecorder() + + authorizeHandler.ServeHTTP(rr, httptest.NewRequest(method, endpoint.Authorization, nil)) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusMethodNotAllowed) + } + }) } - for _, test := range validLoginParameters { - testMessage := fmt.Sprintf("Valid login credentials, authorization grant session, with with state %v scope %v", test.state, test.scope) +} + +func Test_AuthorizeNoCookeExists(t *testing.T) { + parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { + query.Set(oauth2.ParameterClientId, "foo") + query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") + query.Set(oauth2.ParameterResponseType, oauth2.ParameterCode) + }) + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + cookieManager := cookie.GetCookieManagerInstance() + templateManager := template.GetTemplateManagerInstance() + + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, &token.Manager{}, templateManager) + + rr := httptest.NewRecorder() + + authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) + + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } + + contentType := rr.Header().Get(internalHttp.ContentType) + + if contentType != "text/html; charset=utf-8" { + t.Errorf("content type was not text/html: %v", contentType) + } + + response := rr.Result() + body, bodyReadErr := io.ReadAll(response.Body) + + if bodyReadErr != nil { + t.Errorf("could not read response body: %v", bodyReadErr) + } + + if body == nil { + t.Errorf("response body was nil") + } +} + +func Test_AuthorizeInvalidResponseType(t *testing.T) { + parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { + query.Set(oauth2.ParameterClientId, "foo") + query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") + query.Set(oauth2.ParameterResponseType, "abc") + }) + requestValidator := validation.NewRequestValidator() + + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + + rr := httptest.NewRecorder() + + authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) + + if rr.Code != http.StatusFound { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusFound) + } + + location, locationError := rr.Result().Location() + if locationError != nil { + t.Errorf("location was not provied: %v", locationError) + } + + errorQueryParameter := location.Query().Get(oauth2.ParameterError) + + errorType, errorTypeExists := oauth2.AuthorizationErrorTypeFromString(errorQueryParameter) + + if !errorTypeExists { + t.Errorf("error type could not be parsed: %v", errorQueryParameter) + } + + if errorType != oauth2.AuthorizationEtInvalidRequest { + t.Errorf("error type was not Invalid: %v", errorQueryParameter) + } +} + +func Test_AuthorizeInvalidRedirect(t *testing.T) { + type redirectTest struct { + redirect string + status int + } + + var redirectTestParameters = []redirectTest{ + {"://hahaNoURI", http.StatusBadRequest}, + {"", http.StatusBadRequest}, + {"http://example.com/foo", http.StatusBadRequest}, + } + + for _, test := range redirectTestParameters { + testMessage := fmt.Sprintf("Invalid redirect with %s", test.redirect) t.Run(testMessage, func(t *testing.T) { parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { query.Set(oauth2.ParameterClientId, "foo") - query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") - query.Set(oauth2.ParameterResponseType, oauth2.ParameterCode) - if test.state != "" { - query.Set(oauth2.ParameterState, test.state) - } - if test.scope != "" { - query.Set(oauth2.ParameterScope, test.scope) - } + query.Set(oauth2.ParameterRedirectUri, test.redirect) }) - client, _ := testConfig.GetClient("foo") + requestValidator := validation.NewRequestValidator() - id := uuid.New() - authSession := &session.AuthSession{ - Id: id.String(), - Redirect: "https://example.com/callback", - AuthURI: parsedUri.RequestURI(), - CodeChallenge: "", - CodeChallengeMethod: "", - ClientId: client.Id, - ResponseTypes: []oauth2.ResponseType{oauth2.RtCode}, - Scopes: []string{test.scope}, - State: test.state, + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + + rr := httptest.NewRecorder() + + authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) + + if rr.Code != test.status { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, test.status) } + }) - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - cookieManager := cookie.GetCookieManagerInstance() - tokenManager := token.GetTokenManagerInstance() - sessionManager.StartSession(authSession) + } +} - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) +func Test_AuthorizeInvalidClientId(t *testing.T) { + parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { + query.Set(oauth2.ParameterClientId, "bar") + }) - rr := httptest.NewRecorder() + requestValidator := validation.NewRequestValidator() - bodyString := testCreateBody( - "stopnik_auth_session", id.String(), - "stopnik_username", "foo", - "stopnik_password", "bar", - ) - body := strings.NewReader(bodyString) + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) - request := httptest.NewRequest(http.MethodPost, parsedUri.String(), body) - request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() - authorizeHandler.ServeHTTP(rr, request) + authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) - if rr.Code != http.StatusFound { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusFound) - } + if rr.Code != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) + } +} - location, locationError := rr.Result().Location() - if locationError != nil { - t.Errorf("location was not provied: %v", locationError) - } +func Test_AuthorizeNoClientId(t *testing.T) { + requestValidator := validation.NewRequestValidator() - codeQueryParameter := location.Query().Get(oauth2.ParameterCode) + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) - if codeQueryParameter == "" { - t.Errorf("code query parameter was not set") - } + rr := httptest.NewRecorder() - stateQueryParameter := location.Query().Get(oauth2.ParameterState) + authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, endpoint.Authorization, nil)) - if stateQueryParameter != test.state { - t.Errorf("state parameter %v did not match: %v", stateQueryParameter, test.state) - } - }) + if rr.Code != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) } } -func testAuthorizeValidLoginImplicitGrant(t *testing.T, testConfig *config.Config) { - type validLoginParameter struct { - state string - scope string +func testAuthorizeAuthorizationGrant(t *testing.T, testConfig *config.Config) { + type authorizationGrantParameter struct { + state string + scope string + pkceCodeChallenge string + pkceCodeChallengeMethod *pkce.CodeChallengeMethod } - var validLoginParameters = []validLoginParameter{ - {"", ""}, - {"abc", ""}, - {"", "foo:moo"}, - {"abc", "foo:moo"}, + ccmS256 := pkce.S256 + ccmPlain := pkce.PLAIN + + var authorizationGrantParameters = []authorizationGrantParameter{ + {"", "", "", nil}, + {"abc", "", "", nil}, + {"", "foo:moo", "", nil}, + {"abc", "foo:moo", "", nil}, + {"abc", "foo:moo", uuid.New().String(), &ccmS256}, + {"abc", "foo:moo", uuid.New().String(), &ccmPlain}, } - for _, test := range validLoginParameters { - testMessage := fmt.Sprintf("Valid login credentials, implicit grant session with with state %v scope %v", test.state, test.scope) + + for _, test := range authorizationGrantParameters { + testMessage := fmt.Sprintf("Cookie exists, authorization code grant with state %v scope %v code challenge %v", test.state, test.scope, test.pkceCodeChallenge) t.Run(testMessage, func(t *testing.T) { + pkceCodeChallenge := "" parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { query.Set(oauth2.ParameterClientId, "foo") query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") - query.Set(oauth2.ParameterResponseType, oauth2.ParameterToken) + query.Set(oauth2.ParameterResponseType, oauth2.ParameterCode) if test.state != "" { query.Set(oauth2.ParameterState, test.state) } if test.scope != "" { query.Set(oauth2.ParameterScope, test.scope) } + if test.pkceCodeChallenge != "" && test.pkceCodeChallengeMethod != nil { + pkceCodeChallenge = pkce.CalculatePKCE(*test.pkceCodeChallengeMethod, test.pkceCodeChallenge) + query.Set(pkce.ParameterCodeChallenge, pkceCodeChallenge) + } }) - - client, _ := testConfig.GetClient("foo") - - id := uuid.New() - authSession := &session.AuthSession{ - Id: id.String(), - Redirect: "https://example.com/callback", - AuthURI: parsedUri.RequestURI(), - CodeChallenge: "", - CodeChallengeMethod: "", - ClientId: client.Id, - ResponseTypes: []oauth2.ResponseType{oauth2.RtToken}, - Scopes: []string{test.scope}, - State: test.state, - } - requestValidator := validation.NewRequestValidator() sessionManager := session.GetAuthSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() tokenManager := token.GetTokenManagerInstance() - sessionManager.StartSession(authSession) + + user, _ := testConfig.GetUser("foo") + authCookie, _ := cookieManager.CreateAuthCookie(user.Username) authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() - - bodyString := testCreateBody( - "stopnik_auth_session", id.String(), - "stopnik_username", "foo", - "stopnik_password", "bar", - ) - body := strings.NewReader(bodyString) - - request := httptest.NewRequest(http.MethodPost, parsedUri.String(), body) - request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") + request := httptest.NewRequest(http.MethodGet, parsedUri.String(), nil) + request.AddCookie(&authCookie) authorizeHandler.ServeHTTP(rr, request) @@ -502,52 +553,32 @@ func testAuthorizeValidLoginImplicitGrant(t *testing.T, testConfig *config.Confi t.Errorf("location was not provied: %v", locationError) } - accessTokenQueryParameter := location.Query().Get(oauth2.ParameterAccessToken) + codeQueryParameter := location.Query().Get(oauth2.ParameterCode) - if accessTokenQueryParameter == "" { - t.Errorf("access token query parameter was not set") + if codeQueryParameter == "" { + t.Errorf("code query parameter was not set") } - tokenTypeQueryParameter := location.Query().Get(oauth2.ParameterTokenType) + stateQueryParameter := location.Query().Get(oauth2.ParameterState) - if tokenTypeQueryParameter != string(oauth2.TtBearer) { - t.Errorf("token type parameter %v did not match: %v", tokenTypeQueryParameter, oauth2.TtBearer) + if stateQueryParameter != test.state { + t.Errorf("state parameter %v did not match: %v", stateQueryParameter, test.state) } - expiresInTypeQueryParameter := location.Query().Get(oauth2.ParameterExpiresIn) - expiresIn, expiresParseError := strconv.Atoi(expiresInTypeQueryParameter) - if expiresParseError != nil { - t.Errorf("expires query parameter was not parsed: %v", expiresParseError) + authSession, sessionExists := sessionManager.GetSession(codeQueryParameter) + if !sessionExists { + t.Errorf("session does not exist: %v", codeQueryParameter) } - accessTokenDuration := time.Minute * time.Duration(client.GetAccessTTL()) - expectedExpiresIn := int(accessTokenDuration / time.Second) - - if expiresIn != expectedExpiresIn { - t.Errorf("expires in parameter %v did not match %v", expiresIn, expectedExpiresIn) + if authSession.CodeChallenge != pkceCodeChallenge { + t.Errorf("session code challenge %v did not match: %v", authSession.CodeChallenge, pkceCodeChallenge) } - }) - } -} - -func testAuthorizeNotAllowedHttpMethods(t *testing.T) { - var testInvalidAuthorizeHttpMethods = []string{ - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - } - - for _, method := range testInvalidAuthorizeHttpMethods { - testMessage := fmt.Sprintf("Authorize with unsupported method %s", method) - t.Run(testMessage, func(t *testing.T) { - authorizeHandler := NewAuthorizeHandler(&validation.RequestValidator{}, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) - rr := httptest.NewRecorder() - - authorizeHandler.ServeHTTP(rr, httptest.NewRequest(method, endpoint.Authorization, nil)) - - if rr.Code != http.StatusMethodNotAllowed { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusMethodNotAllowed) + if pkceCodeChallenge != "" { + validatePKCE := pkce.ValidatePKCE(*test.pkceCodeChallengeMethod, pkceCodeChallenge, test.pkceCodeChallenge) + if !validatePKCE { + t.Errorf("invalid PKCE code challenge: %v", pkceCodeChallenge) + } } }) } @@ -587,13 +618,13 @@ func testAuthorizeImplicitGrant(t *testing.T, testConfig *config.Config) { client, _ := testConfig.GetClient("foo") user, _ := testConfig.GetUser("foo") - cookie, _ := cookieManager.CreateAuthCookie(user.Username) + authCookie, _ := cookieManager.CreateAuthCookie(user.Username) authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, parsedUri.String(), nil) - request.AddCookie(&cookie) + request.AddCookie(&authCookie) authorizeHandler.ServeHTTP(rr, request) @@ -634,30 +665,21 @@ func testAuthorizeImplicitGrant(t *testing.T, testConfig *config.Config) { } } -func testAuthorizeAuthorizationGrant(t *testing.T, testConfig *config.Config) { - type authorizationGrantParameter struct { - state string - scope string - pkceCodeChallenge string - pkceCodeChallengeMethod *pkce.CodeChallengeMethod +func testAuthorizeValidLoginAuthorizationGrant(t *testing.T, testConfig *config.Config) { + type validLoginParameter struct { + state string + scope string } - ccmS256 := pkce.S256 - ccmPlain := pkce.PLAIN - - var authorizationGrantParameters = []authorizationGrantParameter{ - {"", "", "", nil}, - {"abc", "", "", nil}, - {"", "foo:moo", "", nil}, - {"abc", "foo:moo", "", nil}, - {"abc", "foo:moo", uuid.New().String(), &ccmS256}, - {"abc", "foo:moo", uuid.New().String(), &ccmPlain}, + var validLoginParameters = []validLoginParameter{ + {"", ""}, + {"abc", ""}, + {"", "foo:moo"}, + {"abc", "foo:moo"}, } - - for _, test := range authorizationGrantParameters { - testMessage := fmt.Sprintf("Cookie exists, authorization code grant with state %v scope %v code challenge %v", test.state, test.scope, test.pkceCodeChallenge) + for _, test := range validLoginParameters { + testMessage := fmt.Sprintf("Valid login credentials, authorization grant session, with with state %v scope %v", test.state, test.scope) t.Run(testMessage, func(t *testing.T) { - pkceCodeChallenge := "" parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { query.Set(oauth2.ParameterClientId, "foo") query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") @@ -668,24 +690,42 @@ func testAuthorizeAuthorizationGrant(t *testing.T, testConfig *config.Config) { if test.scope != "" { query.Set(oauth2.ParameterScope, test.scope) } - if test.pkceCodeChallenge != "" && test.pkceCodeChallengeMethod != nil { - pkceCodeChallenge = pkce.CalculatePKCE(*test.pkceCodeChallengeMethod, test.pkceCodeChallenge) - query.Set(pkce.ParameterCodeChallenge, pkceCodeChallenge) - } }) + + client, _ := testConfig.GetClient("foo") + + id := uuid.New() + authSession := &session.AuthSession{ + Id: id.String(), + Redirect: "https://example.com/callback", + AuthURI: parsedUri.RequestURI(), + CodeChallenge: "", + CodeChallengeMethod: "", + ClientId: client.Id, + ResponseTypes: []oauth2.ResponseType{oauth2.RtCode}, + Scopes: []string{test.scope}, + State: test.state, + } + requestValidator := validation.NewRequestValidator() sessionManager := session.GetAuthSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() tokenManager := token.GetTokenManagerInstance() - - user, _ := testConfig.GetUser("foo") - cookie, _ := cookieManager.CreateAuthCookie(user.Username) + sessionManager.StartSession(authSession) authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, parsedUri.String(), nil) - request.AddCookie(&cookie) + + bodyString := testCreateBody( + "stopnik_auth_session", id.String(), + "stopnik_username", "foo", + "stopnik_password", "bar", + ) + body := strings.NewReader(bodyString) + + request := httptest.NewRequest(http.MethodPost, parsedUri.String(), body) + request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") authorizeHandler.ServeHTTP(rr, request) @@ -709,176 +749,109 @@ func testAuthorizeAuthorizationGrant(t *testing.T, testConfig *config.Config) { if stateQueryParameter != test.state { t.Errorf("state parameter %v did not match: %v", stateQueryParameter, test.state) } - - session, sessionExists := sessionManager.GetSession(codeQueryParameter) - if !sessionExists { - t.Errorf("session does not exist: %v", codeQueryParameter) - } - - if session.CodeChallenge != pkceCodeChallenge { - t.Errorf("session code challenge %v did not match: %v", session.CodeChallenge, pkceCodeChallenge) - } - - if pkceCodeChallenge != "" { - validatePKCE := pkce.ValidatePKCE(*test.pkceCodeChallengeMethod, pkceCodeChallenge, test.pkceCodeChallenge) - if !validatePKCE { - t.Errorf("invalid PKCE code challenge: %v", pkceCodeChallenge) - } - } }) } } -func testAuthorizeNoCookeExists(t *testing.T) { - t.Run("No cookie exists", func(t *testing.T) { - parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { - query.Set(oauth2.ParameterClientId, "foo") - query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") - query.Set(oauth2.ParameterResponseType, oauth2.ParameterCode) - }) - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - cookieManager := cookie.GetCookieManagerInstance() - templateManager := template.GetTemplateManagerInstance() - - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, &token.Manager{}, templateManager) - - rr := httptest.NewRecorder() - - authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) - - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } - - contentType := rr.Header().Get(internalHttp.ContentType) - - if contentType != "text/html; charset=utf-8" { - t.Errorf("content type was not text/html: %v", contentType) - } - - response := rr.Result() - body, bodyReadErr := io.ReadAll(response.Body) - - if bodyReadErr != nil { - t.Errorf("could not read response body: %v", bodyReadErr) - } - - if body == nil { - t.Errorf("response body was nil") - } - - }) -} - -func testAuthorizeInvalidResponseType(t *testing.T) { - t.Run("Invalid response type", func(t *testing.T) { - parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { - query.Set(oauth2.ParameterClientId, "foo") - query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") - query.Set(oauth2.ParameterResponseType, "abc") - }) - requestValidator := validation.NewRequestValidator() - - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) - - rr := httptest.NewRecorder() - - authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) - - if rr.Code != http.StatusFound { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusFound) - } - - location, locationError := rr.Result().Location() - if locationError != nil { - t.Errorf("location was not provied: %v", locationError) - } - - errorQueryParameter := location.Query().Get(oauth2.ParameterError) - - errorType, errorTypeExists := oauth2.AuthorizationErrorTypeFromString(errorQueryParameter) - - if !errorTypeExists { - t.Errorf("error type could not be parsed: %v", errorQueryParameter) - } - - if errorType != oauth2.AuthorizationEtInvalidRequest { - t.Errorf("error type was not Invalid: %v", errorQueryParameter) - } - }) -} - -func testAuthorizeInvalidRedirect(t *testing.T) { - type redirectTest struct { - redirect string - status int +func testAuthorizeValidLoginImplicitGrant(t *testing.T, testConfig *config.Config) { + type validLoginParameter struct { + state string + scope string } - var redirectTestParameters = []redirectTest{ - {"://hahaNoURI", http.StatusBadRequest}, - {"", http.StatusBadRequest}, - {"http://example.com/foo", http.StatusBadRequest}, + var validLoginParameters = []validLoginParameter{ + {"", ""}, + {"abc", ""}, + {"", "foo:moo"}, + {"abc", "foo:moo"}, } - - for _, test := range redirectTestParameters { - testMessage := fmt.Sprintf("Invalid redirect with %s", test.redirect) + for _, test := range validLoginParameters { + testMessage := fmt.Sprintf("Valid login credentials, implicit grant session with with state %v scope %v", test.state, test.scope) t.Run(testMessage, func(t *testing.T) { parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { query.Set(oauth2.ParameterClientId, "foo") - query.Set(oauth2.ParameterRedirectUri, test.redirect) + query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") + query.Set(oauth2.ParameterResponseType, oauth2.ParameterToken) + if test.state != "" { + query.Set(oauth2.ParameterState, test.state) + } + if test.scope != "" { + query.Set(oauth2.ParameterScope, test.scope) + } }) - requestValidator := validation.NewRequestValidator() + client, _ := testConfig.GetClient("foo") - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + id := uuid.New() + authSession := &session.AuthSession{ + Id: id.String(), + Redirect: "https://example.com/callback", + AuthURI: parsedUri.RequestURI(), + CodeChallenge: "", + CodeChallengeMethod: "", + ClientId: client.Id, + ResponseTypes: []oauth2.ResponseType{oauth2.RtToken}, + Scopes: []string{test.scope}, + State: test.state, + } - rr := httptest.NewRecorder() + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + cookieManager := cookie.GetCookieManagerInstance() + tokenManager := token.GetTokenManagerInstance() + sessionManager.StartSession(authSession) - authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) - if rr.Code != test.status { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, test.status) - } - }) + rr := httptest.NewRecorder() - } -} + bodyString := testCreateBody( + "stopnik_auth_session", id.String(), + "stopnik_username", "foo", + "stopnik_password", "bar", + ) + body := strings.NewReader(bodyString) -func testAuthorizeInvalidClientId(t *testing.T) { - t.Run("Invalid client id", func(t *testing.T) { - parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { - query.Set(oauth2.ParameterClientId, "bar") - }) + request := httptest.NewRequest(http.MethodPost, parsedUri.String(), body) + request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") - requestValidator := validation.NewRequestValidator() + authorizeHandler.ServeHTTP(rr, request) - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + if rr.Code != http.StatusFound { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusFound) + } - rr := httptest.NewRecorder() + location, locationError := rr.Result().Location() + if locationError != nil { + t.Errorf("location was not provied: %v", locationError) + } - authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, parsedUri.String(), nil)) + accessTokenQueryParameter := location.Query().Get(oauth2.ParameterAccessToken) - if rr.Code != http.StatusBadRequest { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) - } - }) -} + if accessTokenQueryParameter == "" { + t.Errorf("access token query parameter was not set") + } -func testAuthorizeNoClientId(t *testing.T) { - t.Run("No client id provided", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() + tokenTypeQueryParameter := location.Query().Get(oauth2.ParameterTokenType) - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + if tokenTypeQueryParameter != string(oauth2.TtBearer) { + t.Errorf("token type parameter %v did not match: %v", tokenTypeQueryParameter, oauth2.TtBearer) + } - rr := httptest.NewRecorder() + expiresInTypeQueryParameter := location.Query().Get(oauth2.ParameterExpiresIn) + expiresIn, expiresParseError := strconv.Atoi(expiresInTypeQueryParameter) + if expiresParseError != nil { + t.Errorf("expires query parameter was not parsed: %v", expiresParseError) + } - authorizeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, endpoint.Authorization, nil)) + accessTokenDuration := time.Minute * time.Duration(client.GetAccessTTL()) + expectedExpiresIn := int(accessTokenDuration / time.Second) - if rr.Code != http.StatusBadRequest { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) - } - }) + if expiresIn != expectedExpiresIn { + t.Errorf("expires in parameter %v did not match %v", expiresIn, expectedExpiresIn) + } + }) + } } func createUri(t *testing.T, uri string, handler func(query url.Values)) *url.URL { diff --git a/internal/server/handler/forwardauth/forwardauth.go b/internal/server/handler/forwardauth/forwardauth.go index 1b1f7f3..6a356be 100644 --- a/internal/server/handler/forwardauth/forwardauth.go +++ b/internal/server/handler/forwardauth/forwardauth.go @@ -5,6 +5,7 @@ import ( "github.com/google/uuid" "github.com/webishdev/stopnik/internal/config" "github.com/webishdev/stopnik/internal/endpoint" + internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/manager/cookie" "github.com/webishdev/stopnik/internal/manager/session" "github.com/webishdev/stopnik/internal/oauth2" @@ -40,9 +41,14 @@ func NewForwardAuthHandler(cookieManager *cookie.Manager, authSessionManager ses func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.AccessLogRequest(r) - forwardProtocol := r.Header.Get("X-Forwarded-Proto") - forwardHost := r.Header.Get("X-Forwarded-Host") - forwardPath := r.Header.Get("X-Forwarded-Uri") + forwardProtocol := r.Header.Get(internalHttp.XForwardProtocol) + forwardHost := r.Header.Get(internalHttp.XForwardHost) + forwardPath := r.Header.Get(internalHttp.XForwardUri) + + if forwardProtocol == "" || forwardHost == "" || forwardPath == "" { + h.errorHandler.BadRequestHandler(w, r) + return + } forwardString := fmt.Sprintf("%s://%s%s", forwardProtocol, forwardHost, forwardPath) forwardUri, forwardUriError := url.Parse(forwardString) @@ -51,8 +57,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + forwardAuthParameterName := h.config.GetForwardAuthParameterName() + codeParameter := forwardUri.Query().Get(oauth2.ParameterCode) - forwardIdParameter := forwardUri.Query().Get("forward_id") + forwardIdParameter := forwardUri.Query().Get(forwardAuthParameterName) _, validCookie := h.cookieManager.ValidateAuthCookie(r) @@ -68,7 +76,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } else { http.SetCookie(w, authCookie) - w.Header().Set("Location", forwardSession.RedirectUri) + w.Header().Set(internalHttp.Location, forwardSession.RedirectUri) w.WriteHeader(http.StatusTemporaryRedirect) return } @@ -79,7 +87,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { forwardSessionId := uuid.NewString() redirectUri, redirectUriError := createUri(forwardString, "", func(query url.Values) { - query.Set("forward_id", forwardSessionId) + query.Set(forwardAuthParameterName, forwardSessionId) }) if redirectUriError != nil { h.errorHandler.InternalServerErrorHandler(w, r) diff --git a/internal/server/handler/forwardauth/forwardauth_test.go b/internal/server/handler/forwardauth/forwardauth_test.go new file mode 100644 index 0000000..3af0474 --- /dev/null +++ b/internal/server/handler/forwardauth/forwardauth_test.go @@ -0,0 +1,113 @@ +package forwardauth + +import ( + "github.com/webishdev/stopnik/internal/config" + internalHttp "github.com/webishdev/stopnik/internal/http" + "github.com/webishdev/stopnik/internal/manager/cookie" + "github.com/webishdev/stopnik/internal/manager/session" + "github.com/webishdev/stopnik/internal/template" + "net/http" + "net/http/httptest" + "testing" +) + +func Test_ForwardAuth(t *testing.T) { + testConfig := &config.Config{ + Clients: []config.Client{ + { + Id: "foo", + ClientSecret: "d82c4eb5261cb9c8aa9855edd67d1bd10482f41529858d925094d173fa662aa91ff39bc5b188615273484021dfb16fd8284cf684ccf0fc795be3aa2fc1e6c181", + Redirects: []string{"https://example.com/callback"}, + }, + }, + Users: []config.User{ + { + Username: "foo", + Password: "d82c4eb5261cb9c8aa9855edd67d1bd10482f41529858d925094d173fa662aa91ff39bc5b188615273484021dfb16fd8284cf684ccf0fc795be3aa2fc1e6c181", + }, + }, + } + + initializationError := config.Initialize(testConfig) + if initializationError != nil { + t.Fatal(initializationError) + } + + testForwardAuthWithoutCookie(t, testConfig) + + testForwardAuthMissingHeaders(t, testConfig) + + testForwardAuthInvalidHeaders(t, testConfig) +} + +func testForwardAuthWithoutCookie(t *testing.T, testConfig *config.Config) { + t.Run("ForwardAuth without cookie", func(t *testing.T) { + cookieManager := cookie.GetCookieManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() + forwardSessionManager := session.GetForwardSessionManagerInstance() + templateManager := template.GetTemplateManagerInstance() + + forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, templateManager) + + rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, testConfig.GetForwardAuthEndpoint(), nil) + request.Header.Set(internalHttp.XForwardProtocol, "http") + request.Header.Set(internalHttp.XForwardHost, "localhost:8080") + request.Header.Set(internalHttp.XForwardUri, "/blabla") + + forwardAuthHandler.ServeHTTP(rr, request) + + if rr.Code != http.StatusTemporaryRedirect { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusTemporaryRedirect) + } + + _, locationError := rr.Result().Location() + if locationError != nil { + t.Errorf("location was not provied: %v", locationError) + } + }) + +} + +func testForwardAuthMissingHeaders(t *testing.T, testConfig *config.Config) { + t.Run("ForwardAuth without forward headers", func(t *testing.T) { + cookieManager := cookie.GetCookieManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() + forwardSessionManager := session.GetForwardSessionManagerInstance() + templateManager := template.GetTemplateManagerInstance() + + forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, templateManager) + + rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, testConfig.GetForwardAuthEndpoint(), nil) + + forwardAuthHandler.ServeHTTP(rr, request) + + if rr.Code != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) + } + }) +} + +func testForwardAuthInvalidHeaders(t *testing.T, testConfig *config.Config) { + t.Run("ForwardAuth with invalid forward headers", func(t *testing.T) { + cookieManager := cookie.GetCookieManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() + forwardSessionManager := session.GetForwardSessionManagerInstance() + templateManager := template.GetTemplateManagerInstance() + + forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, templateManager) + + rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, testConfig.GetForwardAuthEndpoint(), nil) + request.Header.Set(internalHttp.XForwardProtocol, "!6721abc") + request.Header.Set(internalHttp.XForwardHost, "??+-#127fkhas:8080") + request.Header.Set(internalHttp.XForwardUri, "+ß128lj") + + forwardAuthHandler.ServeHTTP(rr, request) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusInternalServerError) + } + }) +} diff --git a/internal/server/handler/health/health_test.go b/internal/server/handler/health/health_test.go index d4422da..33a867c 100644 --- a/internal/server/handler/health/health_test.go +++ b/internal/server/handler/health/health_test.go @@ -11,7 +11,7 @@ import ( "testing" ) -func Test_Health(t *testing.T) { +func Test_HealthWithToken(t *testing.T) { testConfig := &config.Config{ Clients: []config.Client{ { @@ -32,69 +32,67 @@ func Test_Health(t *testing.T) { t.Fatal(initializationError) } - t.Run("Health without token", func(t *testing.T) { - tokenManager := token.GetTokenManagerInstance() + tokenManager := token.GetTokenManagerInstance() - healthHandler := NewHealthHandler(tokenManager) - - httpRequest := &http.Request{ - Method: http.MethodGet, - } - rr := httptest.NewRecorder() - - healthHandler.ServeHTTP(rr, httpRequest) + client, clientExists := testConfig.GetClient("foo") + if !clientExists { + t.Error("client should exist") + } - contentType := rr.Header().Get(internalHttp.ContentType) + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + tokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, []string{"a:foo", "b:bar"}, "") - if contentType != internalHttp.ContentTypeJSON { - t.Errorf("content type should be %s", internalHttp.ContentTypeJSON) - } + healthHandler := NewHealthHandler(tokenManager) - jsonString := rr.Body.String() + httpRequest := &http.Request{ + Method: http.MethodGet, + Header: http.Header{ + internalHttp.Authorization: []string{"Bearer " + tokenResponse.AccessTokenValue}, + }, + } + rr := httptest.NewRecorder() - if jsonString != `{"ping":"pong"}` { - t.Errorf("json string should be %s, but was %s", `{"ping":"pong"}`, jsonString) - } + healthHandler.ServeHTTP(rr, httpRequest) - }) + contentType := rr.Header().Get(internalHttp.ContentType) - t.Run("Health with token", func(t *testing.T) { - tokenManager := token.GetTokenManagerInstance() + if contentType != internalHttp.ContentTypeJSON { + t.Errorf("content type should be %s", internalHttp.ContentTypeJSON) + } - client, clientExists := testConfig.GetClient("foo") - if !clientExists { - t.Error("client should exist") - } + jsonString := rr.Body.String() - request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) - tokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, []string{"a:foo", "b:bar"}, "") + if jsonString != `{"ping":"pong","username":"foo","scopes":["a:foo","b:bar"]}` { + t.Errorf("json string should be %s, but was %s", `{"ping":"pong","username":"foo","scopes":["a:foo","b:bar"]}`, jsonString) + } +} - healthHandler := NewHealthHandler(tokenManager) +func Test_HealthWithoutToken(t *testing.T) { + tokenManager := token.GetTokenManagerInstance() - httpRequest := &http.Request{ - Method: http.MethodGet, - Header: http.Header{ - internalHttp.Authorization: []string{"Bearer " + tokenResponse.AccessTokenKey}, - }, - } - rr := httptest.NewRecorder() + healthHandler := NewHealthHandler(tokenManager) - healthHandler.ServeHTTP(rr, httpRequest) + httpRequest := &http.Request{ + Method: http.MethodGet, + } + rr := httptest.NewRecorder() - contentType := rr.Header().Get(internalHttp.ContentType) + healthHandler.ServeHTTP(rr, httpRequest) - if contentType != internalHttp.ContentTypeJSON { - t.Errorf("content type should be %s", internalHttp.ContentTypeJSON) - } + contentType := rr.Header().Get(internalHttp.ContentType) - jsonString := rr.Body.String() + if contentType != internalHttp.ContentTypeJSON { + t.Errorf("content type should be %s", internalHttp.ContentTypeJSON) + } - if jsonString != `{"ping":"pong","username":"foo","scopes":["a:foo","b:bar"]}` { - t.Errorf("json string should be %s, but was %s", `{"ping":"pong","username":"foo","scopes":["a:foo","b:bar"]}`, jsonString) - } + jsonString := rr.Body.String() - }) + if jsonString != `{"ping":"pong"}` { + t.Errorf("json string should be %s, but was %s", `{"ping":"pong"}`, jsonString) + } +} +func Test_HealthNotAllowedHttpMethods(t *testing.T) { var testInvalidHealthHttpMethods = []string{ http.MethodPost, http.MethodPut, diff --git a/internal/server/handler/introspect/introspect_test.go b/internal/server/handler/introspect/introspect_test.go index d9e0416..364a31f 100644 --- a/internal/server/handler/introspect/introspect_test.go +++ b/internal/server/handler/introspect/introspect_test.go @@ -51,61 +51,47 @@ func Test_Introspect(t *testing.T) { t.Fatal(initializationError) } - testIntrospectMissingClientCredentials(t) - - testIntrospectInvalidClientCredentials(t) - - testIntrospectEmptyToken(t) - - testIntrospectInvalidToken(t) - testIntrospect(t, testConfig) testIntrospectWithoutHint(t, testConfig) testIntrospectDisabled(t, testConfig) - - testIntrospectNotAllowedHttpMethods(t) } -func testIntrospectMissingClientCredentials(t *testing.T) { - t.Run("Missing client credentials", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - tokenManager := token.GetTokenManagerInstance() +func Test_IntrospectMissingClientCredentials(t *testing.T) { + requestValidator := validation.NewRequestValidator() + tokenManager := token.GetTokenManagerInstance() - introspectHandler := NewIntrospectHandler(requestValidator, tokenManager) + introspectHandler := NewIntrospectHandler(requestValidator, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - introspectHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, endpoint.Introspect, nil)) + introspectHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, endpoint.Introspect, nil)) - if rr.Code != http.StatusUnauthorized { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) - } - }) + if rr.Code != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) + } } -func testIntrospectInvalidClientCredentials(t *testing.T) { - t.Run("Invalid client credentials", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - tokenManager := token.GetTokenManagerInstance() +func Test_IntrospectInvalidClientCredentials(t *testing.T) { + requestValidator := validation.NewRequestValidator() + tokenManager := token.GetTokenManagerInstance() - introspectHandler := NewIntrospectHandler(requestValidator, tokenManager) + introspectHandler := NewIntrospectHandler(requestValidator, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, nil) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "xxx"))) + request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, nil) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "xxx"))) - introspectHandler.ServeHTTP(rr, request) + introspectHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusUnauthorized { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) - } - }) + if rr.Code != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) + } } -func testIntrospectEmptyToken(t *testing.T) { +func Test_IntrospectEmptyToken(t *testing.T) { type introspectParameter struct { tokenHint oauth2.IntrospectTokenType } @@ -151,7 +137,7 @@ func testIntrospectEmptyToken(t *testing.T) { } } -func testIntrospectInvalidToken(t *testing.T) { +func Test_IntrospectInvalidToken(t *testing.T) { type introspectParameter struct { tokenHint oauth2.IntrospectTokenType } @@ -198,6 +184,36 @@ func testIntrospectInvalidToken(t *testing.T) { } } +func Test_IntrospectNotAllowedHttpMethods(t *testing.T) { + var testInvalidIntrospectHttpMethods = []string{ + http.MethodGet, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + } + + testConfig := &config.Config{} + initializationError := config.Initialize(testConfig) + if initializationError != nil { + t.Fatal(initializationError) + } + + for _, method := range testInvalidIntrospectHttpMethods { + testMessage := fmt.Sprintf("Introspect with unsupported method %s", method) + t.Run(testMessage, func(t *testing.T) { + introspectHandler := NewIntrospectHandler(&validation.RequestValidator{}, &token.Manager{}) + + rr := httptest.NewRecorder() + + introspectHandler.ServeHTTP(rr, httptest.NewRequest(method, endpoint.Introspect, nil)) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusMethodNotAllowed) + } + }) + } +} + func testIntrospect(t *testing.T, testConfig *config.Config) { type introspectParameter struct { tokenHint oauth2.IntrospectTokenType @@ -237,15 +253,15 @@ func testIntrospect(t *testing.T, testConfig *config.Config) { introspectHandler := NewIntrospectHandler(requestValidator, tokenManager) - token := accessTokenResponse.AccessTokenKey + accessTokenValue := accessTokenResponse.AccessTokenValue if test.tokenHint == oauth2.ItRefreshToken { - token = accessTokenResponse.RefreshTokenKey + accessTokenValue = accessTokenResponse.RefreshTokenValue } rr := httptest.NewRecorder() bodyString := testCreateBody( - oauth2.ParameterToken, token, + oauth2.ParameterToken, accessTokenValue, oauth2.ParameterTokenTypeHint, test.tokenHint, ) body := strings.NewReader(bodyString) @@ -310,15 +326,15 @@ func testIntrospectWithoutHint(t *testing.T, testConfig *config.Config) { introspectHandler := NewIntrospectHandler(requestValidator, tokenManager) - token := accessTokenResponse.AccessTokenKey + accessTokenValue := accessTokenResponse.AccessTokenValue if test.tokenType == oauth2.ItRefreshToken { - token = accessTokenResponse.RefreshTokenKey + accessTokenValue = accessTokenResponse.RefreshTokenValue } rr := httptest.NewRecorder() bodyString := testCreateBody( - oauth2.ParameterToken, token, + oauth2.ParameterToken, accessTokenValue, ) body := strings.NewReader(bodyString) @@ -382,15 +398,15 @@ func testIntrospectDisabled(t *testing.T, testConfig *config.Config) { introspectHandler := NewIntrospectHandler(requestValidator, tokenManager) - token := accessTokenResponse.AccessTokenKey + accessTokenValue := accessTokenResponse.AccessTokenValue if test.tokenHint == oauth2.ItRefreshToken { - token = accessTokenResponse.RefreshTokenKey + accessTokenValue = accessTokenResponse.RefreshTokenValue } rr := httptest.NewRecorder() bodyString := testCreateBody( - oauth2.ParameterToken, token, + oauth2.ParameterToken, accessTokenValue, oauth2.ParameterTokenTypeHint, test.tokenHint, ) body := strings.NewReader(bodyString) @@ -408,36 +424,6 @@ func testIntrospectDisabled(t *testing.T, testConfig *config.Config) { } } -func testIntrospectNotAllowedHttpMethods(t *testing.T) { - var testInvalidIntrospectHttpMethods = []string{ - http.MethodGet, - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - } - - testConfig := &config.Config{} - initializationError := config.Initialize(testConfig) - if initializationError != nil { - t.Fatal(initializationError) - } - - for _, method := range testInvalidIntrospectHttpMethods { - testMessage := fmt.Sprintf("Introspect with unsupported method %s", method) - t.Run(testMessage, func(t *testing.T) { - introspectHandler := NewIntrospectHandler(&validation.RequestValidator{}, &token.Manager{}) - - rr := httptest.NewRecorder() - - introspectHandler.ServeHTTP(rr, httptest.NewRequest(method, endpoint.Introspect, nil)) - - if rr.Code != http.StatusMethodNotAllowed { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusMethodNotAllowed) - } - }) - } -} - func testIntrospectParse(t *testing.T, r *http.Response) response { responseBody, bodyReadErr := io.ReadAll(r.Body) diff --git a/internal/server/handler/keys/keys_test.go b/internal/server/handler/keys/keys_test.go index d5325b7..442be47 100644 --- a/internal/server/handler/keys/keys_test.go +++ b/internal/server/handler/keys/keys_test.go @@ -56,62 +56,53 @@ func Test_Keys(t *testing.T) { t.Fatal(initializationError) } - testKeys(t) + keyManger := key.GetKeyMangerInstance() - testKeysNotAllowedHttpMethods(t) + keysHandler := NewKeysHandler(keyManger) -} - -func testKeys(t *testing.T) { - t.Run("Get keys", func(t *testing.T) { - keyManger := key.GetKeyMangerInstance() - - keysHandler := NewKeysHandler(keyManger) - - rr := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, endpoint.Keys, nil) + rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, endpoint.Keys, nil) - keysHandler.ServeHTTP(rr, request) + keysHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } - - requestResponse := rr.Result() + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } - keys := testKeyParse(t, requestResponse) + requestResponse := rr.Result() - if len(keys.Keys) != 3 { - t.Errorf("handler returned wrong number of keys: got %v want %v", len(keys.Keys), 3) - } + keys := testKeyParse(t, requestResponse) - containsES512 := slices.ContainsFunc(keys.Keys, func(r responseKeys) bool { - return r.Alg == string(jwa.ES512) && r.Kty == "EC" - }) + if len(keys.Keys) != 3 { + t.Errorf("handler returned wrong number of keys: got %v want %v", len(keys.Keys), 3) + } - if !containsES512 { - t.Error("key for ES512 was missing") - } + containsES512 := slices.ContainsFunc(keys.Keys, func(r responseKeys) bool { + return r.Alg == string(jwa.ES512) && r.Kty == "EC" + }) - containsES256 := slices.ContainsFunc(keys.Keys, func(r responseKeys) bool { - return r.Alg == string(jwa.ES256) && r.Kty == "EC" - }) + if !containsES512 { + t.Error("key for ES512 was missing") + } - if !containsES256 { - t.Error("key for ES256 was missing") - } + containsES256 := slices.ContainsFunc(keys.Keys, func(r responseKeys) bool { + return r.Alg == string(jwa.ES256) && r.Kty == "EC" + }) - containsRSA256 := slices.ContainsFunc(keys.Keys, func(r responseKeys) bool { - return r.Alg == string(jwa.RS256) && r.Kty == "RSA" - }) + if !containsES256 { + t.Error("key for ES256 was missing") + } - if !containsRSA256 { - t.Error("key for RS256 was missing") - } + containsRSA256 := slices.ContainsFunc(keys.Keys, func(r responseKeys) bool { + return r.Alg == string(jwa.RS256) && r.Kty == "RSA" }) + + if !containsRSA256 { + t.Error("key for RS256 was missing") + } } -func testKeysNotAllowedHttpMethods(t *testing.T) { +func Test_KeysNotAllowedHttpMethods(t *testing.T) { var testInvalidIntrospectHttpMethods = []string{ http.MethodPost, http.MethodPut, diff --git a/internal/server/handler/logout/logout_test.go b/internal/server/handler/logout/logout_test.go index 3dbd9da..bb60c7f 100644 --- a/internal/server/handler/logout/logout_test.go +++ b/internal/server/handler/logout/logout_test.go @@ -38,15 +38,13 @@ func Test_Logout(t *testing.T) { testInvalidCookies(t, testConfig) testLogout(t, testConfig) - - testLogoutNotAllowedHttpMethods(t) } func testInvalidCookies(t *testing.T, testConfig *config.Config) { t.Run("Logout with invalid cookie", func(t *testing.T) { cookieManager := cookie.GetCookieManagerInstance() - cookie := http.Cookie{ + authCookie := http.Cookie{ Name: testConfig.GetAuthCookieName(), Value: "foobar", Path: "/", @@ -60,7 +58,7 @@ func testInvalidCookies(t *testing.T, testConfig *config.Config) { rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodPost, endpoint.Logout, nil) - request.AddCookie(&cookie) + request.AddCookie(&authCookie) logoutHandler.ServeHTTP(rr, request) @@ -88,7 +86,7 @@ func testLogout(t *testing.T, testConfig *config.Config) { cookieManager := cookie.GetCookieManagerInstance() user, _ := testConfig.GetUser("foo") - cookie, _ := cookieManager.CreateAuthCookie(user.Username) + authCookie, _ := cookieManager.CreateAuthCookie(user.Username) logoutHandler := NewLogoutHandler(cookieManager, test.handlerRedirect) @@ -103,7 +101,7 @@ func testLogout(t *testing.T, testConfig *config.Config) { body := strings.NewReader(bodyString) request := httptest.NewRequest(http.MethodPost, endpoint.Logout, body) - request.AddCookie(&cookie) + request.AddCookie(&authCookie) if bodyString != "" { request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") } @@ -137,7 +135,7 @@ func testLogout(t *testing.T, testConfig *config.Config) { } } -func testLogoutNotAllowedHttpMethods(t *testing.T) { +func Test_LogoutNotAllowedHttpMethods(t *testing.T) { var testInvalidLogoutHttpMethods = []string{ http.MethodGet, http.MethodPut, diff --git a/internal/server/handler/metadata/metadata_test.go b/internal/server/handler/metadata/metadata_test.go index bad77bb..1e7b4d6 100644 --- a/internal/server/handler/metadata/metadata_test.go +++ b/internal/server/handler/metadata/metadata_test.go @@ -11,58 +11,51 @@ import ( ) func Test_Metadata(t *testing.T) { - testMetadata(t) - testMetadataNotAllowedHttpMethods(t) -} - -func testMetadata(t *testing.T) { - t.Run("Get metadata", func(t *testing.T) { - metadataHandler := NewMetadataHandler() + metadataHandler := NewMetadataHandler() - rr := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, endpoint.Keys, nil) + rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, endpoint.Keys, nil) - metadataHandler.ServeHTTP(rr, request) + metadataHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } - requestResponse := rr.Result() + requestResponse := rr.Result() - metadata := testMetadataParse(t, requestResponse) + metadata := testMetadataParse(t, requestResponse) - if metadata.Issuer != "http://example.com" { - t.Error("metadata issuer did not match") - } + if metadata.Issuer != "http://example.com" { + t.Error("metadata issuer did not match") + } - if metadata.AuthorizationEndpoint != "http://example.com/authorize" { - t.Error("metadata authorization_endpoint did not match") - } + if metadata.AuthorizationEndpoint != "http://example.com/authorize" { + t.Error("metadata authorization_endpoint did not match") + } - if metadata.TokenEndpoint != "http://example.com/token" { - t.Error("metadata token_endpoint did not match") - } + if metadata.TokenEndpoint != "http://example.com/token" { + t.Error("metadata token_endpoint did not match") + } - if metadata.JWKsUri != "http://example.com/keys" { - t.Error("metadata jwks_uri did not match") - } + if metadata.JWKsUri != "http://example.com/keys" { + t.Error("metadata jwks_uri did not match") + } - if metadata.IntrospectionEndpoint != "http://example.com/introspect" { - t.Error("metadata introspection_endpoint did not match") - } + if metadata.IntrospectionEndpoint != "http://example.com/introspect" { + t.Error("metadata introspection_endpoint did not match") + } - if metadata.RevocationEndpoint != "http://example.com/revoke" { - t.Error("metadata revocation_endpoint did not match") - } + if metadata.RevocationEndpoint != "http://example.com/revoke" { + t.Error("metadata revocation_endpoint did not match") + } - if metadata.ServiceDocumentation != "https://stopnik.webish.dev" { - t.Error("metadata service_documentation did not match") - } - }) + if metadata.ServiceDocumentation != "https://stopnik.webish.dev" { + t.Error("metadata service_documentation did not match") + } } -func testMetadataNotAllowedHttpMethods(t *testing.T) { +func Test_MetadataNotAllowedHttpMethods(t *testing.T) { var testInvalidMetadataHttpMethods = []string{ http.MethodPost, http.MethodPut, diff --git a/internal/server/handler/oidc/discovery_test.go b/internal/server/handler/oidc/discovery_test.go index 39c9a8d..870e3df 100644 --- a/internal/server/handler/oidc/discovery_test.go +++ b/internal/server/handler/oidc/discovery_test.go @@ -11,58 +11,51 @@ import ( ) func Test_OidcConfiguration(t *testing.T) { - testOidcConfiguration(t) - testOidcConfigurationNotAllowedHttpMethods(t) -} - -func testOidcConfiguration(t *testing.T) { - t.Run("Get OIDC configuration", func(t *testing.T) { - oidcDiscoveryHandler := NewOidcDiscoveryHandler() + oidcDiscoveryHandler := NewOidcDiscoveryHandler() - rr := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, endpoint.Keys, nil) + rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, endpoint.Keys, nil) - oidcDiscoveryHandler.ServeHTTP(rr, request) + oidcDiscoveryHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } - requestResponse := rr.Result() + requestResponse := rr.Result() - oidcConfigurationParse := testOidcConfigurationParse(t, requestResponse) + oidcConfigurationParse := testOidcConfigurationParse(t, requestResponse) - if oidcConfigurationParse.Issuer != "http://example.com" { - t.Error("oidcConfigurationParse issuer did not match") - } + if oidcConfigurationParse.Issuer != "http://example.com" { + t.Error("oidcConfigurationParse issuer did not match") + } - if oidcConfigurationParse.AuthorizationEndpoint != "http://example.com/authorize" { - t.Error("oidcConfigurationParse authorization_endpoint did not match") - } + if oidcConfigurationParse.AuthorizationEndpoint != "http://example.com/authorize" { + t.Error("oidcConfigurationParse authorization_endpoint did not match") + } - if oidcConfigurationParse.TokenEndpoint != "http://example.com/token" { - t.Error("oidcConfigurationParse token_endpoint did not match") - } + if oidcConfigurationParse.TokenEndpoint != "http://example.com/token" { + t.Error("oidcConfigurationParse token_endpoint did not match") + } - if oidcConfigurationParse.JWKsUri != "http://example.com/keys" { - t.Error("oidcConfigurationParse jwks_uri did not match") - } + if oidcConfigurationParse.JWKsUri != "http://example.com/keys" { + t.Error("oidcConfigurationParse jwks_uri did not match") + } - if oidcConfigurationParse.IntrospectionEndpoint != "http://example.com/introspect" { - t.Error("oidcConfigurationParse introspection_endpoint did not match") - } + if oidcConfigurationParse.IntrospectionEndpoint != "http://example.com/introspect" { + t.Error("oidcConfigurationParse introspection_endpoint did not match") + } - if oidcConfigurationParse.RevocationEndpoint != "http://example.com/revoke" { - t.Error("oidcConfigurationParse revocation_endpoint did not match") - } + if oidcConfigurationParse.RevocationEndpoint != "http://example.com/revoke" { + t.Error("oidcConfigurationParse revocation_endpoint did not match") + } - if oidcConfigurationParse.ServiceDocumentation != "https://stopnik.webish.dev" { - t.Error("oidcConfigurationParse service_documentation did not match") - } - }) + if oidcConfigurationParse.ServiceDocumentation != "https://stopnik.webish.dev" { + t.Error("oidcConfigurationParse service_documentation did not match") + } } -func testOidcConfigurationNotAllowedHttpMethods(t *testing.T) { +func Test_OidcConfigurationNotAllowedHttpMethods(t *testing.T) { var testInvalidOidcConfigurationHttpMethods = []string{ http.MethodPost, http.MethodPut, diff --git a/internal/server/handler/oidc/userinfo_test.go b/internal/server/handler/oidc/userinfo_test.go index 9227a0d..d25128b 100644 --- a/internal/server/handler/oidc/userinfo_test.go +++ b/internal/server/handler/oidc/userinfo_test.go @@ -6,7 +6,6 @@ import ( "github.com/webishdev/stopnik/internal/config" "github.com/webishdev/stopnik/internal/endpoint" internalHttp "github.com/webishdev/stopnik/internal/http" - "github.com/webishdev/stopnik/internal/manager/key" "github.com/webishdev/stopnik/internal/manager/token" "io" "net/http" @@ -53,14 +52,12 @@ func Test_UserInfo(t *testing.T) { t.Fatal(initializationError) } - keyManger := key.GetKeyMangerInstance() - - testOidcUserInfo(t, testConfig, keyManger) + testOidcUserInfo(t, testConfig) testOidcUserInfoNotAllowedHttpMethods(t) } -func testOidcUserInfo(t *testing.T, testConfig *config.Config, keyManager *key.Manger) { +func testOidcUserInfo(t *testing.T, testConfig *config.Config) { t.Run("OIDC UserInfo", func(t *testing.T) { tokenManager := token.GetTokenManagerInstance() @@ -77,7 +74,7 @@ func testOidcUserInfo(t *testing.T, testConfig *config.Config, keyManager *key.M httpRequest := &http.Request{ Method: http.MethodGet, Header: http.Header{ - internalHttp.Authorization: []string{"Bearer " + tokenResponse.AccessTokenKey}, + internalHttp.Authorization: []string{"Bearer " + tokenResponse.AccessTokenValue}, }, } rr := httptest.NewRecorder() diff --git a/internal/server/handler/revoke/revoke_test.go b/internal/server/handler/revoke/revoke_test.go index 0d98503..174061e 100644 --- a/internal/server/handler/revoke/revoke_test.go +++ b/internal/server/handler/revoke/revoke_test.go @@ -49,61 +49,47 @@ func Test_Revoke(t *testing.T) { t.Fatal(initializationError) } - testRevokeMissingClientCredentials(t) - - testRevokeInvalidClientCredentials(t) - - testRevokeEmptyToken(t) - - testRevokeInvalidToken(t) - testRevoke(t, testConfig) testRevokeWithoutHint(t, testConfig) testRevokeDisabled(t, testConfig) - - testRevokeNotAllowedHttpMethods(t) } -func testRevokeMissingClientCredentials(t *testing.T) { - t.Run("Missing client credentials", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - tokenManager := token.GetTokenManagerInstance() +func Test_RevokeMissingClientCredentials(t *testing.T) { + requestValidator := validation.NewRequestValidator() + tokenManager := token.GetTokenManagerInstance() - revokeHandler := NewRevokeHandler(requestValidator, tokenManager) + revokeHandler := NewRevokeHandler(requestValidator, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - revokeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, endpoint.Revoke, nil)) + revokeHandler.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, endpoint.Revoke, nil)) - if rr.Code != http.StatusUnauthorized { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) - } - }) + if rr.Code != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) + } } -func testRevokeInvalidClientCredentials(t *testing.T) { - t.Run("Invalid client credentials", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - tokenManager := token.GetTokenManagerInstance() +func Test_RevokeInvalidClientCredentials(t *testing.T) { + requestValidator := validation.NewRequestValidator() + tokenManager := token.GetTokenManagerInstance() - revokeHandler := NewRevokeHandler(requestValidator, tokenManager) + revokeHandler := NewRevokeHandler(requestValidator, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodPost, endpoint.Revoke, nil) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "xxx"))) + request := httptest.NewRequest(http.MethodPost, endpoint.Revoke, nil) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "xxx"))) - revokeHandler.ServeHTTP(rr, request) + revokeHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusUnauthorized { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) - } - }) + if rr.Code != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusUnauthorized) + } } -func testRevokeEmptyToken(t *testing.T) { +func Test_RevokeEmptyToken(t *testing.T) { type introspectParameter struct { tokenHint oauth2.IntrospectTokenType } @@ -142,7 +128,7 @@ func testRevokeEmptyToken(t *testing.T) { } } -func testRevokeInvalidToken(t *testing.T) { +func Test_RevokeInvalidToken(t *testing.T) { type introspectParameter struct { tokenHint oauth2.IntrospectTokenType } @@ -220,15 +206,15 @@ func testRevoke(t *testing.T, testConfig *config.Config) { revokeHandler := NewRevokeHandler(requestValidator, tokenManager) - token := accessTokenResponse.AccessTokenKey + accessTokenValue := accessTokenResponse.AccessTokenValue if test.tokenHint == oauth2.ItRefreshToken { - token = accessTokenResponse.RefreshTokenKey + accessTokenValue = accessTokenResponse.RefreshTokenValue } rr := httptest.NewRecorder() bodyString := testCreateBody( - oauth2.ParameterToken, token, + oauth2.ParameterToken, accessTokenValue, oauth2.ParameterTokenTypeHint, test.tokenHint, ) body := strings.NewReader(bodyString) @@ -244,12 +230,12 @@ func testRevoke(t *testing.T, testConfig *config.Config) { } if test.tokenHint == oauth2.ItAccessToken { - _, accessTokenExists := tokenManager.GetAccessToken(token) + _, accessTokenExists := tokenManager.GetAccessToken(accessTokenValue) if accessTokenExists { t.Errorf("access token should have been revoked") } } else if test.tokenHint == oauth2.ItRefreshToken { - _, refreshTokenExists := tokenManager.GetRefreshToken(token) + _, refreshTokenExists := tokenManager.GetRefreshToken(accessTokenValue) if refreshTokenExists { t.Errorf("refresh token should have been revoked") } @@ -298,15 +284,15 @@ func testRevokeWithoutHint(t *testing.T, testConfig *config.Config) { revokeHandler := NewRevokeHandler(requestValidator, tokenManager) - token := accessTokenResponse.AccessTokenKey + accessTokenValue := accessTokenResponse.AccessTokenValue if test.tokenHint == oauth2.ItRefreshToken { - token = accessTokenResponse.RefreshTokenKey + accessTokenValue = accessTokenResponse.RefreshTokenValue } rr := httptest.NewRecorder() bodyString := testCreateBody( - oauth2.ParameterToken, token, + oauth2.ParameterToken, accessTokenValue, ) body := strings.NewReader(bodyString) @@ -321,12 +307,12 @@ func testRevokeWithoutHint(t *testing.T, testConfig *config.Config) { } if test.tokenHint == oauth2.ItAccessToken { - _, accessTokenExists := tokenManager.GetAccessToken(token) + _, accessTokenExists := tokenManager.GetAccessToken(accessTokenValue) if accessTokenExists { t.Errorf("access token should have been revoked") } } else if test.tokenHint == oauth2.ItRefreshToken { - _, refreshTokenExists := tokenManager.GetRefreshToken(token) + _, refreshTokenExists := tokenManager.GetRefreshToken(accessTokenValue) if refreshTokenExists { t.Errorf("refresh token should have been revoked") } @@ -375,15 +361,15 @@ func testRevokeDisabled(t *testing.T, testConfig *config.Config) { revokeHandler := NewRevokeHandler(requestValidator, tokenManager) - token := accessTokenResponse.AccessTokenKey + accessTokenValue := accessTokenResponse.AccessTokenValue if test.tokenHint == oauth2.ItRefreshToken { - token = accessTokenResponse.RefreshTokenKey + accessTokenValue = accessTokenResponse.RefreshTokenValue } rr := httptest.NewRecorder() bodyString := testCreateBody( - oauth2.ParameterToken, token, + oauth2.ParameterToken, accessTokenValue, oauth2.ParameterTokenTypeHint, test.tokenHint, ) body := strings.NewReader(bodyString) @@ -402,7 +388,7 @@ func testRevokeDisabled(t *testing.T, testConfig *config.Config) { } } -func testRevokeNotAllowedHttpMethods(t *testing.T) { +func Test_RevokeNotAllowedHttpMethods(t *testing.T) { var testInvalidRevokeHttpMethods = []string{ http.MethodGet, http.MethodPut, diff --git a/internal/server/handler/token/token_test.go b/internal/server/handler/token/token_test.go index 24c37b1..9c9635b 100644 --- a/internal/server/handler/token/token_test.go +++ b/internal/server/handler/token/token_test.go @@ -49,30 +49,10 @@ func Test_Token(t *testing.T) { t.Fatal(initializationError) } - testTokenMissingClientCredentials(t) - - testTokenInvalidClientCredentials(t) - - testTokenMissingGrandType(t) - - testTokenInvalidGrandType(t) - - testTokenAuthorizationCodeGrantTypeMissingCodeParameter(t) - - testTokenAuthorizationCodeGrantTypeInvalidPKCE(t) - - testTokenAuthorizationCodeGrantType(t) - - testTokenPasswordGrantType(t) - - testTokenClientCredentialsGrantType(t) - testTokenRefreshTokenGrantType(t, testConfig) - - testTokenNotAllowedHttpMethods(t) } -func testTokenMissingClientCredentials(t *testing.T) { +func Test_TokenMissingClientCredentials(t *testing.T) { t.Run("Missing client credentials for confidential client", func(t *testing.T) { requestValidator := validation.NewRequestValidator() sessionManager := session.GetAuthSessionManagerInstance() @@ -90,7 +70,7 @@ func testTokenMissingClientCredentials(t *testing.T) { }) } -func testTokenInvalidClientCredentials(t *testing.T) { +func Test_TokenInvalidClientCredentials(t *testing.T) { t.Run("Invalid client credentials", func(t *testing.T) { requestValidator := validation.NewRequestValidator() sessionManager := session.GetAuthSessionManagerInstance() @@ -111,126 +91,118 @@ func testTokenInvalidClientCredentials(t *testing.T) { }) } -func testTokenMissingGrandType(t *testing.T) { - t.Run("Missing grant type", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - tokenManager := token.GetTokenManagerInstance() +func Test_TokenMissingGrandType(t *testing.T) { + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + tokenManager := token.GetTokenManagerInstance() - tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) + tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) + request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) - tokenHandler.ServeHTTP(rr, request) + tokenHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusBadRequest { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) - } - }) + if rr.Code != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) + } } -func testTokenInvalidGrandType(t *testing.T) { - t.Run("Invalid grant type", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - tokenManager := token.GetTokenManagerInstance() +func Test_TokenInvalidGrandType(t *testing.T) { + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + tokenManager := token.GetTokenManagerInstance() - tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) + tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - bodyString := testCreateBody( - oauth2.ParameterGrantType, "foobar", - ) - body := strings.NewReader(bodyString) + bodyString := testCreateBody( + oauth2.ParameterGrantType, "foobar", + ) + body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) - request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) + request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") - tokenHandler.ServeHTTP(rr, request) + tokenHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusBadRequest { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) - } - }) + if rr.Code != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) + } } -func testTokenAuthorizationCodeGrantTypeMissingCodeParameter(t *testing.T) { - t.Run("Authorization code grant type, missing code parameter", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - tokenManager := token.GetTokenManagerInstance() +func Test_TokenAuthorizationCodeGrantTypeMissingCodeParameter(t *testing.T) { + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + tokenManager := token.GetTokenManagerInstance() - tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) + tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - bodyString := testCreateBody( - oauth2.ParameterGrantType, oauth2.GtAuthorizationCode, - ) - body := strings.NewReader(bodyString) + bodyString := testCreateBody( + oauth2.ParameterGrantType, oauth2.GtAuthorizationCode, + ) + body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) - request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) + request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") - tokenHandler.ServeHTTP(rr, request) + tokenHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusBadRequest { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) - } - }) + if rr.Code != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) + } } -func testTokenAuthorizationCodeGrantTypeInvalidPKCE(t *testing.T) { - t.Run("Authorization code grant type with invalid PKCE", func(t *testing.T) { - id := uuid.New() - pkceCodeChallenge := pkce.CalculatePKCE(pkce.S256, "foobar") - authSession := &session.AuthSession{ - Id: id.String(), - Redirect: "https://example.com/callback", - AuthURI: "https://example.com/auth", - CodeChallenge: pkceCodeChallenge, - CodeChallengeMethod: string(pkce.S256), - ClientId: "foo", - ResponseTypes: []oauth2.ResponseType{oauth2.RtCode}, - Scopes: []string{"foo:bar", "moo:abc"}, - State: "xyz", - } +func Test_TokenAuthorizationCodeGrantTypeInvalidPKCE(t *testing.T) { + id := uuid.New() + pkceCodeChallenge := pkce.CalculatePKCE(pkce.S256, "foobar") + authSession := &session.AuthSession{ + Id: id.String(), + Redirect: "https://example.com/callback", + AuthURI: "https://example.com/auth", + CodeChallenge: pkceCodeChallenge, + CodeChallengeMethod: string(pkce.S256), + ClientId: "foo", + ResponseTypes: []oauth2.ResponseType{oauth2.RtCode}, + Scopes: []string{"foo:bar", "moo:abc"}, + State: "xyz", + } - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - tokenManager := token.GetTokenManagerInstance() - sessionManager.StartSession(authSession) + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + tokenManager := token.GetTokenManagerInstance() + sessionManager.StartSession(authSession) - tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) + tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - bodyString := testCreateBody( - oauth2.ParameterGrantType, oauth2.GtAuthorizationCode, - oauth2.ParameterCode, id.String(), - pkce.ParameterCodeVerifier, "barfoo", - ) - body := strings.NewReader(bodyString) + bodyString := testCreateBody( + oauth2.ParameterGrantType, oauth2.GtAuthorizationCode, + oauth2.ParameterCode, id.String(), + pkce.ParameterCodeVerifier, "barfoo", + ) + body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) - request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) + request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") - tokenHandler.ServeHTTP(rr, request) + tokenHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusBadRequest { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) - } - }) + if rr.Code != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusBadRequest) + } } -func testTokenAuthorizationCodeGrantType(t *testing.T) { +func Test_TokenAuthorizationCodeGrantType(t *testing.T) { type authorizationGrantParameter struct { state string scope string @@ -313,70 +285,66 @@ func testTokenAuthorizationCodeGrantType(t *testing.T) { } } -func testTokenPasswordGrantType(t *testing.T) { - t.Run("Password grant type", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - tokenManager := token.GetTokenManagerInstance() +func Test_TokenPasswordGrantType(t *testing.T) { + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + tokenManager := token.GetTokenManagerInstance() - tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) + tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - bodyString := testCreateBody( - oauth2.ParameterGrantType, oauth2.GtPassword, - oauth2.ParameterUsername, "foo", - oauth2.ParameterPassword, "bar", - oauth2.ParameterScope, "foo:bar moo:abc", - ) - body := strings.NewReader(bodyString) + bodyString := testCreateBody( + oauth2.ParameterGrantType, oauth2.GtPassword, + oauth2.ParameterUsername, "foo", + oauth2.ParameterPassword, "bar", + oauth2.ParameterScope, "foo:bar moo:abc", + ) + body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) - request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) + request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") - tokenHandler.ServeHTTP(rr, request) + tokenHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } - response := rr.Result() + response := rr.Result() - testTokenValidate(t, tokenManager, response) - }) + testTokenValidate(t, tokenManager, response) } -func testTokenClientCredentialsGrantType(t *testing.T) { - t.Run("Client credentials grant type", func(t *testing.T) { - requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() - tokenManager := token.GetTokenManagerInstance() +func Test_TokenClientCredentialsGrantType(t *testing.T) { + requestValidator := validation.NewRequestValidator() + sessionManager := session.GetAuthSessionManagerInstance() + tokenManager := token.GetTokenManagerInstance() - tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) + tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - bodyString := testCreateBody( - oauth2.ParameterGrantType, oauth2.GtClientCredentials, - oauth2.ParameterScope, "foo:bar moo:abc", - ) - body := strings.NewReader(bodyString) + bodyString := testCreateBody( + oauth2.ParameterGrantType, oauth2.GtClientCredentials, + oauth2.ParameterScope, "foo:bar moo:abc", + ) + body := strings.NewReader(bodyString) - request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) - request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) - request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") + request := httptest.NewRequest(http.MethodPost, endpoint.Token, body) + request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar"))) + request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded") - tokenHandler.ServeHTTP(rr, request) + tokenHandler.ServeHTTP(rr, request) - if rr.Code != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) - } + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, http.StatusOK) + } - response := rr.Result() + response := rr.Result() - testTokenValidate(t, tokenManager, response) - }) + testTokenValidate(t, tokenManager, response) } func testTokenRefreshTokenGrantType(t *testing.T, testConfig *config.Config) { @@ -412,7 +380,7 @@ func testTokenRefreshTokenGrantType(t *testing.T, testConfig *config.Config) { bodyString := testCreateBody( oauth2.ParameterGrantType, oauth2.GtRefreshToken, - oauth2.ParameterRefreshToken, accessTokenResponse.RefreshTokenKey, + oauth2.ParameterRefreshToken, accessTokenResponse.RefreshTokenValue, ) body := strings.NewReader(bodyString) @@ -432,7 +400,7 @@ func testTokenRefreshTokenGrantType(t *testing.T, testConfig *config.Config) { }) } -func testTokenNotAllowedHttpMethods(t *testing.T) { +func Test_TokenNotAllowedHttpMethods(t *testing.T) { var testInvalidTokenHttpMethods = []string{ http.MethodGet, http.MethodPut, @@ -478,11 +446,11 @@ func testTokenValidate(t *testing.T, tokenManager *token.Manager, response *http t.Errorf("could not parse response body: %v", jsonParseError) } - if accessTokenResponse.AccessTokenKey == "" { + if accessTokenResponse.AccessTokenValue == "" { t.Errorf("access token key was empty") } - _, exists := tokenManager.GetAccessToken(accessTokenResponse.AccessTokenKey) + _, exists := tokenManager.GetAccessToken(accessTokenResponse.AccessTokenValue) if !exists { t.Errorf("access token was not found in access token manager") diff --git a/internal/server/server.go b/internal/server/server.go index 4ff01d9..36cfee3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -125,7 +125,7 @@ func (stopnikServer *StopnikServer) Start() { errorServer := stopnikServer.listenAndServe(stopnikServer.config.Server.TLS.Addr, *stopnikServer.serveTLS) if errorServer != nil && !errors.Is(errorServer, http.ErrServerClosed) { - log.Error("Error starting server: %v", errorServer) + log.Error("Error starting TLS server: %v", errorServer) } }() } diff --git a/website/docs/introduction/config.md b/website/docs/introduction/config.md index da7e05a..dc99577 100644 --- a/website/docs/introduction/config.md +++ b/website/docs/introduction/config.md @@ -10,30 +10,31 @@ The possible configuration options are listed in the next section. The configuration file (e.g. `config.yml`) may contain different root options which are described here as followed -| Property | Description | -|-----------|------------------------------| -| `server` | Server configuration | -| `ui` | User interface configuration | -| `clients` | List of clients | -| `users` | List of users | +| Property | Description | Required | +|-----------|------------------------------|----------| +| `server` | Server configuration | Yes | +| `ui` | User interface configuration | No | +| `clients` | List of clients | Yes | +| `users` | List of users | Yes | -### Server configuration +### Server configuration Root entry named `server` -| Property | Description | -|-------------------------|-----------------------------------------------------------------------------------| -| `logLevel` | Log level | -| `cookies` | Configuration related to cookie names | -| `addr` | [Go like address](https://pkg.go.dev/net#Dial), may contain IP and port | -| `secret` | Server secret | -| `privateKey` | General RSA or EC private key (can be overwritten for each client) to sign tokens | -| `tls` | Configuration for TLS | -| `logoutRedirect` | Where to redirect user after logout | -| `introspectScope` | Scope which allows token introspection | -| `revokeScopeScope` | Scope which allows token revocation | -| `sessionTimeoutSeconds` | Seconds until session will end | - +| Property | Description | Required | +|-------------------------|---------------------------------------------------------------------------------------------------|----------| +| `logLevel` | Log level | No | +| `cookies` | Configuration related to cookie names | No | +| `addr` | [Go like address](https://pkg.go.dev/net#Dial), may contain IP and port | Yes | +| `secret` | Server secret | No | +| `privateKey` | General RSA or EC private key (can be overwritten for each client) to sign tokens | No | +| `issuer` | Issuer | No | +| `tls` | Configuration for TLS | No | +| `logoutRedirect` | Where to redirect user after logout | No | +| `introspectScope` | Scope which allows token introspection | No | +| `revokeScopeScope` | Scope which allows token revocation | No | +| `sessionTimeoutSeconds` | Seconds until session will end | No | +| `forwardAuth` | [Traefik ForwardAuth](https://doc.traefik.io/traefik/middlewares/http/forwardauth/) configuration | No | #### TLS @@ -41,10 +42,10 @@ Public and private keys to sign tokens Entry `server.tls` -| Property | Description | -|----------|-------------------------------------------------------------------------| -| `addr` | [Go like address](https://pkg.go.dev/net#Dial), may contain IP and port | -| `keys` | Public and private keys for TLS | +| Property | Description | Required | +|----------|-------------------------------------------------------------------------|----------| +| `addr` | [Go like address](https://pkg.go.dev/net#Dial), may contain IP and port | Yes | +| `keys` | Public and private keys for TLS | Yes | ##### TLS keys @@ -52,10 +53,10 @@ Public and private keys for TLS Entry `server.tls.keys` -| Property | Description | -|----------|------------------| -| `cert` | Certificate file | -| `key` | Key file | +| Property | Description | Required | +|----------|------------------|----------| +| `cert` | Certificate file | Yes | +| `key` | Key file | Yes | #### Cookies @@ -63,23 +64,35 @@ Public and private keys to sign tokens Entry `server.cookies` -| Property | Description | -|---------------|----------------------------------| -| `authName` | Name of the authorization cookie | -| `messageName` | Name of internal message cookie | +| Property | Description | Required | +|---------------|----------------------------------|----------| +| `authName` | Name of the authorization cookie | No | +| `messageName` | Name of internal message cookie | No | + +#### ForwardAuth + +**STOPnik** supports [Traefik ForwardAuth](https://doc.traefik.io/traefik/middlewares/http/forwardauth/) out of the box. + +Entry `server.forwardAuth` + +| Property | Description | Required | +|-----------------|-----------------------------------------------------|----------| +| `endpoint` | Internal endpoint to be called by Traefik | No | +| `externalUrl` | URL of **STOPnik** to redirect the user for a login | Yes | +| `parameterName` | URL parameter used by **STOPnik** for ForwardAuth | No | ### User interface configuration Root entry named `ui` -| Property | Description | -|-------------------|---------------------------------------------------------------------------------------------------------------------------------------------| -| `logoImage` | Path of additional logo image | -| `logoContentType` | [HTTP Mime type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types) of logo image if not `image/png` | -| `hideFooter` | Will hide the **STOPnik** footer | -| `hideMascot` | Will hide the **STOPnik** mascot | -| `footerText` | The footer text | -| `title` | Title | +| Property | Description | Required | +|-------------------|---------------------------------------------------------------------------------------------------------------------------------------------|----------| +| `logoImage` | Path of additional logo image | No | +| `logoContentType` | [HTTP Mime type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types) of logo image if not `image/png` | No | +| `hideFooter` | Will hide the **STOPnik** footer | No | +| `hideMascot` | Will hide the **STOPnik** mascot | No | +| `footerText` | The footer text | No | +| `title` | Title | No | ### Clients @@ -89,26 +102,24 @@ Root entry `clients` Each entry may contain the following options -| Property | Description | -|---------------------------|---------------------------------------------------------| -| `id` | The id of the client | -| `clientSecret` | SHA512 hashed secret | -| `salt` | Optional salt for secret to avoid identical hash values | -| `accessTTL` | Access token time to live | -| `refreshTTL` | Refresh token time to live | -| `idTTL` | OpenId Connect ID token time to live | -| `oidc` | Flag to allow an client to handle OpenId Connect | -| `introspect` | Introspection scope | -| `revoke` | Revocation scope | -| `redirects` | List of redirects URIs | -| `opaqueToken` | Use opaque token | -| `passwordFallbackAllowed` | Form auth allowed | -| `claims` | List of claims | -| `issuer` | Issuer | -| `audience` | Audience | -| `privateKey` | RSA or EC private key to sign tokens | -| `rolesClaim` | Name for the claim used to provide roles | - +| Property | Description | Required | +|---------------------------|---------------------------------------------------------|----------| +| `id` | The id of the client | Yes | +| `clientSecret` | SHA512 hashed secret | Yes | +| `salt` | Optional salt for secret to avoid identical hash values | No | +| `accessTTL` | Access token time to live | No | +| `refreshTTL` | Refresh token time to live | No | +| `idTTL` | OpenId Connect ID token time to live | No | +| `oidc` | Flag to allow an client to handle OpenId Connect | No | +| `introspect` | Introspection scope | No | +| `revoke` | Revocation scope | No | +| `redirects` | List of redirects URIs | No | +| `opaqueToken` | Use opaque token | No | +| `passwordFallbackAllowed` | Form auth allowed | No | +| `claims` | List of claims | No | +| `audience` | Audience | No | +| `privateKey` | RSA or EC private key to sign tokens | No | +| `rolesClaim` | Name for the claim used to provide roles | No | For `clientSecret` and `salt` see, [Command line - Password](../advanced/cmd.md#password) @@ -120,10 +131,10 @@ Entry `clients[n].calims` Each entry may contain the following options -| Property | Description | -|----------|-------------| -| `name` | Name | -| `value` | Value | +| Property | Description | Required | +|----------|-------------|----------| +| `name` | Name | Yes | +| `value` | Value | Yes | ### Users @@ -133,13 +144,13 @@ Root entry `users` Each entry may contain the following options -| Property | Description | -|------------|--------------------------------------------------------------------| -| `username` | Username | -| `password` | SHA512 hashed password | -| `salt` | Optional salt for password to avoid identical hash values | -| `profile` | User profile which will be used for OpenId Connect UserInfo | -| `roles` | YAML map for roles, key of the map is the id of the related client | +| Property | Description | Required | +|------------|--------------------------------------------------------------------|----------| +| `username` | Username | Yes | +| `password` | SHA512 hashed password | Yes | +| `salt` | Optional salt for password to avoid identical hash values | No | +| `profile` | User profile which will be used for OpenId Connect UserInfo | No | +| `roles` | YAML map for roles, key of the map is the id of the related client | No | For `password` and `salt` see, [Command line - Password](../advanced/cmd.md#password) @@ -151,24 +162,24 @@ Entry `users[n].profile` Each entry may contain the following options -| Property | Description | -|---------------------|----------------------------------| -| `givenName` | Given name | -| `familyName` | Family name | -| `nickname` | Nickname | -| `preferredUserName` | Preferred username | -| `email` | E-Mail address | -| `emailVerified` | E-Mail address verification flag | -| `gender` | Gender | -| `birthDate` | Birthdate | -| `zoneInfo` | Zone information | -| `locale` | locale | -| `phoneNumber` | Phone number | -| `phoneVerified` | Phone number verficiation flag | -| `website` | Website URL | -| `profile` | Profile URL | -| `profilePicture` | Profile picture URL | -| `address` | User address | +| Property | Description | Required | +|---------------------|----------------------------------|----------| +| `givenName` | Given name | No | +| `familyName` | Family name | No | +| `nickname` | Nickname | No | +| `preferredUserName` | Preferred username | No | +| `email` | E-Mail address | No | +| `emailVerified` | E-Mail address verification flag | No | +| `gender` | Gender | No | +| `birthDate` | Birthdate | No | +| `zoneInfo` | Zone information | No | +| `locale` | locale | No | +| `phoneNumber` | Phone number | No | +| `phoneVerified` | Phone number verification flag | No | +| `website` | Website URL | No | +| `profile` | Profile URL | No | +| `profilePicture` | Profile picture URL | No | +| `address` | User address | No | #### User address @@ -178,14 +189,13 @@ Entry `users[n].profile.address` Each entry may contain the following options -| Property | Description | -|--------------|-------------| -| `street` | Street | -| `city` | City | -| `postalCode` | Postal code | -| `region` | Region | -| `country` | Country | - +| Property | Description | Required | +|--------------|-------------|----------| +| `street` | Street | No | +| `city` | City | No | +| `postalCode` | Postal code | No | +| `region` | Region | No | +| `country` | Country | No | ## Examples @@ -215,7 +225,8 @@ Not login and not `OAuth | OpenId Connect` flow will be possible. ### Development configuration -The shown `config.yml` is used during development and can be found [here](https://github.com/webishdev/stopnik/blob/main/config.yml) in the repository. +The shown `config.yml` is used during development and can be +found [here](https://github.com/webishdev/stopnik/blob/main/config.yml) in the repository. To be able to use it, the referenced `server.crt` and `server.key` must be created as self-signed certificate. @@ -226,17 +237,19 @@ server: authName: stopnik_auth messageName: stopnik_message #logoutRedirect: http://localhost:8080 + forwardAuth: + externalUrl: http://stopnik.localhost:9090 secret: WRYldij9ebtDZ5VJSsxNAfCZ - privateKey: ./test_keys/rsa256key.pem + privateKey: ./.test_files/rsa256key.pem addr: :8082 tls: addr: :8081 keys: - cert: ./test_keys/server.crt - key: ./test_keys/server.key + cert: ./.test_files/server.crt + key: ./.test_files/server.key ui: # hideFooter: true -# hideMascot: true +# hideLogo: true # footerText: Some nice line! # title: Test realm clients: @@ -252,6 +265,7 @@ clients: - https://oauth.pstmn.io/v1/callback - http://localhost:8080/session/callback - http://localhost:5173/reporting/oidc-callback* + - http://localhost:8082/health claims: - name: foo value: bar @@ -268,7 +282,7 @@ clients: salt: 321 accessTTL: 5 refreshTTL: 15 - privateKey: ./test_keys/ecdsa521key.pem + privateKey: ./.test_files/ecdsa521key.pem redirects: - https://oauth.pstmn.io/v1/callback users: