Skip to content

Commit

Permalink
feat (config): add support for a http.RoundTripper
Browse files Browse the repository at this point in the history
Add support for specifying an optional http.RoundTripper
for a provider config.  If specified the http
client will use the RoundTripper when making
requests to the provider.
  • Loading branch information
jimlambrt committed Aug 1, 2024
1 parent 36b85f9 commit dfc6d99
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
40 changes: 38 additions & 2 deletions oidc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"hash"
"hash/fnv"
"net/http"
"net/url"
"reflect"
"runtime"
Expand Down Expand Up @@ -89,9 +90,15 @@ type Config struct {

// ProviderCA is an optional CA certs (PEM encoded) to use when sending
// requests to the provider. If you have a list of *x509.Certificates, then
// see EncodeCertificates(...) to PEM encode them.
// see EncodeCertificates(...) to PEM encode them. Note: specifying both
// ProviderCA and RoundTripper is an error.
ProviderCA string

// RoundTripper is an optional http.RoundTripper to use when sending requests
// to the provider. Note: specifying both ProviderCA and RoundTripper is an
// error.
RoundTripper http.RoundTripper

// NowFunc is a time func that returns the current time.
NowFunc func() time.Time `json:"-"`

Expand All @@ -118,6 +125,7 @@ func NewConfig(issuer string, clientID string, clientSecret ClientSecret, suppor
SupportedSigningAlgs: supported,
Scopes: opts.withScopes,
ProviderCA: opts.withProviderCA,
RoundTripper: opts.withRoundTripper,
Audiences: opts.withAudiences,
NowFunc: opts.withNowFunc,
AllowedRedirectURLs: allowedRedirectURLs,
Expand Down Expand Up @@ -168,6 +176,16 @@ func (c *Config) Hash() (uint64, error) {
args = append(args, audiences...)
args = append(args, redirects...)

if c.RoundTripper != nil {
v := reflect.ValueOf(c.RoundTripper)
switch {
case v.CanAddr():
args = append(args, v.Addr().String())
default:
args = append(args, v.String())
}
}

if c.ProviderConfig != nil {
args = append(
args,
Expand Down Expand Up @@ -269,6 +287,9 @@ func (c *Config) Validate() error {
return fmt.Errorf("%s: %w", op, ErrInvalidCACert)
}
}
if c.ProviderCA != "" && c.RoundTripper != nil {
return fmt.Errorf("%s: you cannot specify both a ProviderCA and Transport: %w", op, ErrInvalidParameter)
}

if c.ProviderConfig != nil {
switch {
Expand Down Expand Up @@ -300,6 +321,7 @@ type configOptions struct {
withProviderCA string
withNowFunc func() time.Time
withProviderConfig *ProviderConfig
withRoundTripper http.RoundTripper
}

// configDefaults is a handy way to get the defaults at runtime and
Expand All @@ -319,12 +341,14 @@ func getConfigOpts(opt ...Option) configOptions {
}

// WithProviderCA provides optional CA certs (PEM encoded) for the provider's
// config. These certs will can be used when making http requests to the
// config. These certs will be used when making http requests to the
// provider.
//
// Valid for: Config
//
// See EncodeCertificates(...) to PEM encode a number of certs.
//
// Note: specifying both WithProviderCA and WithRoundTripper is a error.
func WithProviderCA(cert string) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
Expand All @@ -333,6 +357,18 @@ func WithProviderCA(cert string) Option {
}
}

// WithRoundTripper provides and optional RoundTripper for the provider's
// config. This RoundTripper will be used when making http requests to the
// provider. Note: specifying both WithProviderCA and WithRoundTripper is a
// error.
func WithRoundTripper(rt http.RoundTripper) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withRoundTripper = rt
}
}
}

// EncodeCertificates will encode a number of x509 certificates to PEM. It will
// help encode certs for use with the WithProviderCA(...) option.
func EncodeCertificates(certs ...*x509.Certificate) (string, error) {
Expand Down
4 changes: 2 additions & 2 deletions oidc/docs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func ExampleNewConfig() {
fmt.Println(pc)

// Output:
// &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] <nil> <nil>}
// &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] <nil> <nil> <nil>}
}

func ExampleWithProviderConfig() {
Expand All @@ -120,7 +120,7 @@ func ExampleWithProviderConfig() {
fmt.Println(string(val))

// Output:
// {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}}
// {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","RoundTripper":null,"ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}}
}

func ExampleNewProvider() {
Expand Down
13 changes: 9 additions & 4 deletions oidc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,17 +635,22 @@ func (p *Provider) HTTPClient() (*http.Client, error) {
// to the same host. On the downside, this transport can leak file
// descriptors over time, so we'll be sure to call
// client.CloseIdleConnections() in the Provider.Done() to stave that off.
tr := cleanhttp.DefaultPooledTransport()
var tr http.RoundTripper

if p.config.ProviderCA != "" {
switch {
case p.config.RoundTripper != nil && p.config.ProviderCA != "":
return nil, fmt.Errorf("%s: you cannot specify config for both a ProviderCA and Transport: %w", op, ErrInvalidParameter)
case p.config.ProviderCA != "":
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(p.config.ProviderCA)); !ok {
return nil, fmt.Errorf("%s: %w", op, ErrInvalidCACert)
}

tr.TLSClientConfig = &tls.Config{
tr = cleanhttp.DefaultPooledTransport()
tr.(*http.Transport).TLSClientConfig = &tls.Config{
RootCAs: certPool,
}
case p.config.RoundTripper != nil:
tr = p.config.RoundTripper
}

c := &http.Client{
Expand Down

0 comments on commit dfc6d99

Please sign in to comment.