Skip to content

Commit

Permalink
Add OIDC discovery and empty user info
Browse files Browse the repository at this point in the history
  • Loading branch information
giftkugel committed Aug 27, 2024
1 parent b2c5954 commit ef81a3e
Show file tree
Hide file tree
Showing 17 changed files with 289 additions and 85 deletions.
2 changes: 1 addition & 1 deletion cmd/stopnik/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func readConfiguration(configurationFile *string, configLoader *config.Loader) (
}
logger.SetLogLevel(currentConfig.Server.LogLevel)
logger.Info("Config loaded from %s", *configurationFile)
if currentConfig.GetOIDC() {
if currentConfig.GetOidc() {
logger.Info("OpenId Connect is enabled")
}

Expand Down
14 changes: 9 additions & 5 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/rand"
"errors"
"fmt"
internalHttp "github.com/webishdev/stopnik/internal/http"
"github.com/webishdev/stopnik/log"
"math/big"
)
Expand Down Expand Up @@ -80,7 +81,7 @@ type Client struct {
Id string `yaml:"id"`
ClientSecret string `yaml:"clientSecret"`
Salt string `yaml:"salt"`
OIDC bool `yaml:"oidc"`
Oidc bool `yaml:"oidc"`
ClientType string `yaml:"type"`
AccessTTL int `yaml:"accessTTL"`
RefreshTTL int `yaml:"refreshTTL"`
Expand Down Expand Up @@ -220,7 +221,7 @@ func (config *Config) Setup() error {
return errors.New(invalidClient)
}

config.oidc = config.oidc || client.OIDC
config.oidc = config.oidc || client.Oidc
}

config.userMap = setup[User](&config.Users, func(user User) string {
Expand Down Expand Up @@ -292,7 +293,7 @@ func (config *Config) GetFooterText() string {
return GetOrDefaultString(config.UI.FooterText, "STOPnik")
}

func (config *Config) GetOIDC() bool {
func (config *Config) GetOidc() bool {
return config.oidc
}

Expand All @@ -304,8 +305,11 @@ func (client *Client) GetRefreshTTL() int {
return GetOrDefaultInt(client.RefreshTTL, 0)
}

func (client *Client) GetIssuer() string {
return GetOrDefaultString(client.Issuer, "STOPnik")
func (client *Client) GetIssuer(requestData *internalHttp.RequestData) string {
if requestData == nil || requestData.Host == "" || requestData.Scheme == "" {
return GetOrDefaultString(client.Issuer, "STOPnik")
}
return GetOrDefaultString(client.Issuer, requestData.IssuerString())
}

func (client *Client) GetAudience() []string {
Expand Down
3 changes: 2 additions & 1 deletion internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"errors"
internalHttp "github.com/webishdev/stopnik/internal/http"
"reflect"
"testing"
)
Expand Down Expand Up @@ -452,7 +453,7 @@ func assertClientValues(t *testing.T, config *Config, id string, expectedAccessT
t.Errorf("expected refresh TTL to be %d, got %d", expectedRefreshTTL, refreshTTL)
}

issuer := client.GetIssuer()
issuer := client.GetIssuer(&internalHttp.RequestData{})
if issuer != expectedIssuer {
t.Errorf("expected issuer to be '%s', got '%s'", expectedIssuer, issuer)
}
Expand Down
2 changes: 2 additions & 0 deletions internal/endpoint/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ const (
Revoke string = "/revoke"
Metadata string = "/.well-known/oauth-authorization-server"
Keys string = "/keys"
OidcDiscovery string = "/.well-known/openid-configuration"
OidcUserInfo string = "/userinfo"
)
48 changes: 48 additions & 0 deletions internal/http/request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package http

import (
"fmt"
"net/http"
"net/url"
)

type RequestData struct {
Scheme string
Host string
Path string
Query string
Fragment string
}

func NewRequestData(r *http.Request) *RequestData {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}

query := ""
if r.URL.RawQuery != "" {
query = "?" + r.URL.RawQuery
}
fragment := ""
if r.URL.RawFragment != "" {
fragment = "#" + r.URL.RawFragment
}
return &RequestData{
Scheme: scheme,
Host: r.Host,
Path: r.URL.RawPath,
Query: query,
Fragment: fragment,
}
}

func (r *RequestData) IssuerString() string {
return fmt.Sprintf("%s://%s", r.Scheme, r.Host)
}

func (r *RequestData) URL() (*url.URL, error) {
uri := fmt.Sprintf("%s://%s%s%s%s", r.Scheme, r.Host, r.Path, r.Query, r.Fragment)

return url.Parse(uri)
}
28 changes: 15 additions & 13 deletions internal/manager/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/webishdev/stopnik/internal/oauth2"
"github.com/webishdev/stopnik/internal/store"
"github.com/webishdev/stopnik/log"
"net/http"
"strings"
"time"
)
Expand Down Expand Up @@ -53,14 +54,15 @@ func (tokenManager *TokenManager) RevokeRefreshToken(refreshToken *oauth2.Refres
refreshTokenStore.Delete(refreshToken.Key)
}

func (tokenManager *TokenManager) CreateAccessTokenResponse(username string, client *config.Client, scopes []string, nonce string) oauth2.AccessTokenResponse {
func (tokenManager *TokenManager) CreateAccessTokenResponse(r *http.Request, username string, client *config.Client, scopes []string, nonce string) oauth2.AccessTokenResponse {
log.Debug("Creating new access token for %s, access TTL %d, refresh TTL %d", client.Id, client.GetAccessTTL(), client.GetRefreshTTL())

requestData := internalHttp.NewRequestData(r)
accessTokenStore := *tokenManager.accessTokenStore
refreshTokenStore := *tokenManager.refreshTokenStore

accessTokenDuration := time.Minute * time.Duration(client.GetAccessTTL())
accessTokenKey := tokenManager.generateAccessToken(username, client, accessTokenDuration)
accessTokenKey := tokenManager.generateAccessToken(requestData, username, client, accessTokenDuration)
accessToken := &oauth2.AccessToken{
Key: accessTokenKey,
TokenType: oauth2.TtBearer,
Expand All @@ -79,7 +81,7 @@ func (tokenManager *TokenManager) CreateAccessTokenResponse(username string, cli

if client.GetRefreshTTL() > 0 {
refreshTokenDuration := time.Minute * time.Duration(client.GetRefreshTTL())
refreshTokenKey := tokenManager.generateAccessToken(username, client, refreshTokenDuration)
refreshTokenKey := tokenManager.generateAccessToken(requestData, username, client, refreshTokenDuration)
refreshToken := &oauth2.RefreshToken{
Key: refreshTokenKey,
Username: username,
Expand All @@ -92,11 +94,11 @@ func (tokenManager *TokenManager) CreateAccessTokenResponse(username string, cli
accessTokenResponse.RefreshTokenKey = refreshTokenKey
}

if client.OIDC {
if client.Oidc {
user, userExists := tokenManager.config.GetUser(username)
if userExists {
idTokenDuration := time.Minute * time.Duration(client.GetAccessTTL())
idTokenKey := tokenManager.generateIdToken(user, client, nonce, idTokenDuration)
idTokenKey := tokenManager.generateIdToken(requestData, user, client, nonce, idTokenDuration)
accessTokenResponse.IdToken = idTokenKey
}
}
Expand Down Expand Up @@ -126,17 +128,17 @@ func (tokenManager *TokenManager) ValidateAccessToken(authorizationHeader string
return user, accessToken.Scopes, true
}

func (tokenManager *TokenManager) generateIdToken(user *config.User, client *config.Client, nonce string, duration time.Duration) string {
idToken := generateIdToken(user, client, nonce, duration)
func (tokenManager *TokenManager) generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, duration time.Duration) string {
idToken := generateIdToken(requestData, user, client, nonce, duration)
return tokenManager.generateJWTToken(client, idToken)
}

func (tokenManager *TokenManager) generateAccessToken(username string, client *config.Client, duration time.Duration) string {
func (tokenManager *TokenManager) generateAccessToken(requestData *internalHttp.RequestData, username string, client *config.Client, duration time.Duration) string {
tokenId := uuid.New()
if client.OpaqueToken {
return tokenManager.generateOpaqueAccessToken(tokenId.String())
}
accessToken := generateAccessToken(tokenId.String(), duration, username, client)
accessToken := generateAccessToken(requestData, tokenId.String(), duration, username, client)
return tokenManager.generateJWTToken(client, accessToken)
}

Expand Down Expand Up @@ -187,7 +189,7 @@ func (tokenManager *TokenManager) generateJWTToken(client *config.Client, token

}

func generateIdToken(user *config.User, client *config.Client, nonce string, duration time.Duration) jwt.Token {
func generateIdToken(requestData *internalHttp.RequestData, user *config.User, client *config.Client, nonce string, duration time.Duration) jwt.Token {
tokenId := uuid.New().String()
builder := jwt.NewBuilder().
Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4
Expand All @@ -197,7 +199,7 @@ func generateIdToken(user *config.User, client *config.Client, nonce string, dur
builder.JwtID(tokenId)

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
builder.Issuer(client.GetIssuer())
builder.Issuer(client.GetIssuer(requestData))

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2
builder.Subject(user.Username)
Expand All @@ -218,7 +220,7 @@ func generateIdToken(user *config.User, client *config.Client, nonce string, dur
return token
}

func generateAccessToken(tokenId string, duration time.Duration, username string, client *config.Client) jwt.Token {
func generateAccessToken(requestData *internalHttp.RequestData, tokenId string, duration time.Duration, username string, client *config.Client) jwt.Token {
builder := jwt.NewBuilder().
Expiration(time.Now().Add(duration)). // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4
IssuedAt(time.Now()) // https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6
Expand All @@ -232,7 +234,7 @@ func generateAccessToken(tokenId string, duration time.Duration, username string
builder.JwtID(tokenId)

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
builder.Issuer(client.GetIssuer())
builder.Issuer(client.GetIssuer(requestData))

// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2
builder.Subject(username)
Expand Down
9 changes: 7 additions & 2 deletions internal/manager/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package manager
import (
"fmt"
"github.com/webishdev/stopnik/internal/config"
"github.com/webishdev/stopnik/internal/endpoint"
internalHttp "github.com/webishdev/stopnik/internal/http"
"github.com/webishdev/stopnik/internal/oauth2"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
Expand Down Expand Up @@ -33,7 +36,8 @@ func Test_Token(t *testing.T) {
t.Fatal("client does not exist")
}

accessTokenResponse := tokenManager.CreateAccessTokenResponse("foo", client, []string{"abc", "def"}, "")
request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, []string{"abc", "def"}, "")

if accessTokenResponse.AccessTokenKey == "" {
t.Error("empty access token")
Expand Down Expand Up @@ -137,7 +141,8 @@ func Test_Token(t *testing.T) {
t.Fatal("client does not exist")
}

accessTokenResponse := tokenManager.CreateAccessTokenResponse("bar", client, []string{"abc", "def"}, "")
request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, "bar", client, []string{"abc", "def"}, "")

_, _, valid := tokenManager.ValidateAccessToken(fmt.Sprintf("%s %s", internalHttp.AuthBearer, accessTokenResponse.AccessTokenKey))

Expand Down
6 changes: 3 additions & 3 deletions internal/server/handler/authorize/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (h *Handler) handleGetRequest(w http.ResponseWriter, r *http.Request) {
State: state,
}

if client.OIDC {
if client.Oidc {
nonceQueryParameter := r.URL.Query().Get(oidc.ParameterNonce)
authSession.Nonce = nonceQueryParameter
} else {
Expand All @@ -157,7 +157,7 @@ func (h *Handler) handleGetRequest(w http.ResponseWriter, r *http.Request) {
query := redirectURL.Query()

if slices.Contains(responseTypes, oauth2.RtToken) {
accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, authSession.Nonce)
accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(r, user.Username, client, scopes, authSession.Nonce)
setImplicitGrantParameter(query, accessTokenResponse)
} else if slices.Contains(responseTypes, oauth2.RtCode) {
setAuthorizationGrantParameter(query, id.String())
Expand Down Expand Up @@ -225,7 +225,7 @@ func (h *Handler) handlePostRequest(w http.ResponseWriter, r *http.Request, user
h.errorHandler.InternalServerErrorHandler(w, r)
return
}
accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(user.Username, client, authSession.Scopes, authSession.Nonce)
accessTokenResponse := h.tokenManager.CreateAccessTokenResponse(r, user.Username, client, authSession.Scopes, authSession.Nonce)
setImplicitGrantParameter(query, accessTokenResponse)
} else if slices.Contains(responseTypes, oauth2.RtCode) {
setAuthorizationGrantParameter(query, authSession.Id)
Expand Down
3 changes: 2 additions & 1 deletion internal/server/handler/health/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ func Test_Health(t *testing.T) {
t.Error("client should exist")
}

tokenResponse := tokenManager.CreateAccessTokenResponse("foo", client, []string{"a:foo", "b:bar"}, "")
request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil)
tokenResponse := tokenManager.CreateAccessTokenResponse(request, "foo", client, []string{"a:foo", "b:bar"}, "")

healthHandler := NewHealthHandler(tokenManager)

Expand Down
15 changes: 9 additions & 6 deletions internal/server/handler/introspect/introspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ func testIntrospect(t *testing.T, testConfig *config.Config, keyManger *manager.
sessionManager := manager.NewSessionManager(testConfig)
tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManger))
sessionManager.StartSession(authSession)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "")
request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "")

introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager)

Expand All @@ -253,7 +254,7 @@ func testIntrospect(t *testing.T, testConfig *config.Config, keyManger *manager.
)
body := strings.NewReader(bodyString)

request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, body)
request = httptest.NewRequest(http.MethodPost, endpoint.Introspect, body)
request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar")))
request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded")

Expand Down Expand Up @@ -308,7 +309,8 @@ func testIntrospectWithoutHint(t *testing.T, testConfig *config.Config, keyMange
sessionManager := manager.NewSessionManager(testConfig)
tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManger))
sessionManager.StartSession(authSession)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "")
request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "")

introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager)

Expand All @@ -324,7 +326,7 @@ func testIntrospectWithoutHint(t *testing.T, testConfig *config.Config, keyMange
)
body := strings.NewReader(bodyString)

request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, body)
request = httptest.NewRequest(http.MethodPost, endpoint.Introspect, body)
request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("foo", "bar")))
request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded")

Expand Down Expand Up @@ -379,7 +381,8 @@ func testIntrospectDisabled(t *testing.T, testConfig *config.Config, keyManger *
sessionManager := manager.NewSessionManager(testConfig)
tokenManager := manager.NewTokenManager(testConfig, manager.NewDefaultKeyLoader(testConfig, keyManger))
sessionManager.StartSession(authSession)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes, "")
request := httptest.NewRequest(http.MethodPost, endpoint.Token, nil)
accessTokenResponse := tokenManager.CreateAccessTokenResponse(request, user.Username, client, scopes, "")

introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager)

Expand All @@ -396,7 +399,7 @@ func testIntrospectDisabled(t *testing.T, testConfig *config.Config, keyManger *
)
body := strings.NewReader(bodyString)

request := httptest.NewRequest(http.MethodPost, endpoint.Introspect, body)
request = httptest.NewRequest(http.MethodPost, endpoint.Introspect, body)
request.Header.Add(internalHttp.Authorization, fmt.Sprintf("Basic %s", testTokenCreateBasicAuth("bar", "bar")))
request.Header.Add(internalHttp.ContentType, "application/x-www-form-urlencoded")

Expand Down
Loading

0 comments on commit ef81a3e

Please sign in to comment.