Skip to content

Commit

Permalink
feat: optional static domain for callback & allow setting statecookie…
Browse files Browse the repository at this point in the history
….domain (#30)
  • Loading branch information
cdanis authored Jan 5, 2025
1 parent ff743d8 commit 1969587
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 66 deletions.
25 changes: 24 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"net/url"
"os"
"strings"
"text/template"
Expand Down Expand Up @@ -31,6 +32,11 @@ type Config struct {
Provider *ProviderConfig `json:"provider"`
Scopes []string `json:"scopes"`

// Can be a relative path or a full URL.
// If a relative path is used, the scheme and domain will be taken from the incoming request.
// In this case, the callback path will overlay all hostnames behind the middleware.
// If a full URL is used, all callbacks are sent there. It is the user's responsibility to ensure
// that the callback URL is also routed to this middleware plugin.
CallbackUri string `json:"callback_uri"`

// The URL used to start authorization when needed.
Expand Down Expand Up @@ -74,6 +80,7 @@ type ProviderConfig struct {
type StateCookieConfig struct {
Name string `json:"name"`
Path string `json:"path"`
Domain string `json:"domain"`
Secure bool `json:"secure"`
HttpOnly bool `json:"http_only"`
SameSite string `json:"same_site"`
Expand Down Expand Up @@ -115,6 +122,7 @@ func CreateConfig() *Config {
StateCookie: &StateCookieConfig{
Name: "Authorization",
Path: "/",
Domain: "",
Secure: true,
HttpOnly: true,
SameSite: "default",
Expand Down Expand Up @@ -165,6 +173,12 @@ func New(uctx context.Context, next http.Handler, config *Config, name string) (
return nil, err
}

parsedCallbackURL, err := url.Parse(config.CallbackUri)
if err != nil {
log(config.LogLevel, LogLevelError, "Error while parsing CallbackUri: %s", err.Error())
return nil, err
}

if config.Provider.TokenValidation == "" {
// For EntraID, we cannot validate the access token using JWKS, so we fall back to the id token by default
if strings.HasPrefix(config.Provider.Url, "https://login.microsoftonline.com") {
Expand All @@ -174,12 +188,21 @@ func New(uctx context.Context, next http.Handler, config *Config, name string) (
}
}

log(config.LogLevel, LogLevelInfo, "Configuration loaded. Provider Url: %v", parsedURL)
log(config.LogLevel, LogLevelInfo, "Provider Url: %v", parsedURL)
log(config.LogLevel, LogLevelInfo, "I will use this URL for callbacks from the IDP: %v", parsedCallbackURL)
if urlIsAbsolute(parsedCallbackURL) {
log(config.LogLevel, LogLevelInfo, "Callback URL is absolute, will not overlay wrapped services")
} else {
log(config.LogLevel, LogLevelInfo, "Callback URL is relative, will overlay any wrapped host")
}
log(config.LogLevel, LogLevelDebug, "Scopes: %s", strings.Join(config.Scopes, ", "))
log(config.LogLevel, LogLevelDebug, "StateCookie: %v", config.StateCookie)

log(config.LogLevel, LogLevelInfo, "Configuration loaded successfully, starting OIDC Auth middleware...")
return &TraefikOidcAuth{
next: next,
ProviderURL: parsedURL,
CallbackURL: parsedCallbackURL,
Config: config,
SessionStorage: CreateCookieSessionStorage(),
}, nil
Expand Down
134 changes: 76 additions & 58 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
type TraefikOidcAuth struct {
next http.Handler
ProviderURL *url.URL
CallbackURL *url.URL
Config *Config
SessionStorage SessionStorage
DiscoveryDocument *OidcDiscovery
Expand Down Expand Up @@ -66,6 +67,33 @@ func (toa *TraefikOidcAuth) EnsureOidcDiscovery() error {
return nil
}

func (toa *TraefikOidcAuth) GetAbsoluteCallbackURL(req *http.Request) *url.URL {
if urlIsAbsolute(toa.CallbackURL) {
return toa.CallbackURL
} else {
abs := *toa.CallbackURL
fillHostSchemeFromRequest(req, &abs)
return &abs
}
}

func (toa *TraefikOidcAuth) isCallbackRequest(req *http.Request) bool {
u := req.URL
fillHostSchemeFromRequest(req, u)

if u.Path != toa.CallbackURL.Path {
return false
}

if urlIsAbsolute(toa.CallbackURL) {
if u.Scheme != toa.CallbackURL.Scheme || u.Host != toa.CallbackURL.Host {
return false
}
}

return true
}

func (toa *TraefikOidcAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
err := toa.EnsureOidcDiscovery()

Expand All @@ -75,7 +103,7 @@ func (toa *TraefikOidcAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request)
return
}

if strings.HasPrefix(req.RequestURI, toa.Config.CallbackUri) {
if toa.isCallbackRequest(req) {
toa.handleCallback(rw, req)
return
}
Expand Down Expand Up @@ -146,13 +174,9 @@ func (toa *TraefikOidcAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request)
}

if !ok {
http.SetCookie(rw, &http.Cookie{
Name: toa.Config.StateCookie.Name,
Value: "",
Path: toa.Config.StateCookie.Path,
Expires: time.Now().Add(-24 * time.Hour),
MaxAge: -1,
})
c := toa.createStateCookie()
makeCookieExpireImmediately(c)
http.SetCookie(rw, c)

toa.handleUnauthorized(rw, req)
return
Expand Down Expand Up @@ -316,7 +340,8 @@ func (toa *TraefikOidcAuth) handleCallback(rw http.ResponseWriter, req *http.Req
MaxAge: -1,
Secure: true,
HttpOnly: true,
Path: toa.Config.CallbackUri,
Path: toa.CallbackURL.Path,
Domain: toa.CallbackURL.Host,
SameSite: http.SameSiteDefaultMode,
})

Expand Down Expand Up @@ -348,7 +373,7 @@ func (toa *TraefikOidcAuth) handleLogout(rw http.ResponseWriter, req *http.Reque
return
}

callbackUri := ensureAbsoluteUrl(req, toa.Config.CallbackUri)
callbackUri := toa.GetAbsoluteCallbackURL(req).String()
redirectUri := ensureAbsoluteUrl(req, toa.Config.PostLogoutRedirectUri)

if req.URL.Query().Get("redirect_uri") != "" {
Expand Down Expand Up @@ -390,9 +415,9 @@ func (toa *TraefikOidcAuth) redirectToProvider(rw http.ResponseWriter, req *http
log(toa.Config.LogLevel, LogLevelInfo, "Redirecting to OIDC provider...")

host := getFullHost(req)

originalUrl := fmt.Sprintf("%s%s", host, req.RequestURI)
redirectUrl := host + toa.Config.CallbackUri

redirectUrl := toa.GetAbsoluteCallbackURL(req).String()

state := OidcState{
Action: "Login",
Expand Down Expand Up @@ -443,12 +468,14 @@ func (toa *TraefikOidcAuth) redirectToProvider(rw http.ResponseWriter, req *http
}

// TODO: Make configurable
// TODO does this need domain tweaks? it is in the login flow
http.SetCookie(rw, &http.Cookie{
Name: "CodeVerifier",
Value: encryptedCodeVerifier,
Secure: true,
HttpOnly: true,
Path: toa.Config.CallbackUri,
Path: toa.CallbackURL.Path,
Domain: toa.CallbackURL.Host,
SameSite: http.SameSiteDefaultMode,
})
}
Expand Down Expand Up @@ -478,38 +505,39 @@ func (toa *TraefikOidcAuth) storeSessionAndAttachCookie(session SessionState, rw
toa.SetChunkedCookies(rw, toa.Config.StateCookie.Name, encryptedSessionTicket)
}

func (toa *TraefikOidcAuth) createStateCookie() *http.Cookie {
return &http.Cookie{
Name: toa.Config.StateCookie.Name,
Value: "",
Secure: toa.Config.StateCookie.Secure,
HttpOnly: toa.Config.StateCookie.HttpOnly,
Path: toa.Config.StateCookie.Path,
Domain: toa.Config.StateCookie.Domain,
SameSite: parseCookieSameSite(toa.Config.StateCookie.SameSite),
}
}

func (toa *TraefikOidcAuth) SetChunkedCookies(rw http.ResponseWriter, cookieName string, cookieValue string) {
cookieChunks := ChunkString(cookieValue, 3072)

baseCookie := toa.createStateCookie()
baseCookie.Name = cookieName

// Set the cookie
if len(cookieChunks) == 1 {
http.SetCookie(rw, &http.Cookie{
Name: cookieName,
Value: cookieValue,
Secure: toa.Config.StateCookie.Secure,
HttpOnly: toa.Config.StateCookie.HttpOnly,
Path: toa.Config.StateCookie.Path,
SameSite: parseCookieSameSite(toa.Config.StateCookie.SameSite),
})
c := baseCookie
c.Value = cookieValue
http.SetCookie(rw, c)
} else {
http.SetCookie(rw, &http.Cookie{
Name: cookieName + "Chunks",
Value: fmt.Sprintf("%d", len(cookieChunks)),
Secure: toa.Config.StateCookie.Secure,
HttpOnly: toa.Config.StateCookie.HttpOnly,
Path: toa.Config.StateCookie.Path,
SameSite: parseCookieSameSite(toa.Config.StateCookie.SameSite),
})
c := baseCookie
c.Name = cookieName + "Chunks"
c.Value = fmt.Sprintf("%d", len(cookieChunks))
http.SetCookie(rw, c)

for index, chunk := range cookieChunks {
http.SetCookie(rw, &http.Cookie{
Name: fmt.Sprintf("%s%d", cookieName, index+1),
Value: chunk,
Secure: toa.Config.StateCookie.Secure,
HttpOnly: toa.Config.StateCookie.HttpOnly,
Path: toa.Config.StateCookie.Path,
SameSite: parseCookieSameSite(toa.Config.StateCookie.SameSite),
})
c.Name = fmt.Sprintf("%s%d", cookieName, index+1)
c.Value = chunk
http.SetCookie(rw, c)
}
}
}
Expand Down Expand Up @@ -560,31 +588,21 @@ func (toa *TraefikOidcAuth) ClearChunkedCookie(rw http.ResponseWriter, req *http
return err
}

baseCookie := toa.createStateCookie()
baseCookie.Name = cookieName
baseCookie.Value = ""
makeCookieExpireImmediately(baseCookie)

if chunkCount == 0 {
http.SetCookie(rw, &http.Cookie{
Name: cookieName,
Value: "",
Path: toa.Config.StateCookie.Path,
Expires: time.Now().Add(-24 * time.Hour),
MaxAge: -1,
})
http.SetCookie(rw, baseCookie)
} else {
http.SetCookie(rw, &http.Cookie{
Name: fmt.Sprintf("%sChunks", cookieName),
Value: "",
Path: toa.Config.StateCookie.Path,
Expires: time.Now().Add(-24 * time.Hour),
MaxAge: -1,
})
c := baseCookie
c.Name = cookieName + "Chunks"
http.SetCookie(rw, c)

for i := 0; i < chunkCount; i++ {
http.SetCookie(rw, &http.Cookie{
Name: fmt.Sprintf("%s%d", cookieName, i+1),
Value: "",
Path: toa.Config.StateCookie.Path,
Expires: time.Now().Add(-24 * time.Hour),
MaxAge: -1,
})
c.Name = fmt.Sprintf("%s%d", cookieName, i+1)
http.SetCookie(rw, c)
}
}

Expand Down
4 changes: 1 addition & 3 deletions oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ func randomBytesInHex(count int) (string, error) {
}

func exchangeAuthCode(oidcAuth *TraefikOidcAuth, req *http.Request, authCode string) (*OidcTokenResponse, error) {
host := getFullHost(req)

redirectUrl := host + oidcAuth.Config.CallbackUri
redirectUrl := oidcAuth.GetAbsoluteCallbackURL(req).String()

urlValues := url.Values{
"grant_type": {"authorization_code"},
Expand Down
31 changes: 28 additions & 3 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ func parseCookieSameSite(sameSite string) http.SameSite {
}
}

func makeCookieExpireImmediately(cookie *http.Cookie) *http.Cookie {
cookie.Expires = time.Now().Add(-24 * time.Hour)
cookie.MaxAge = -1
return cookie
}

func urlIsAbsolute(u *url.URL) bool {
return u.Scheme != "" && u.Host != ""
}

func parseUrl(rawUrl string) (*url.URL, error) {
if rawUrl == "" {
return nil, errors.New("invalid empty url")
Expand All @@ -59,17 +69,32 @@ func parseUrl(rawUrl string) (*url.URL, error) {
return u, nil
}

func getFullHost(req *http.Request) string {
func getSchemeFromRequest(req *http.Request) string {
scheme := req.Header.Get("X-Forwarded-Proto")
host := req.Header.Get("X-Forwarded-Host")

if scheme == "" {
if req.TLS != nil {
scheme = "https"
} else {
scheme = "http"
}
}
return scheme
}

func fillHostSchemeFromRequest(req *http.Request, u *url.URL) *url.URL {
scheme := getSchemeFromRequest(req)
host := req.Header.Get("X-Forwarded-Host")
if host == "" {
host = req.Host
}
u.Scheme = scheme
u.Host = host
return u
}

func getFullHost(req *http.Request) string {
scheme := getSchemeFromRequest(req)
host := req.Header.Get("X-Forwarded-Host")
if host == "" {
host = req.Host
}
Expand Down
Loading

0 comments on commit 1969587

Please sign in to comment.