diff --git a/cmd/stopnik/commands.go b/cmd/stopnik/commands.go index 1ffe872..4f88c97 100644 --- a/cmd/stopnik/commands.go +++ b/cmd/stopnik/commands.go @@ -56,7 +56,7 @@ func start(configurationFile *string) error { return nil } -func readConfiguration(configurationFile *string, configLoader *config.Loader) (*config.Config, error) { +func readConfiguration(configurationFile *string, configLoader config.Loader) (*config.Config, error) { configError := configLoader.LoadConfig(*configurationFile, true) if configError != nil { fmt.Printf("STOPnik %s - %s\n\n", Version, GitHash) diff --git a/internal/config/config.go b/internal/config/config.go index b68d6a5..cdbd43a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,32 +11,37 @@ import ( "github.com/webishdev/stopnik/log" "io" "os" - "regexp" "strings" "sync" ) +// Keys defines path to TSL certificate and key file. type Keys struct { Cert string `yaml:"cert"` Key string `yaml:"key"` } +// TLS defines the Go like address to listen to and references the necessary Keys. type TLS struct { Addr string `yaml:"addr"` Keys Keys `yaml:"keys"` } +// Cookies defines the name for HTTP cookies used by STOPnik. type Cookies struct { AuthName string `yaml:"authName"` MessageName string `yaml:"messageName"` } +// ForwardAuth defines the configuration related to Traefik Forward Auth, +// only used when ExternalUrl is provided. type ForwardAuth struct { Endpoint string `yaml:"endpoint"` ExternalUrl string `yaml:"externalUrl"` ParameterName string `yaml:"parameterName"` } +// Server defines the main STOPnik server configuration. type Server struct { LogLevel string `yaml:"logLevel"` Addr string `yaml:"addr"` @@ -52,6 +57,8 @@ type Server struct { ForwardAuth ForwardAuth `yaml:"forwardAuth"` } +// UserAddress defines the address for a specific user, +// the definition provided in the YAML file will be mapped into values inside a JSON response. type UserAddress struct { Formatted string `json:"formatted,omitempty"` Street string `yaml:"street" json:"street_address,omitempty"` @@ -61,6 +68,8 @@ type UserAddress struct { Country string `yaml:"country" json:"country,omitempty"` } +// UserProfile defines the profile for a specific user, +// the definition provided in the YAML file will be mapped into values inside a JSON response. type UserProfile struct { Subject string `json:"sub,omitempty"` Name string `json:"name,omitempty"` @@ -83,6 +92,7 @@ type UserProfile struct { UpdatedAt string `json:"updated_at,omitempty"` } +// User defines the general user entry in the configuration. type User struct { Username string `yaml:"username"` Password string `yaml:"password"` @@ -91,11 +101,14 @@ type User struct { Roles map[string][]string `yaml:"roles"` } +// Claim defines additional claims with name and value, +// used for a specific Client. type Claim struct { Name string `yaml:"name"` Value string `yaml:"value"` } +// Client defines the general client entry in the configuration. type Client struct { Id string `yaml:"id"` ClientSecret string `yaml:"clientSecret"` @@ -116,6 +129,7 @@ type Client struct { isForwardAuth bool } +// UI defines the general web user interface entry in the configuration. type UI struct { HideFooter bool `yaml:"hideFooter"` HideLogo bool `yaml:"hideLogo"` @@ -125,6 +139,7 @@ type UI struct { LogoContentType string `yaml:"logoContentType"` } +// Config defines the root entry for the configuration. type Config struct { Server Server `yaml:"server"` Clients []Client `yaml:"clients"` @@ -141,6 +156,7 @@ type Config struct { var configLock = &sync.Mutex{} var configSingleton *Config +// GetConfigInstance returns the current singleton of Config when it was initialized by Initialize before. func GetConfigInstance() *Config { configLock.Lock() defer configLock.Unlock() @@ -151,6 +167,13 @@ func GetConfigInstance() *Config { return configSingleton } +// Initialize initializes a given Config. +// Checks for OIDC configuration on given Client entries. +// Initializes maps for faster Client and User access in the Config. +// Generates a server secret when none was provided. +// Loads a logo image into []byte to use in the web user interface. +// Checks for ForwardAuth settings. +// Sets the singleton for the current Config func Initialize(config *Config) error { configLock.Lock() defer configLock.Unlock() @@ -202,11 +225,12 @@ func Initialize(config *Config) error { config.logoImage = &bs } - if config.Server.ForwardAuth.Endpoint != "" { + if config.Server.ForwardAuth.ExternalUrl != "" { config.forwardAuthClient = &Client{ Id: uuid.NewString(), isForwardAuth: true, } + log.Info("Forward auth client created") } configSingleton = config @@ -214,6 +238,7 @@ func Initialize(config *Config) error { return nil } +// Validate validates the current Config and returns an error when necessary values are missing. func (config *Config) Validate() error { if config.Server.Addr == "" { return errors.New("no server address provided") @@ -262,67 +287,91 @@ func (config *Config) Validate() error { return nil } -func (config *Config) GetUser(name string) (*User, bool) { - value, exists := config.userMap[name] +// GetUser returns a User for the given username. +// Also returns a bool which indicates, whether the User exists or not. +func (config *Config) GetUser(username string) (*User, bool) { + value, exists := config.userMap[username] return value, exists } -func (config *Config) GetClient(name string) (*Client, bool) { - value, exists := config.clientMap[name] - if !exists && config.forwardAuthClient != nil { +// GetClient returns a Client for the given clientId. +// Also returns a bool which indicates, whether the Client exists or not. +func (config *Config) GetClient(clientId string) (*Client, bool) { + value, exists := config.clientMap[clientId] + if !exists && config.forwardAuthClient != nil && config.forwardAuthClient.Id == clientId { return config.forwardAuthClient, true } return value, exists } +// GetAuthCookieName returns the name of the authentication cookie. +// When no name is provided a default value will be returned. func (config *Config) GetAuthCookieName() string { return GetOrDefaultString(config.Server.Cookies.AuthName, "stopnik_auth") } +// GetMessageCookieName returns the name of the message cookie. +// When no name is provided a default value will be returned. func (config *Config) GetMessageCookieName() string { return GetOrDefaultString(config.Server.Cookies.MessageName, "stopnik_message") } +// GetSessionTimeoutSeconds returns the session timeout in seconds. +// When no session timeout is provided a default value will be returned. func (config *Config) GetSessionTimeoutSeconds() int { return GetOrDefaultInt(config.Server.SessionTimeoutSeconds, 3600) } +// GetIntrospectScope returns the scope which can be used to introspect tokens. +// When no scope is provided a default value will be returned. func (config *Config) GetIntrospectScope() string { return GetOrDefaultString(config.Server.IntrospectScope, "stopnik:introspect") } +// GetRevokeScope returns the scope which can be used to revoke tokens. +// When no scope is provided a default value will be returned. func (config *Config) GetRevokeScope() string { return GetOrDefaultString(config.Server.RevokeScope, "stopnik:revoke") } +// GetServerSecret returns the server secret. +// When no secret is provided a previously generated value will be returned. func (config *Config) GetServerSecret() string { return GetOrDefaultString(config.Server.Secret, config.generatedSecret) } +// GetHideFooter returns whether the footer should be hidden in the web user interface. func (config *Config) GetHideFooter() bool { return config.UI.HideFooter } -func (config *Config) GetHideMascot() bool { +// GetHideLogo returns whether the logo should be hidden in the web user interface. +func (config *Config) GetHideLogo() bool { return config.UI.HideLogo } +// GetTitle returns whether the title shown in the web user interface. func (config *Config) GetTitle() string { return config.UI.Title } +// GetFooterText returns whether the text shown in the footer of the web user interface. +// When no footer text is provided a default value will be returned. func (config *Config) GetFooterText() string { return GetOrDefaultString(config.UI.FooterText, "STOPnik") } +// GetLogoImage returns a pointer to the loaded logo image. Can be nil if no image was provided. func (config *Config) GetLogoImage() *[]byte { return config.logoImage } +// GetOidc returns whether one of the existing clients has OIDC flag set or not. func (config *Config) GetOidc() bool { return config.oidc } +// GetIssuer returns the issuer, either by mirroring from request, from Server configuration or default value. func (config *Config) GetIssuer(requestData *internalHttp.RequestData) string { if requestData == nil || requestData.Host == "" || requestData.Scheme == "" { return GetOrDefaultString(config.Server.Issuer, "STOPnik") @@ -330,38 +379,64 @@ func (config *Config) GetIssuer(requestData *internalHttp.RequestData) string { return GetOrDefaultString(config.Server.Issuer, requestData.IssuerString()) } +// GetForwardAuthEnabled returns whether Traefik Forward Auth is enabled or not. +// Check in general whether the ForwardAuth ExternalUrl value is set. func (config *Config) GetForwardAuthEnabled() bool { return config.Server.ForwardAuth.ExternalUrl != "" } +// GetForwardAuthEndpoint returns the endpoint which will use used for Traefik Forward Auth. +// When no endpoint is provided a default value will be returned. func (config *Config) GetForwardAuthEndpoint() string { return GetOrDefaultString(config.Server.ForwardAuth.Endpoint, "/forward") } +// GetForwardAuthParameterName returns the query parameter name which will use used for Traefik Forward Auth. +// When no query parameter name is provided a default value will be returned. func (config *Config) GetForwardAuthParameterName() string { return GetOrDefaultString(config.Server.ForwardAuth.ParameterName, "forward_id") } +func (config *Config) GetForwardAuthClient() (*Client, bool) { + if config.forwardAuthClient != nil && config.forwardAuthClient.Id != "" { + return config.forwardAuthClient, true + } + return nil, false +} + +// GetRolesClaim returns the name of the claim uses to provide User roles in a Client. +// When no name is provided a default value will be returned. func (client *Client) GetRolesClaim() string { return GetOrDefaultString(client.RolesClaim, "roles") } +// GetAccessTTL returns access token time to live. +// When no time to live is provided a default value will be returned. func (client *Client) GetAccessTTL() int { return GetOrDefaultInt(client.AccessTTL, 5) } +// GetRefreshTTL returns refresh token time to live. +// When no time to live is provided a default value will be returned. func (client *Client) GetRefreshTTL() int { return GetOrDefaultInt(client.RefreshTTL, 0) } +// GetIdTTL returns id token time to live. +// When no time to live is provided a default value will be returned. func (client *Client) GetIdTTL() int { return GetOrDefaultInt(client.IdTTL, 0) } +// GetAudience returns the audience value. +// When no audience value is provided a default value will be returned. func (client *Client) GetAudience() []string { return GetOrDefaultStringSlice(client.Audience, []string{"all"}) } +// GetClientType returns the client type value. +// When no client secret is provided the client will be a public client, confidential otherwise. +// See oauth2.ClientType func (client *Client) GetClientType() oauth2.ClientType { if client.ClientSecret == "" { return oauth2.CtPublic @@ -370,13 +445,15 @@ func (client *Client) GetClientType() oauth2.ClientType { } } -func (client *Client) ValidateRedirect(redirect string) (bool, error) { +// ValidateRedirect return whether the redirect is valid for a given Client or not. +func (client *Client) ValidateRedirect(redirect string) bool { if client.isForwardAuth { - return true, nil + return true } return validateRedirect(client.Id, client.Redirects, redirect) } +// GetPreferredUsername returns the preferred username for a given User, or just the username. func (user *User) GetPreferredUsername() string { if user.Profile.PreferredUserName == "" { return user.Username @@ -385,6 +462,7 @@ func (user *User) GetPreferredUsername() string { } } +// GetFormattedAddress return the formatted address for a User. func (user *User) GetFormattedAddress() string { userAddress := user.Profile.Address var sb strings.Builder @@ -403,14 +481,15 @@ func (user *User) GetFormattedAddress() string { return sb.String() } +// GetRoles returns the roles configured for the User for a given clientId. func (user *User) GetRoles(clientId string) []string { return user.Roles[clientId] } -func validateRedirect(clientId string, redirects []string, redirect string) (bool, error) { +func validateRedirect(clientId string, redirects []string, redirect string) bool { if redirect == "" { log.Error("Redirect provided for client %s was empty", clientId) - return false, nil + return false } redirectCount := len(redirects) @@ -419,29 +498,20 @@ func validateRedirect(clientId string, redirects []string, redirect string) (boo matchesRedirect := false for redirectIndex := range redirectCount { clientRedirect := redirects[redirectIndex] - clientRedirect = strings.Replace(clientRedirect, "/", "\\/", 1) - clientRedirect = strings.Replace(clientRedirect, ".", "\\.", 1) - clientRedirect = strings.Replace(clientRedirect, "?", "\\?", 1) endsWithWildcard := strings.HasSuffix(clientRedirect, "*") + var matched bool if endsWithWildcard { - clientRedirect = strings.Replace(clientRedirect[:len(clientRedirect)-1], "*", "\\*", 1) - clientRedirect = clientRedirect + ".*" + clientRedirect = clientRedirect[:len(clientRedirect)-1] + matched = strings.HasPrefix(redirect, clientRedirect) } else { - clientRedirect = strings.Replace(clientRedirect, "*", "\\*", 1) - } - clientRedirect = fmt.Sprintf("^%s$", clientRedirect) - matched, regexError := regexp.MatchString(clientRedirect, redirect) - if regexError != nil { - log.Error("Cloud not match redirect URI %s for client %s", redirect, clientId) - return false, regexError + matched = redirect == clientRedirect } - matchesRedirect = matchesRedirect || matched } - return matchesRedirect, nil + return matchesRedirect } else { log.Error("Client %s has no redirect URI(s) configured!", clientId) - return false, nil + return false } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ecf3de3..be81bc2 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -269,7 +269,7 @@ func Test_EmptyUIConfiguration(t *testing.T) { t.Error("expected title to be empty") } - hideMascot := config.GetHideMascot() + hideMascot := config.GetHideLogo() if hideMascot { t.Error("expected hideMascot to be false") } @@ -324,7 +324,7 @@ func Test_SimpleUIConfiguration(t *testing.T) { t.Error("expected title to be 'Oh my Foo!'") } - hideMascot := config.GetHideMascot() + hideMascot := config.GetHideLogo() if !hideMascot { t.Error("expected hideMascot to be true") } @@ -852,10 +852,7 @@ func Test_ValidateRedirects(t *testing.T) { for index, test := range redirectParameters { testMessage := fmt.Sprintf("Validate redirect %d %s", index, test.redirect) t.Run(testMessage, func(t *testing.T) { - result, validationError := validateRedirect("fooId", validRedirects, test.redirect) - if validationError != nil { - t.Error("Validation should not return an error") - } + result := validateRedirect("fooId", validRedirects, test.redirect) if result != test.expectedResult { t.Error("Redirect validation did not match") } @@ -865,20 +862,14 @@ func Test_ValidateRedirects(t *testing.T) { func Test_EmptyRedirect(t *testing.T) { var validRedirects = []string{"http://foo.com/callback", "https://foo.com/callback", "https://foo.com/wildcard/*"} - result, validationError := validateRedirect("fooId", validRedirects, "") - if validationError != nil { - t.Error("Validation should not return an error") - } + result := validateRedirect("fooId", validRedirects, "") if result { t.Error("Redirect validation did not match") } } func Test_NoRedirects(t *testing.T) { - result, validationError := validateRedirect("fooId", []string{}, "http://foo.com/callback") - if validationError != nil { - t.Error("Validation should not return an error") - } + result := validateRedirect("fooId", []string{}, "http://foo.com/callback") if result { t.Error("Redirect validation did not match") } @@ -970,18 +961,12 @@ func assertClientValues(t *testing.T, config *Config, expected testExpectedClien t.Errorf("expected expectedRoles claim to be '%s', got '%s'", expected.expectedRolesClaim, rolesClaim) } - validRedirect, validRedirectError := client.ValidateRedirect("http://localhost:8080/callback") - if validRedirectError != nil { - t.Error("expected valid redirect not return an error") - } + validRedirect := client.ValidateRedirect("http://localhost:8080/callback") if !validRedirect { t.Error("expected valid redirect") } - invalidRedirect, invalidRedirectError := client.ValidateRedirect("http://foo.com:8080/callback") - if invalidRedirectError != nil { - t.Error("expected invalid redirect not return an error") - } + invalidRedirect := client.ValidateRedirect("http://foo.com:8080/callback") if invalidRedirect { t.Error("did not expect redirect to be valid") } diff --git a/internal/config/doc.go b/internal/config/doc.go new file mode 100644 index 0000000..e5b1065 --- /dev/null +++ b/internal/config/doc.go @@ -0,0 +1,3 @@ +// Package config implements general handling for configuration files +// and adds definition for the configuration file YAML structure +package config diff --git a/internal/config/loader.go b/internal/config/loader.go index a59e554..4b113ed 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -1,21 +1,31 @@ package config +// ReadFile function definition to read a file by name into []byte. type ReadFile func(filename string) ([]byte, error) + +// Unmarshal function definition to unmarshal a []byte into a general interface. type Unmarshal func(in []byte, out interface{}) (err error) -type Loader struct { +// Loader defines how a configuration is loaded. +type Loader interface { + // LoadConfig loads the given configuration and validates if necessary. + LoadConfig(name string, validate bool) error +} + +type loader struct { fileReader ReadFile unmarshaler Unmarshal } -func NewConfigLoader(fileReader ReadFile, unmarshaler Unmarshal) *Loader { - return &Loader{ +// NewConfigLoader combines the ReadFile and Unmarshal functions into a Loader. +func NewConfigLoader(fileReader ReadFile, unmarshaler Unmarshal) Loader { + return &loader{ fileReader: fileReader, unmarshaler: unmarshaler, } } -func (loader *Loader) LoadConfig(name string, validate bool) error { +func (loader *loader) LoadConfig(name string, validate bool) error { data, readError := loader.fileReader(name) if readError != nil { return readError diff --git a/internal/config/support.go b/internal/config/support.go index f96db36..4103d27 100644 --- a/internal/config/support.go +++ b/internal/config/support.go @@ -34,6 +34,7 @@ func setup[T any](values *[]T, accessor func(T) string) map[string]*T { return valueMap } +// GetOrDefaultString returns a value or a default value if the given value is empty. func GetOrDefaultString(value string, defaultValue string) string { if value == "" { return defaultValue @@ -42,6 +43,7 @@ func GetOrDefaultString(value string, defaultValue string) string { } } +// GetOrDefaultStringSlice returns an array or a default array if the given array is empty. func GetOrDefaultStringSlice(value []string, defaultValue []string) []string { if len(value) == 0 { return defaultValue @@ -50,6 +52,7 @@ func GetOrDefaultStringSlice(value []string, defaultValue []string) []string { } } +// GetOrDefaultInt returns a value or a default value if the given value is empty. func GetOrDefaultInt(value int, defaultValue int) int { if value == 0 { return defaultValue diff --git a/internal/crypto/hash.go b/internal/crypto/hash.go index 5231fe5..5af41f2 100644 --- a/internal/crypto/hash.go +++ b/internal/crypto/hash.go @@ -6,10 +6,12 @@ import ( "fmt" ) +// Sha512Hash returns a SHA512 hash for the given value. func Sha512Hash(value string) string { return fmt.Sprintf("%x", sha512.Sum512([]byte(value))) } +// Sha512SaltedHash returns a SHA512 hash for the given value and salt. func Sha512SaltedHash(value string, salt string) string { if salt != "" { saltedValue := fmt.Sprintf("%s/!%s", value, salt) @@ -19,6 +21,7 @@ func Sha512SaltedHash(value string, salt string) string { } } +// Sha1Hash returns a SHA1 hash for the given value. func Sha1Hash(value string) string { return fmt.Sprintf("%x", sha1.Sum([]byte(value))) } diff --git a/internal/crypto/key.go b/internal/crypto/key.go index 444769e..dae8f77 100644 --- a/internal/crypto/key.go +++ b/internal/crypto/key.go @@ -12,6 +12,7 @@ import ( "os" ) +// HashAlgorithm used for the names of the supported hash algorithms. type HashAlgorithm string const ( @@ -20,12 +21,14 @@ const ( SHA512 HashAlgorithm = "SHA512" ) +// SigningPrivateKey defines a combination of private key, signing and hash algorithm. type SigningPrivateKey struct { PrivateKey interface{} SignatureAlgorithm jwa.SignatureAlgorithm HashAlgorithm HashAlgorithm } +// ManagedKey defines a combination of keys defined for config.Client. type ManagedKey struct { Id string Clients []*config.Client @@ -34,7 +37,9 @@ type ManagedKey struct { HashAlgorithm HashAlgorithm } +// ServerSecretLoader defines how to receive a private server key. type ServerSecretLoader interface { + // GetServerKey returns the private key of the server. GetServerKey() jwt.SignEncryptParseOption } @@ -42,11 +47,14 @@ type serverSecret struct { secret string } +// KeyLoader defines how to get ManagedKey for a specific client. type KeyLoader interface { + // LoadKeys returns a ManagedKey for a specific client and a bool indicating whether a key exists or not. LoadKeys(client *config.Client) (*ManagedKey, bool) ServerSecretLoader } +// NewServerSecretLoader creates a ServerSecretLoader based on the current config.Config. func NewServerSecretLoader() ServerSecretLoader { currentConfig := config.GetConfigInstance() return &serverSecret{secret: currentConfig.GetServerSecret()} @@ -56,8 +64,9 @@ func (s *serverSecret) GetServerKey() jwt.SignEncryptParseOption { return jwt.WithKey(jwa.HS256, []byte(s.secret)) } -func LoadPrivateKey(name string) (*SigningPrivateKey, error) { - privateKeyBytes, readError := os.ReadFile(name) +// LoadPrivateKey loads a private key from a given filename. +func LoadPrivateKey(filename string) (*SigningPrivateKey, error) { + privateKeyBytes, readError := os.ReadFile(filename) if readError != nil { return nil, readError } diff --git a/internal/manager/cookie/cookie.go b/internal/manager/cookie/cookie.go index c89c785..ca43b34 100644 --- a/internal/manager/cookie/cookie.go +++ b/internal/manager/cookie/cookie.go @@ -1,9 +1,11 @@ package cookie import ( + "fmt" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/webishdev/stopnik/internal/config" "github.com/webishdev/stopnik/internal/crypto" + "github.com/webishdev/stopnik/internal/manager/session" "github.com/webishdev/stopnik/log" "net/http" "sync" @@ -13,9 +15,10 @@ import ( type Now func() time.Time type Manager struct { - config *config.Config - keyFallback crypto.ServerSecretLoader - now Now + config *config.Config + loginSession session.Manager[session.LoginSession] + keyFallback crypto.ServerSecretLoader + now Now } var cookieManagerLock = &sync.Mutex{} @@ -33,9 +36,10 @@ func GetCookieManagerInstance() *Manager { func newCookieManagerWithTime(now Now) *Manager { configInstance := config.GetConfigInstance() return &Manager{ - config: configInstance, - keyFallback: crypto.NewServerSecretLoader(), - now: now, + config: configInstance, + loginSession: session.GetLoginSessionManagerInstance(), + keyFallback: crypto.NewServerSecretLoader(), + now: now, } } @@ -73,10 +77,10 @@ func (cookieManager *Manager) DeleteAuthCookie() http.Cookie { } } -func (cookieManager *Manager) CreateAuthCookie(username string) (http.Cookie, error) { +func (cookieManager *Manager) CreateAuthCookie(username string, loginSessionId string) (http.Cookie, error) { authCookieName := cookieManager.config.GetAuthCookieName() log.Debug("Creating %s auth cookie", authCookieName) - value, err := cookieManager.generateCookieValue(username) + value, err := cookieManager.generateAuthCookieValue(username, loginSessionId) if err != nil { return http.Cookie{}, err } @@ -90,34 +94,47 @@ func (cookieManager *Manager) CreateAuthCookie(username string) (http.Cookie, er }, nil } -func (cookieManager *Manager) ValidateAuthCookie(r *http.Request) (*config.User, bool) { +func (cookieManager *Manager) ValidateAuthCookie(r *http.Request) (*config.User, *session.LoginSession, bool) { authCookieName := cookieManager.config.GetAuthCookieName() - log.Debug("Validating %s cookie", authCookieName) + log.Debug("Validating %s auth cookie", authCookieName) cookie, cookieError := r.Cookie(authCookieName) if cookieError != nil { - return &config.User{}, false + return &config.User{}, &session.LoginSession{}, false } else { - return cookieManager.validateCookieValue(cookie) + return cookieManager.validateAuthCookieValue(cookie) } } -func (cookieManager *Manager) validateCookieValue(cookie *http.Cookie) (*config.User, bool) { +func (cookieManager *Manager) validateAuthCookieValue(cookie *http.Cookie) (*config.User, *session.LoginSession, bool) { options := cookieManager.keyFallback.GetServerKey() token, err := jwt.Parse([]byte(cookie.Value), options) if err != nil { - return &config.User{}, false + return &config.User{}, &session.LoginSession{}, false + } + + loginClaim, loginClaimExists := token.Get("login") + if !loginClaimExists { + return &config.User{}, &session.LoginSession{}, false + } + + loginSessionId := fmt.Sprintf("%s", loginClaim) + + loginSession, loginSessionExists := cookieManager.loginSession.GetSession(loginSessionId) + if !loginSessionExists { + return &config.User{}, &session.LoginSession{}, false } // https://stackoverflow.com/a/61284284/4094586 username := token.Subject() user, userExists := cookieManager.config.GetUser(username) - return user, userExists + return user, loginSession, userExists } -func (cookieManager *Manager) generateCookieValue(username string) (string, error) { +func (cookieManager *Manager) generateAuthCookieValue(username string, loginSessionId string) (string, error) { sessionTimeout := cookieManager.config.GetSessionTimeoutSeconds() token, builderError := jwt.NewBuilder(). Subject(username). + Claim("login", loginSessionId). Expiration(cookieManager.now().Add(time.Second * time.Duration(sessionTimeout))). Build() if builderError != nil { diff --git a/internal/manager/cookie/cookie_test.go b/internal/manager/cookie/cookie_test.go index 9078c24..ddb43cf 100644 --- a/internal/manager/cookie/cookie_test.go +++ b/internal/manager/cookie/cookie_test.go @@ -2,8 +2,10 @@ package cookie import ( "fmt" + "github.com/google/uuid" "github.com/webishdev/stopnik/internal/config" "github.com/webishdev/stopnik/internal/endpoint" + "github.com/webishdev/stopnik/internal/manager/session" "net/http" "net/http/httptest" "testing" @@ -29,8 +31,15 @@ func Test_Cookie(t *testing.T) { t.Run("Create and validate auth cookie", func(t *testing.T) { cookieManager := GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() - cookie, cookieError := cookieManager.CreateAuthCookie("foo") + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: "foo", + } + loginSessionManager.StartSession(loginSession) + + cookie, cookieError := cookieManager.CreateAuthCookie("foo", loginSession.Id) if cookieError != nil { t.Error(cookieError) @@ -44,9 +53,9 @@ func Test_Cookie(t *testing.T) { }, } - _, userExists := cookieManager.ValidateAuthCookie(httpRequest) + _, _, exists := cookieManager.ValidateAuthCookie(httpRequest) - if !userExists { + if !exists { t.Error("Token in auth cookie is invalid") } }) @@ -54,7 +63,7 @@ func Test_Cookie(t *testing.T) { t.Run("Create and validate expired auth cookie", func(t *testing.T) { cookieManager := newCookieManagerWithTime(now) - cookie, cookieError := cookieManager.CreateAuthCookie("foo") + cookie, cookieError := cookieManager.CreateAuthCookie("foo", "someId") if cookieError != nil { t.Error(cookieError) @@ -68,9 +77,9 @@ func Test_Cookie(t *testing.T) { mockedTime = mockedTime.Add(time.Hour * time.Duration(-6)) - _, userExists := cookieManager.ValidateAuthCookie(httpRequest) + _, _, exists := cookieManager.ValidateAuthCookie(httpRequest) - if userExists { + if exists { t.Error("Expired token should not provide a user") } }) @@ -78,7 +87,7 @@ func Test_Cookie(t *testing.T) { t.Run("Create and validate auth cookie with wrong username", func(t *testing.T) { cookieManager := GetCookieManagerInstance() - cookie, cookieError := cookieManager.CreateAuthCookie("bar") + cookie, cookieError := cookieManager.CreateAuthCookie("bar", "someId") if cookieError != nil { t.Error(cookieError) @@ -90,9 +99,9 @@ func Test_Cookie(t *testing.T) { }, } - _, userExists := cookieManager.ValidateAuthCookie(httpRequest) + _, _, exists := cookieManager.ValidateAuthCookie(httpRequest) - if userExists { + if exists { t.Error("User should not exists") } }) @@ -108,9 +117,9 @@ func Test_Cookie(t *testing.T) { }, } - _, userExists := cookieManager.ValidateAuthCookie(httpRequest) + _, _, exists := cookieManager.ValidateAuthCookie(httpRequest) - if userExists { + if exists { t.Error("User should not exists") } }) diff --git a/internal/manager/session/auth.go b/internal/manager/session/auth.go index caadf03..bdad461 100644 --- a/internal/manager/session/auth.go +++ b/internal/manager/session/auth.go @@ -52,3 +52,6 @@ func (authManager *AuthManager) GetSession(id string) (*AuthSession, bool) { authSessionStore := *authManager.authSessionStore return authSessionStore.Get(id) } + +func (authManager *AuthManager) CloseSession(_ string, _ bool) { +} diff --git a/internal/manager/session/forward.go b/internal/manager/session/forward.go index befdd7e..5fd9e23 100644 --- a/internal/manager/session/forward.go +++ b/internal/manager/session/forward.go @@ -43,3 +43,6 @@ func (forwardManager *ForwardManager) GetSession(id string) (*ForwardSession, bo forwardSessionStore := *forwardManager.forwardSessionStore return forwardSessionStore.Get(id) } + +func (forwardManager *ForwardManager) CloseSession(_ string, _ bool) { +} diff --git a/internal/manager/session/login.go b/internal/manager/session/login.go new file mode 100644 index 0000000..b8959df --- /dev/null +++ b/internal/manager/session/login.go @@ -0,0 +1,70 @@ +package session + +import ( + "github.com/webishdev/stopnik/internal/config" + "github.com/webishdev/stopnik/internal/store" + "github.com/webishdev/stopnik/log" + "sync" + "time" +) + +type LoginSession struct { + Id string + Username string +} + +type LoginManager struct { + config *config.Config + loginSessionStore *store.ExpiringStore[LoginSession] +} + +var loginSessionManagerLock = &sync.Mutex{} +var loginSessionManagerSingleton *LoginManager + +func GetLoginSessionManagerInstance() Manager[LoginSession] { + loginSessionManagerLock.Lock() + defer loginSessionManagerLock.Unlock() + if loginSessionManagerSingleton == nil { + currentConfig := config.GetConfigInstance() + duration := time.Minute * time.Duration(currentConfig.GetSessionTimeoutSeconds()) + loginSessionStore := store.NewTimedStore[LoginSession](duration) + loginSessionManagerSingleton = &LoginManager{ + config: currentConfig, + loginSessionStore: &loginSessionStore, + } + } + return loginSessionManagerSingleton +} + +func (loginManager *LoginManager) StartSession(loginSession *LoginSession) { + loginSessionStore := *loginManager.loginSessionStore + loginSessionStore.Set(loginSession.Id, loginSession) +} + +func (loginManager *LoginManager) GetSession(id string) (*LoginSession, bool) { + loginSessionStore := *loginManager.loginSessionStore + return loginSessionStore.Get(id) +} + +func (loginManager *LoginManager) CloseSession(id string, all bool) { + loginSessionStore := *loginManager.loginSessionStore + loginSession, loginSessionExists := loginSessionStore.Get(id) + if loginSessionExists { + log.Info("Closing main login session with id %s", id) + loginSessionStore.Delete(id) + if all { + username := loginSession.Username + var userSessionIds []string + for _, otherSession := range loginSessionStore.GetValues() { + if otherSession.Username == username && otherSession.Id != id { + userSessionIds = append(userSessionIds, otherSession.Id) + } + } + for _, otherSessionId := range userSessionIds { + log.Info("Closing login session with id %s", otherSessionId) + loginSessionStore.Delete(otherSessionId) + } + } + } + +} diff --git a/internal/manager/session/session.go b/internal/manager/session/session.go index ab5029e..51dd663 100644 --- a/internal/manager/session/session.go +++ b/internal/manager/session/session.go @@ -3,4 +3,5 @@ package session type Manager[T any] interface { StartSession(session *T) GetSession(id string) (*T, bool) + CloseSession(id string, all bool) } diff --git a/internal/oauth2/types.go b/internal/oauth2/types.go index 17133b3..9c2958a 100644 --- a/internal/oauth2/types.go +++ b/internal/oauth2/types.go @@ -2,13 +2,11 @@ package oauth2 import "strings" +// GrantType as described in +// - https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.10 +// - https://datatracker.ietf.org/doc/html/rfc7591#section-2 type GrantType string -/* - * GrantType as described in - * - https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.10 - * - https://datatracker.ietf.org/doc/html/rfc7591#section-2 - */ const ( GtAuthorizationCode GrantType = "authorization_code" GtClientCredentials GrantType = "client_credentials" diff --git a/internal/server/handler/account/account.go b/internal/server/handler/account/account.go index bfc2474..3297302 100644 --- a/internal/server/handler/account/account.go +++ b/internal/server/handler/account/account.go @@ -4,6 +4,7 @@ import ( "github.com/google/uuid" internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/manager/cookie" + "github.com/webishdev/stopnik/internal/manager/session" "github.com/webishdev/stopnik/internal/server/handler/error" "github.com/webishdev/stopnik/internal/server/validation" "github.com/webishdev/stopnik/internal/template" @@ -12,29 +13,32 @@ import ( ) type Handler struct { - validator *validation.RequestValidator - cookieManager *cookie.Manager - templateManager *template.Manager - errorHandler *error.Handler + validator *validation.RequestValidator + cookieManager *cookie.Manager + loginSessionManager session.Manager[session.LoginSession] + templateManager *template.Manager + errorHandler *error.Handler } func NewAccountHandler( validator *validation.RequestValidator, cookieManager *cookie.Manager, + loginSessionManager session.Manager[session.LoginSession], templateManager *template.Manager, ) *Handler { return &Handler{ - validator: validator, - cookieManager: cookieManager, - templateManager: templateManager, - errorHandler: error.NewErrorHandler(), + validator: validator, + cookieManager: cookieManager, + loginSessionManager: loginSessionManager, + templateManager: templateManager, + errorHandler: error.NewErrorHandler(), } } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.AccessLogRequest(r) if r.Method == http.MethodGet { - user, validCookie := h.cookieManager.ValidateAuthCookie(r) + user, _, validCookie := h.cookieManager.ValidateAuthCookie(r) if validCookie { logoutTemplate := h.templateManager.LogoutTemplate(user.Username, r.RequestURI) @@ -68,13 +72,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - cookie, err := h.cookieManager.CreateAuthCookie(user.Username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: user.Username, + } + h.loginSessionManager.StartSession(loginSession) + authCookie, err := h.cookieManager.CreateAuthCookie(user.Username, loginSession.Id) if err != nil { h.errorHandler.InternalServerErrorHandler(w, r) return } - http.SetCookie(w, &cookie) + http.SetCookie(w, &authCookie) w.Header().Set(internalHttp.Location, r.RequestURI) w.WriteHeader(http.StatusSeeOther) diff --git a/internal/server/handler/account/account_test.go b/internal/server/handler/account/account_test.go index e4f4331..18433c2 100644 --- a/internal/server/handler/account/account_test.go +++ b/internal/server/handler/account/account_test.go @@ -2,10 +2,12 @@ package account import ( "fmt" + "github.com/google/uuid" "github.com/webishdev/stopnik/internal/config" "github.com/webishdev/stopnik/internal/endpoint" internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/manager/cookie" + "github.com/webishdev/stopnik/internal/manager/session" "github.com/webishdev/stopnik/internal/server/validation" "github.com/webishdev/stopnik/internal/template" "io" @@ -40,12 +42,18 @@ func Test_AccountWithCookie(t *testing.T) { requestValidator := validation.NewRequestValidator() cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() templateManager := template.GetTemplateManagerInstance() user, _ := testConfig.GetUser("foo") - authCookie, _ := cookieManager.CreateAuthCookie(user.Username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: user.Username, + } + loginSessionManager.StartSession(loginSession) + authCookie, _ := cookieManager.CreateAuthCookie(user.Username, loginSession.Id) - accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) + accountHandler := NewAccountHandler(requestValidator, cookieManager, loginSessionManager, templateManager) rr := httptest.NewRecorder() @@ -79,9 +87,10 @@ func Test_AccountWithCookie(t *testing.T) { func Test_AccountWithoutCookie(t *testing.T) { requestValidator := validation.NewRequestValidator() cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() templateManager := template.GetTemplateManagerInstance() - accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) + accountHandler := NewAccountHandler(requestValidator, cookieManager, loginSessionManager, templateManager) rr := httptest.NewRecorder() @@ -124,11 +133,17 @@ func Test_AccountLogin(t *testing.T) { t.Run(testMessage, func(t *testing.T) { requestValidator := validation.NewRequestValidator() cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() templateManager := template.GetTemplateManagerInstance() - authCookie, _ := cookieManager.CreateAuthCookie(test.username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: test.username, + } + loginSessionManager.StartSession(loginSession) + authCookie, _ := cookieManager.CreateAuthCookie(test.username, loginSession.Id) - accountHandler := NewAccountHandler(requestValidator, cookieManager, templateManager) + accountHandler := NewAccountHandler(requestValidator, cookieManager, loginSessionManager, templateManager) rr := httptest.NewRecorder() @@ -170,7 +185,7 @@ func Test_AccountNotAllowedHttpMethods(t *testing.T) { for _, method := range testInvalidAccountHttpMethods { testMessage := fmt.Sprintf("Account with unsupported method %s", method) t.Run(testMessage, func(t *testing.T) { - accountHandler := NewAccountHandler(&validation.RequestValidator{}, &cookie.Manager{}, &template.Manager{}) + accountHandler := NewAccountHandler(&validation.RequestValidator{}, &cookie.Manager{}, &session.LoginManager{}, &template.Manager{}) rr := httptest.NewRecorder() diff --git a/internal/server/handler/authorize/authorize.go b/internal/server/handler/authorize/authorize.go index 8c55827..cedc130 100644 --- a/internal/server/handler/authorize/authorize.go +++ b/internal/server/handler/authorize/authorize.go @@ -22,27 +22,30 @@ import ( ) type Handler struct { - validator *validation.RequestValidator - cookieManager *cookie.Manager - authSessionManager session.Manager[session.AuthSession] - tokenManager *token.Manager - templateManager *template.Manager - errorHandler *error.Handler + validator *validation.RequestValidator + cookieManager *cookie.Manager + authSessionManager session.Manager[session.AuthSession] + loginSessionManager session.Manager[session.LoginSession] + tokenManager *token.Manager + templateManager *template.Manager + errorHandler *error.Handler } func NewAuthorizeHandler( validator *validation.RequestValidator, cookieManager *cookie.Manager, authSessionManager session.Manager[session.AuthSession], + loginSessionManager session.Manager[session.LoginSession], tokenManager *token.Manager, templateManager *template.Manager) *Handler { return &Handler{ - validator: validator, - cookieManager: cookieManager, - authSessionManager: authSessionManager, - tokenManager: tokenManager, - templateManager: templateManager, - errorHandler: error.NewErrorHandler(), + validator: validator, + cookieManager: cookieManager, + authSessionManager: authSessionManager, + loginSessionManager: loginSessionManager, + tokenManager: tokenManager, + templateManager: templateManager, + errorHandler: error.NewErrorHandler(), } } @@ -157,7 +160,7 @@ func (h *Handler) handleGetRequest(w http.ResponseWriter, r *http.Request) { h.authSessionManager.StartSession(authSession) } - user, validCookie := h.cookieManager.ValidateAuthCookie(r) + user, _, validCookie := h.cookieManager.ValidateAuthCookie(r) if validCookie { authSession.Username = user.Username @@ -210,7 +213,12 @@ func (h *Handler) handleGetRequest(w http.ResponseWriter, r *http.Request) { } func (h *Handler) handlePostRequest(w http.ResponseWriter, r *http.Request, user *config.User) { - authCookie, authCookieError := h.cookieManager.CreateAuthCookie(user.Username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: user.Username, + } + h.loginSessionManager.StartSession(loginSession) + authCookie, authCookieError := h.cookieManager.CreateAuthCookie(user.Username, loginSession.Id) if authCookieError != nil { h.errorHandler.InternalServerErrorHandler(w, r) return @@ -268,10 +276,7 @@ func (h *Handler) validateLogin(w http.ResponseWriter, r *http.Request) (*config } func (h *Handler) validateRedirect(client *config.Client, redirect string) func(w http.ResponseWriter, r *http.Request) { - validRedirect, validationError := client.ValidateRedirect(redirect) - if validationError != nil { - return h.errorHandler.InternalServerErrorHandler - } + validRedirect := client.ValidateRedirect(redirect) if !validRedirect { return h.errorHandler.BadRequestHandler diff --git a/internal/server/handler/authorize/authorize_test.go b/internal/server/handler/authorize/authorize_test.go index 5af6cf4..39a0a5f 100644 --- a/internal/server/handler/authorize/authorize_test.go +++ b/internal/server/handler/authorize/authorize_test.go @@ -85,7 +85,7 @@ func Test_AuthorizeInvalidLogin(t *testing.T) { cookieManager := cookie.GetCookieManagerInstance() requestValidator := validation.NewRequestValidator() - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, &session.AuthManager{}, &session.LoginManager{}, &token.Manager{}, &template.Manager{}) rr := httptest.NewRecorder() @@ -172,7 +172,7 @@ func Test_AuthorizeEmptyLogin(t *testing.T) { cookieManager := cookie.GetCookieManagerInstance() requestValidator := validation.NewRequestValidator() - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, &session.AuthManager{}, &session.LoginManager{}, &token.Manager{}, &template.Manager{}) rr := httptest.NewRecorder() @@ -255,11 +255,12 @@ func Test_AuthorizeValidLoginNoSession(t *testing.T) { } }) requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() tokenManager := token.GetTokenManagerInstance() - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, loginSessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() @@ -333,7 +334,7 @@ func Test_AuthorizeNotAllowedHttpMethods(t *testing.T) { for _, method := range testInvalidAuthorizeHttpMethods { testMessage := fmt.Sprintf("Authorize with unsupported method %s", method) t.Run(testMessage, func(t *testing.T) { - authorizeHandler := NewAuthorizeHandler(&validation.RequestValidator{}, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(&validation.RequestValidator{}, &cookie.Manager{}, &session.AuthManager{}, &session.LoginManager{}, &token.Manager{}, &template.Manager{}) rr := httptest.NewRecorder() @@ -346,18 +347,19 @@ func Test_AuthorizeNotAllowedHttpMethods(t *testing.T) { } } -func Test_AuthorizeNoCookeExists(t *testing.T) { +func Test_AuthorizeNoCookieExists(t *testing.T) { parsedUri := createUri(t, endpoint.Authorization, func(query url.Values) { query.Set(oauth2.ParameterClientId, "foo") query.Set(oauth2.ParameterRedirectUri, "https://example.com/callback") query.Set(oauth2.ParameterResponseType, oauth2.ParameterCode) }) requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() templateManager := template.GetTemplateManagerInstance() - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, &token.Manager{}, templateManager) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, loginSessionManager, &token.Manager{}, templateManager) rr := httptest.NewRecorder() @@ -393,7 +395,7 @@ func Test_AuthorizeInvalidResponseType(t *testing.T) { }) requestValidator := validation.NewRequestValidator() - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &session.LoginManager{}, &token.Manager{}, &template.Manager{}) rr := httptest.NewRecorder() @@ -443,7 +445,7 @@ func Test_AuthorizeInvalidRedirect(t *testing.T) { requestValidator := validation.NewRequestValidator() - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &session.LoginManager{}, &token.Manager{}, &template.Manager{}) rr := httptest.NewRecorder() @@ -464,7 +466,7 @@ func Test_AuthorizeInvalidClientId(t *testing.T) { requestValidator := validation.NewRequestValidator() - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &session.LoginManager{}, &token.Manager{}, &template.Manager{}) rr := httptest.NewRecorder() @@ -478,7 +480,7 @@ func Test_AuthorizeInvalidClientId(t *testing.T) { func Test_AuthorizeNoClientId(t *testing.T) { requestValidator := validation.NewRequestValidator() - authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &token.Manager{}, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, &cookie.Manager{}, &session.AuthManager{}, &session.LoginManager{}, &token.Manager{}, &template.Manager{}) rr := httptest.NewRecorder() @@ -529,14 +531,20 @@ func testAuthorizeAuthorizationGrant(t *testing.T, testConfig *config.Config) { } }) requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() tokenManager := token.GetTokenManagerInstance() user, _ := testConfig.GetUser("foo") - authCookie, _ := cookieManager.CreateAuthCookie(user.Username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: user.Username, + } + loginSessionManager.StartSession(loginSession) + authCookie, _ := cookieManager.CreateAuthCookie(user.Username, loginSession.Id) - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, loginSessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, parsedUri.String(), nil) @@ -565,7 +573,7 @@ func testAuthorizeAuthorizationGrant(t *testing.T, testConfig *config.Config) { t.Errorf("state parameter %v did not match: %v", stateQueryParameter, test.state) } - authSession, sessionExists := sessionManager.GetSession(codeQueryParameter) + authSession, sessionExists := authSessionManager.GetSession(codeQueryParameter) if !sessionExists { t.Errorf("session does not exist: %v", codeQueryParameter) } @@ -612,15 +620,21 @@ func testAuthorizeImplicitGrant(t *testing.T, testConfig *config.Config) { } }) requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() tokenManager := token.GetTokenManagerInstance() client, _ := testConfig.GetClient("foo") user, _ := testConfig.GetUser("foo") - authCookie, _ := cookieManager.CreateAuthCookie(user.Username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: user.Username, + } + loginSessionManager.StartSession(loginSession) + authCookie, _ := cookieManager.CreateAuthCookie(user.Username, loginSession.Id) - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, loginSessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, parsedUri.String(), nil) @@ -708,12 +722,13 @@ func testAuthorizeValidLoginAuthorizationGrant(t *testing.T, testConfig *config. } requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() tokenManager := token.GetTokenManagerInstance() - sessionManager.StartSession(authSession) + authSessionManager.StartSession(authSession) - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, loginSessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() @@ -796,12 +811,13 @@ func testAuthorizeValidLoginImplicitGrant(t *testing.T, testConfig *config.Confi } requestValidator := validation.NewRequestValidator() - sessionManager := session.GetAuthSessionManagerInstance() + authSessionManager := session.GetAuthSessionManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() cookieManager := cookie.GetCookieManagerInstance() tokenManager := token.GetTokenManagerInstance() - sessionManager.StartSession(authSession) + authSessionManager.StartSession(authSession) - authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) + authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, loginSessionManager, tokenManager, &template.Manager{}) rr := httptest.NewRecorder() diff --git a/internal/server/handler/forwardauth/forwardauth.go b/internal/server/handler/forwardauth/forwardauth.go index 6a356be..30182c6 100644 --- a/internal/server/handler/forwardauth/forwardauth.go +++ b/internal/server/handler/forwardauth/forwardauth.go @@ -22,17 +22,19 @@ type Handler struct { cookieManager *cookie.Manager authSessionManager session.Manager[session.AuthSession] forwardSessionManager session.Manager[session.ForwardSession] + loginSessionManager session.Manager[session.LoginSession] templateManager *template.Manager errorHandler *internalError.Handler } -func NewForwardAuthHandler(cookieManager *cookie.Manager, authSessionManager session.Manager[session.AuthSession], forwardSessionManager session.Manager[session.ForwardSession], templateManager *template.Manager) *Handler { +func NewForwardAuthHandler(cookieManager *cookie.Manager, authSessionManager session.Manager[session.AuthSession], forwardSessionManager session.Manager[session.ForwardSession], loginSessionManager session.Manager[session.LoginSession], templateManager *template.Manager) *Handler { currentConfig := config.GetConfigInstance() return &Handler{ config: currentConfig, cookieManager: cookieManager, authSessionManager: authSessionManager, forwardSessionManager: forwardSessionManager, + loginSessionManager: loginSessionManager, templateManager: templateManager, errorHandler: internalError.NewErrorHandler(), } @@ -41,6 +43,12 @@ func NewForwardAuthHandler(cookieManager *cookie.Manager, authSessionManager ses func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.AccessLogRequest(r) + forwardAuthClient, forwardAuthClientExists := h.config.GetForwardAuthClient() + if !forwardAuthClientExists { + h.errorHandler.InternalServerErrorHandler(w, r) + return + } + forwardProtocol := r.Header.Get(internalHttp.XForwardProtocol) forwardHost := r.Header.Get(internalHttp.XForwardHost) forwardPath := r.Header.Get(internalHttp.XForwardUri) @@ -62,7 +70,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { codeParameter := forwardUri.Query().Get(oauth2.ParameterCode) forwardIdParameter := forwardUri.Query().Get(forwardAuthParameterName) - _, validCookie := h.cookieManager.ValidateAuthCookie(r) + _, _, validCookie := h.cookieManager.ValidateAuthCookie(r) if validCookie { w.WriteHeader(http.StatusOK) @@ -98,7 +106,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { parsedUri, parsedUriError := createUri(h.config.Server.ForwardAuth.ExternalUrl, endpoint.Authorization, func(query url.Values) { query.Set(oauth2.ParameterResponseType, string(oauth2.RtCode)) - query.Set(oauth2.ParameterClientId, "fooobar") + query.Set(oauth2.ParameterClientId, forwardAuthClient.Id) query.Set(oauth2.ParameterState, "abc") query.Set(oauth2.ParameterScope, "xzy") query.Set(oauth2.ParameterRedirectUri, redirectUri.String()) @@ -123,13 +131,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) validate(code string, forwardSessionId string) (*http.Cookie, *session.ForwardSession, bool) { authSession, authSessionExists := h.authSessionManager.GetSession(code) - forwardSession, forwardSessionExists := h.forwardSessionManager.GetSession(forwardSessionId) - if authSessionExists && forwardSessionExists { + if authSessionExists { codeChallengeMethod, codeChallengeMethodExists := pkce.CodeChallengeMethodFromString(authSession.CodeChallengeMethod) - if codeChallengeMethodExists { + forwardSession, forwardSessionExists := h.forwardSessionManager.GetSession(forwardSessionId) + if codeChallengeMethodExists && forwardSessionExists { validatePKCE := pkce.ValidatePKCE(codeChallengeMethod, forwardSession.CodeChallengeVerifier, authSession.CodeChallenge) if validatePKCE { - authCookie, authCookieError := h.cookieManager.CreateAuthCookie(authSession.Username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: authSession.Username, + } + h.loginSessionManager.StartSession(loginSession) + authCookie, authCookieError := h.cookieManager.CreateAuthCookie(authSession.Username, loginSession.Id) if authCookieError != nil { return nil, nil, false } diff --git a/internal/server/handler/forwardauth/forwardauth_test.go b/internal/server/handler/forwardauth/forwardauth_test.go index 3af0474..ddf170b 100644 --- a/internal/server/handler/forwardauth/forwardauth_test.go +++ b/internal/server/handler/forwardauth/forwardauth_test.go @@ -13,6 +13,11 @@ import ( func Test_ForwardAuth(t *testing.T) { testConfig := &config.Config{ + Server: config.Server{ + ForwardAuth: config.ForwardAuth{ + ExternalUrl: "http://foo.com", + }, + }, Clients: []config.Client{ { Id: "foo", @@ -45,9 +50,10 @@ func testForwardAuthWithoutCookie(t *testing.T, testConfig *config.Config) { cookieManager := cookie.GetCookieManagerInstance() authSessionManager := session.GetAuthSessionManagerInstance() forwardSessionManager := session.GetForwardSessionManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() templateManager := template.GetTemplateManagerInstance() - forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, templateManager) + forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, loginSessionManager, templateManager) rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, testConfig.GetForwardAuthEndpoint(), nil) @@ -74,9 +80,10 @@ func testForwardAuthMissingHeaders(t *testing.T, testConfig *config.Config) { cookieManager := cookie.GetCookieManagerInstance() authSessionManager := session.GetAuthSessionManagerInstance() forwardSessionManager := session.GetForwardSessionManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() templateManager := template.GetTemplateManagerInstance() - forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, templateManager) + forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, loginSessionManager, templateManager) rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, testConfig.GetForwardAuthEndpoint(), nil) @@ -94,9 +101,10 @@ func testForwardAuthInvalidHeaders(t *testing.T, testConfig *config.Config) { cookieManager := cookie.GetCookieManagerInstance() authSessionManager := session.GetAuthSessionManagerInstance() forwardSessionManager := session.GetForwardSessionManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() templateManager := template.GetTemplateManagerInstance() - forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, templateManager) + forwardAuthHandler := NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, loginSessionManager, templateManager) rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, testConfig.GetForwardAuthEndpoint(), nil) diff --git a/internal/server/handler/logout/logout.go b/internal/server/handler/logout/logout.go index b9f8ca3..b095daa 100644 --- a/internal/server/handler/logout/logout.go +++ b/internal/server/handler/logout/logout.go @@ -3,35 +3,39 @@ package logout import ( internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/manager/cookie" + "github.com/webishdev/stopnik/internal/manager/session" "github.com/webishdev/stopnik/internal/server/handler/error" "github.com/webishdev/stopnik/log" "net/http" ) type Handler struct { - logoutRedirect string - cookieManager *cookie.Manager - errorHandler *error.Handler + logoutRedirect string + cookieManager *cookie.Manager + loginSessionManager session.Manager[session.LoginSession] + errorHandler *error.Handler } -func NewLogoutHandler(cookieManager *cookie.Manager, logoutRedirect string) *Handler { +func NewLogoutHandler(cookieManager *cookie.Manager, loginSessionManager session.Manager[session.LoginSession], logoutRedirect string) *Handler { return &Handler{ - cookieManager: cookieManager, - logoutRedirect: logoutRedirect, + cookieManager: cookieManager, + loginSessionManager: loginSessionManager, + logoutRedirect: logoutRedirect, } } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.AccessLogRequest(r) if r.Method == http.MethodPost { - _, validCookie := h.cookieManager.ValidateAuthCookie(r) + _, loginSession, validCookie := h.cookieManager.ValidateAuthCookie(r) if !validCookie { h.errorHandler.ForbiddenHandler(w, r) return } - cookie := h.cookieManager.DeleteAuthCookie() + h.loginSessionManager.CloseSession(loginSession.Id, true) + authCookie := h.cookieManager.DeleteAuthCookie() - http.SetCookie(w, &cookie) + http.SetCookie(w, &authCookie) logoutRedirectFrom := r.PostFormValue("stopnik_logout_redirect") diff --git a/internal/server/handler/logout/logout_test.go b/internal/server/handler/logout/logout_test.go index bb60c7f..7ac799d 100644 --- a/internal/server/handler/logout/logout_test.go +++ b/internal/server/handler/logout/logout_test.go @@ -2,10 +2,12 @@ package logout import ( "fmt" + "github.com/google/uuid" "github.com/webishdev/stopnik/internal/config" "github.com/webishdev/stopnik/internal/endpoint" internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/manager/cookie" + "github.com/webishdev/stopnik/internal/manager/session" "net/http" "net/http/httptest" "strings" @@ -43,6 +45,7 @@ func Test_Logout(t *testing.T) { func testInvalidCookies(t *testing.T, testConfig *config.Config) { t.Run("Logout with invalid cookie", func(t *testing.T) { cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() authCookie := http.Cookie{ Name: testConfig.GetAuthCookieName(), @@ -53,7 +56,7 @@ func testInvalidCookies(t *testing.T, testConfig *config.Config) { SameSite: http.SameSiteLaxMode, } - logoutHandler := NewLogoutHandler(cookieManager, "") + logoutHandler := NewLogoutHandler(cookieManager, loginSessionManager, "") rr := httptest.NewRecorder() @@ -84,11 +87,17 @@ func testLogout(t *testing.T, testConfig *config.Config) { testMessage := fmt.Sprintf("Logout handler redirect %v, form redirect %v", test.handlerRedirect, test.formRedirect) t.Run(testMessage, func(t *testing.T) { cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() user, _ := testConfig.GetUser("foo") - authCookie, _ := cookieManager.CreateAuthCookie(user.Username) + loginSession := &session.LoginSession{ + Id: uuid.NewString(), + Username: user.Username, + } + loginSessionManager.StartSession(loginSession) + authCookie, _ := cookieManager.CreateAuthCookie(user.Username, loginSession.Id) - logoutHandler := NewLogoutHandler(cookieManager, test.handlerRedirect) + logoutHandler := NewLogoutHandler(cookieManager, loginSessionManager, test.handlerRedirect) rr := httptest.NewRecorder() @@ -146,7 +155,7 @@ func Test_LogoutNotAllowedHttpMethods(t *testing.T) { for _, method := range testInvalidLogoutHttpMethods { testMessage := fmt.Sprintf("Logout with unsupported method %s", method) t.Run(testMessage, func(t *testing.T) { - logoutHandler := NewLogoutHandler(&cookie.Manager{}, "") + logoutHandler := NewLogoutHandler(&cookie.Manager{}, &session.LoginManager{}, "") rr := httptest.NewRecorder() diff --git a/internal/server/server.go b/internal/server/server.go index 36cfee3..96bbd27 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -180,16 +180,17 @@ func registerHandlers(config *config.Config, handle func(pattern string, handler authSessionManager := session.GetAuthSessionManagerInstance() tokenManager := token2.GetTokenManagerInstance() cookieManager := cookie.GetCookieManagerInstance() + loginSessionManager := session.GetLoginSessionManagerInstance() requestValidator := validation.NewRequestValidator() templateManager := template.GetTemplateManagerInstance() // Own healthHandler := health.NewHealthHandler(tokenManager) - accountHandler := account.NewAccountHandler(requestValidator, cookieManager, templateManager) - logoutHandler := logout.NewLogoutHandler(cookieManager, config.Server.LogoutRedirect) + accountHandler := account.NewAccountHandler(requestValidator, cookieManager, loginSessionManager, templateManager) + logoutHandler := logout.NewLogoutHandler(cookieManager, loginSessionManager, config.Server.LogoutRedirect) // OAuth2 - authorizeHandler := authorize.NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, tokenManager, templateManager) + authorizeHandler := authorize.NewAuthorizeHandler(requestValidator, cookieManager, authSessionManager, loginSessionManager, tokenManager, templateManager) tokenHandler := token.NewTokenHandler(requestValidator, authSessionManager, tokenManager) // OAuth2 extensions @@ -206,7 +207,7 @@ func registerHandlers(config *config.Config, handle func(pattern string, handler if config.GetForwardAuthEnabled() { log.Info("ForwardAuth enabled with endpoint %s", config.GetForwardAuthEndpoint()) forwardSessionManager := session.GetForwardSessionManagerInstance() - forwardAuthHandler := forwardauth.NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, templateManager) + forwardAuthHandler := forwardauth.NewForwardAuthHandler(cookieManager, authSessionManager, forwardSessionManager, loginSessionManager, templateManager) handle(config.GetForwardAuthEndpoint(), forwardAuthHandler) } diff --git a/internal/store/store.go b/internal/store/store.go index 0a39668..e6466dc 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -181,7 +181,7 @@ func (s *store[T]) Set(key string, value *T) { s.storeMap[key] = value } -func (s *store[T]) SetWithDuration(key string, value *T, duration time.Duration) { +func (s *store[T]) SetWithDuration(key string, value *T, _ time.Duration) { s.Set(key, value) } diff --git a/internal/template/template.go b/internal/template/template.go index 95e26f9..74b5ab9 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -87,7 +87,7 @@ func (templateManager *Manager) LoginTemplate(id string, action string, message Action: action, Token: id, HideFooter: templateManager.config.GetHideFooter(), - HideMascot: templateManager.config.GetHideMascot(), + HideMascot: templateManager.config.GetHideLogo(), ShowTitle: templateManager.config.GetTitle() != "", Title: templateManager.config.GetTitle(), FooterText: templateManager.config.GetFooterText(), @@ -125,7 +125,7 @@ func (templateManager *Manager) LogoutTemplate(username string, requestURI strin Username: username, RequestURI: requestURI, HideFooter: templateManager.config.GetHideFooter(), - HideMascot: templateManager.config.GetHideMascot(), + HideMascot: templateManager.config.GetHideLogo(), ShowTitle: templateManager.config.GetTitle() != "", Title: templateManager.config.GetTitle(), FooterText: templateManager.config.GetFooterText(),