Skip to content

Commit

Permalink
#54: First draft for custom root CAs. Untested WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
sevensolutions committed Jan 24, 2025
1 parent 0f6a12e commit 47ce99b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 11 deletions.
37 changes: 37 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package traefik_oidc_auth

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net/http"
"net/url"
Expand Down Expand Up @@ -61,6 +63,9 @@ type ProviderConfig struct {
Url string `json:"url"`
UrlEnv string `json:"url_env"`

InsecureSkipVerify bool `json:"insecure_skip_verify"`
CACertificateFilePath string `json:"ca_certificate_file_path"`

ClientId string `json:"client_id"`
ClientIdEnv string `json:"client_id_env"`
ClientSecret string `json:"client_secret"`
Expand Down Expand Up @@ -213,9 +218,41 @@ func New(uctx context.Context, next http.Handler, config *Config, name string) (
log(config.LogLevel, LogLevelDebug, "Scopes: %s", strings.Join(config.Scopes, ", "))
log(config.LogLevel, LogLevelDebug, "SessionCookie: %v", config.SessionCookie)

rootCAs, _ := x509.SystemCertPool()
if rootCAs == nil {
rootCAs = x509.NewCertPool()
}

if config.Provider.CACertificateFilePath != "" {
certs, err := os.ReadFile(config.Provider.CACertificateFilePath)
if err != nil {
log(config.LogLevel, LogLevelInfo, "Failed to load CA certificate from %v: %v", config.Provider.CACertificateFilePath, err)
return nil, err
}

// Append our cert to the system pool
if ok := rootCAs.AppendCertsFromPEM(certs); !ok {
log(config.LogLevel, LogLevelWarn, "Failed to append CA certificate. Using system certificates only.")
}
}

httpTransport := &http.Transport{
// MaxIdleConns: 10,
// IdleConnTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Provider.InsecureSkipVerify,
RootCAs: rootCAs,
},
}

httpClient := &http.Client{
Transport: httpTransport,
}

log(config.LogLevel, LogLevelInfo, "Configuration loaded successfully, starting OIDC Auth middleware...")
return &TraefikOidcAuth{
next: next,
httpClient: httpClient,
ProviderURL: parsedURL,
CallbackURL: parsedCallbackURL,
Config: config,
Expand Down
6 changes: 3 additions & 3 deletions jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ func (h *JwksHandler) EnsureLoaded(oidcAuth *TraefikOidcAuth, forceReload bool)
if reload {
log(oidcAuth.Config.LogLevel, LogLevelInfo, "Reloading JWKS...")

return h.loadKeys()
return h.loadKeys(oidcAuth.httpClient)
}

return nil
}

func (h *JwksHandler) loadKeys() error {
resp, err := http.Get(h.Url)
func (h *JwksHandler) loadKeys(httpClient *http.Client) error {
resp, err := httpClient.Get(h.Url)

if err != nil {
return err
Expand Down
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

type TraefikOidcAuth struct {
next http.Handler
httpClient *http.Client
ProviderURL *url.URL
CallbackURL *url.URL
Config *Config
Expand All @@ -40,7 +41,7 @@ func (toa *TraefikOidcAuth) EnsureOidcDiscovery() error {
toa.Jwks = jwks
log(config.LogLevel, LogLevelInfo, "Getting OIDC discovery document...")

oidcDiscoveryDocument, err := GetOidcDiscovery(config.LogLevel, parsedURL)
oidcDiscoveryDocument, err := GetOidcDiscovery(config.LogLevel, toa.httpClient, parsedURL)
if err != nil {
log(config.LogLevel, LogLevelError, "Error while retrieving discovery document: %s", err.Error())
return err
Expand Down
12 changes: 5 additions & 7 deletions oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type OidcState struct {
RedirectUrl string `json:"redirect_url"`
}

func GetOidcDiscovery(logLevel string, providerUrl *url.URL) (*OidcDiscovery, error) {
func GetOidcDiscovery(logLevel string, httpClient *http.Client, providerUrl *url.URL) (*OidcDiscovery, error) {
wellKnownUrl := *providerUrl

wellKnownUrl.Path = path.Join(wellKnownUrl.Path, ".well-known/openid-configuration")
Expand All @@ -128,7 +128,7 @@ func GetOidcDiscovery(logLevel string, providerUrl *url.URL) (*OidcDiscovery, er
// client := &http.Client{Transport: tr}

// Make HTTP GET request to the OpenID provider's discovery endpoint
resp, err := http.Get(wellKnownUrl.String())
resp, err := httpClient.Get(wellKnownUrl.String())

if err != nil {
log(logLevel, LogLevelError, "http-get discovery endpoints - Err: %s", err.Error())
Expand Down Expand Up @@ -193,7 +193,7 @@ func exchangeAuthCode(oidcAuth *TraefikOidcAuth, req *http.Request, authCode str
urlValues.Add("code_verifier", codeVerifier)
}

resp, err := http.PostForm(oidcAuth.DiscoveryDocument.TokenEndpoint, urlValues)
resp, err := oidcAuth.httpClient.PostForm(oidcAuth.DiscoveryDocument.TokenEndpoint, urlValues)

if err != nil {
log(oidcAuth.Config.LogLevel, LogLevelError, "Sending AuthorizationCode in POST: %s", err.Error())
Expand Down Expand Up @@ -257,8 +257,6 @@ func (toa *TraefikOidcAuth) validateTokenLocally(tokenString string) (bool, map[
}

func (toa *TraefikOidcAuth) introspectToken(token string) (bool, map[string]interface{}, error) {
client := &http.Client{}

data := url.Values{
"token": {token},
}
Expand All @@ -284,7 +282,7 @@ func (toa *TraefikOidcAuth) introspectToken(token string) (bool, map[string]inte
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(toa.Config.Provider.ClientId, toa.Config.Provider.ClientSecret)

resp, err := client.Do(req)
resp, err := toa.httpClient.Do(req)
if err != nil {
log(toa.Config.LogLevel, LogLevelError, "Error on introspection request: %s", err.Error())
return false, nil, err
Expand Down Expand Up @@ -322,7 +320,7 @@ func (toa *TraefikOidcAuth) renewToken(refreshToken string) (*OidcTokenResponse,
urlValues.Add("client_secret", toa.Config.Provider.ClientSecret)
}

resp, err := http.PostForm(toa.DiscoveryDocument.TokenEndpoint, urlValues)
resp, err := toa.httpClient.PostForm(toa.DiscoveryDocument.TokenEndpoint, urlValues)

if err != nil {
log(toa.Config.LogLevel, LogLevelError, "Sending token renewal request in POST: %s", err.Error())
Expand Down

0 comments on commit 47ce99b

Please sign in to comment.