Skip to content

Commit

Permalink
Add tests, improve documentation, improve ForwardAuth
Browse files Browse the repository at this point in the history
  • Loading branch information
giftkugel committed Sep 5, 2024
1 parent 56c1df4 commit 51d16a3
Show file tree
Hide file tree
Showing 23 changed files with 1,092 additions and 1,047 deletions.
25 changes: 15 additions & 10 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ type Cookies struct {
}

type ForwardAuth struct {
Endpoint string `yaml:"endpoint"`
ExternalUrl string `yaml:"externalUrl"`
Endpoint string `yaml:"endpoint"`
ExternalUrl string `yaml:"externalUrl"`
ParameterName string `yaml:"parameterName"`
}

type Server struct {
Expand All @@ -47,6 +48,7 @@ type Server struct {
IntrospectScope string `yaml:"introspectScope"`
RevokeScope string `yaml:"revokeScopeScope"`
SessionTimeoutSeconds int `yaml:"sessionTimeoutSeconds"`
Issuer string `yaml:"issuer"`
ForwardAuth ForwardAuth `yaml:"forwardAuth"`
}

Expand Down Expand Up @@ -108,7 +110,6 @@ type Client struct {
OpaqueToken bool `yaml:"opaqueToken"`
PasswordFallbackAllowed bool `yaml:"passwordFallbackAllowed"`
Claims []Claim `yaml:"claims"`
Issuer string `yaml:"issuer"`
Audience []string `yaml:"audience"`
PrivateKey string `yaml:"privateKey"`
RolesClaim string `yaml:"rolesClaim"`
Expand Down Expand Up @@ -295,6 +296,13 @@ func (config *Config) GetOidc() bool {
return config.oidc
}

func (config *Config) GetIssuer(requestData *internalHttp.RequestData) string {
if requestData == nil || requestData.Host == "" || requestData.Scheme == "" {
return GetOrDefaultString(config.Server.Issuer, "STOPnik")
}
return GetOrDefaultString(config.Server.Issuer, requestData.IssuerString())
}

func (config *Config) GetForwardAuthEnabled() bool {
return config.Server.ForwardAuth.ExternalUrl != ""
}
Expand All @@ -303,6 +311,10 @@ func (config *Config) GetForwardAuthEndpoint() string {
return GetOrDefaultString(config.Server.ForwardAuth.Endpoint, "/forward")
}

func (config *Config) GetForwardAuthParameterName() string {
return GetOrDefaultString(config.Server.ForwardAuth.ParameterName, "forward_id")
}

func (client *Client) GetRolesClaim() string {
return GetOrDefaultString(client.RolesClaim, "roles")
}
Expand All @@ -319,13 +331,6 @@ func (client *Client) GetIdTTL() int {
return GetOrDefaultInt(client.IdTTL, 0)
}

func (client *Client) GetIssuer(requestData *internalHttp.RequestData) string {
if requestData == nil || requestData.Host == "" || requestData.Scheme == "" {
return GetOrDefaultString(client.Issuer, "STOPnik")
}
return GetOrDefaultString(client.Issuer, requestData.IssuerString())
}

func (client *Client) GetAudience() []string {
return GetOrDefaultStringSlice(client.Audience, []string{"all"})
}
Expand Down
15 changes: 9 additions & 6 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func Test_SimpleServerConfiguration(t *testing.T) {
*origin = Config{
Server: Server{
Secret: "5XyLSgKpo5kWrJqm",
Issuer: "http://foo.com/bar",
Cookies: Cookies{
AuthName: "my_auth",
},
Expand Down Expand Up @@ -199,6 +200,11 @@ func Test_SimpleServerConfiguration(t *testing.T) {
t.Error("expected server secret to be '5XyLSgKpo5kWrJqm'")
}

issuer := config.GetIssuer(&internalHttp.RequestData{})
if issuer != "http://foo.com/bar" {
t.Error("expected issuer to be 'http://foo.com/bar'")
}

authCookieName := config.GetAuthCookieName()
if authCookieName != "my_auth" {
t.Error("expected auth cookie name to be 'my_auth'")
Expand Down Expand Up @@ -431,6 +437,9 @@ func Test_ValidClients(t *testing.T) {
}, func(in []byte, out interface{}) (err error) {
origin := out.(*Config)
*origin = Config{
Server: Server{
Issuer: "other",
},
Clients: []Client{
{
Id: "foo",
Expand All @@ -444,7 +453,6 @@ func Test_ValidClients(t *testing.T) {
AccessTTL: 20,
RefreshTTL: 60,
IdTTL: 40,
Issuer: "other",
RolesClaim: "groups",
Audience: []string{"one", "two"},
},
Expand Down Expand Up @@ -652,11 +660,6 @@ func assertClientValues(t *testing.T, config *Config, expected testExpectedClien
t.Errorf("expected id token TTL to be %d, got %d", expected.expectedIdTokenTTL, idTokenTTL)
}

issuer := client.GetIssuer(&internalHttp.RequestData{})
if issuer != expected.expectedIssuer {
t.Errorf("expected issuer to be '%s', got '%s'", expected.expectedIssuer, issuer)
}

audience := client.GetAudience()
if !reflect.DeepEqual(audience, expected.expectedAudience) {
t.Errorf("expected audience to be '%s', got '%s'", expected.expectedAudience, audience)
Expand Down
3 changes: 3 additions & 0 deletions internal/http/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ const (
ContentType string = "Content-Type"
Authorization string = "Authorization"
AccessControlAllowOrigin string = "Access-Control-Allow-Origin"
XForwardProtocol string = "X-Forwarded-Proto"
XForwardHost string = "X-Forwarded-Host"
XForwardUri string = "X-Forwarded-Uri"
)

const (
Expand Down
22 changes: 11 additions & 11 deletions internal/manager/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ func (tokenManager *Manager) CreateAccessTokenResponse(r *http.Request, username
accessTokenStore.SetWithDuration(accessToken.Key, accessToken, accessTokenDuration)

accessTokenResponse := oauth2.AccessTokenResponse{
AccessTokenKey: accessToken.Key,
TokenType: oauth2.TtBearer,
ExpiresIn: int(accessTokenDuration / time.Second),
AccessTokenValue: accessToken.Key,
TokenType: oauth2.TtBearer,
ExpiresIn: int(accessTokenDuration / time.Second),
}

if client.GetRefreshTTL() > 0 {
Expand All @@ -107,14 +107,14 @@ func (tokenManager *Manager) CreateAccessTokenResponse(r *http.Request, username

refreshTokenStore.SetWithDuration(refreshToken.Key, refreshToken, refreshTokenDuration)

accessTokenResponse.RefreshTokenKey = refreshToken.Key
accessTokenResponse.RefreshTokenValue = refreshToken.Key
}

if client.Oidc && oidc.HasOidcScope(scopes) {
user, userExists := tokenManager.config.GetUser(username)
if userExists {
accessTokenHash := tokenManager.CreateAccessTokenHash(client, accessToken.Key)
accessTokenResponse.IdToken = tokenManager.CreateIdToken(r, user.Username, client, scopes, nonce, accessTokenHash)
accessTokenResponse.IdTokenValue = tokenManager.CreateIdToken(r, user.Username, client, scopes, nonce, accessTokenHash)
}
}

Expand Down Expand Up @@ -178,7 +178,7 @@ func (tokenManager *Manager) ValidateAccessToken(authorizationHeader string) (*c
}

func (tokenManager *Manager) generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, atHash string, duration time.Duration) string {
idToken := generateIdToken(requestData, user, client, nonce, atHash, duration)
idToken := generateIdToken(requestData, tokenManager.config, user, client, nonce, atHash, duration)
return tokenManager.generateJWTToken(client, idToken)
}

Expand All @@ -187,7 +187,7 @@ func (tokenManager *Manager) generateAccessToken(requestData *internalHttp.Reque
if client.OpaqueToken {
return tokenManager.generateOpaqueAccessToken(tokenId.String())
}
accessToken := generateAccessToken(requestData, tokenId.String(), duration, username, client)
accessToken := generateAccessToken(requestData, tokenManager.config, client, tokenId.String(), duration, username)
return tokenManager.generateJWTToken(client, accessToken)
}

Expand Down Expand Up @@ -238,7 +238,7 @@ func (tokenManager *Manager) generateJWTToken(client *config.Client, token jwt.T

}

func generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, atHash string, duration time.Duration) jwt.Token {
func generateIdToken(requestData *internalHttp.RequestData, config *config.Config, user *config.User, client *config.Client, nonce string, atHash string, duration time.Duration) jwt.Token {
tokenId := uuid.New().String()
builder := jwt.NewBuilder().
Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4
Expand All @@ -248,7 +248,7 @@ func generateIdToken(requestData *internalHttp.RequestData, user *config.User, c
builder.JwtID(tokenId)

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
builder.Issuer(client.GetIssuer(requestData))
builder.Issuer(config.GetIssuer(requestData))

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2
builder.Subject(user.Username)
Expand All @@ -274,7 +274,7 @@ func generateIdToken(requestData *internalHttp.RequestData, user *config.User, c
return token
}

func generateAccessToken(requestData *internalHttp.RequestData, tokenId string, duration time.Duration, username string, client *config.Client) jwt.Token {
func generateAccessToken(requestData *internalHttp.RequestData, config *config.Config, client *config.Client, tokenId string, duration time.Duration, username string) jwt.Token {
builder := jwt.NewBuilder().
Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4
IssuedAt(time.Now()) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6
Expand All @@ -288,7 +288,7 @@ func generateAccessToken(requestData *internalHttp.RequestData, tokenId string,
builder.JwtID(tokenId)

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
builder.Issuer(client.GetIssuer(requestData))
builder.Issuer(config.GetIssuer(requestData))

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2
builder.Subject(username)
Expand Down
Loading

0 comments on commit 51d16a3

Please sign in to comment.