diff --git a/oidc/config.go b/oidc/config.go index db2f4a6..f782e81 100644 --- a/oidc/config.go +++ b/oidc/config.go @@ -12,6 +12,7 @@ import ( "fmt" "hash" "hash/fnv" + "net/http" "net/url" "reflect" "runtime" @@ -89,9 +90,15 @@ type Config struct { // ProviderCA is an optional CA certs (PEM encoded) to use when sending // requests to the provider. If you have a list of *x509.Certificates, then - // see EncodeCertificates(...) to PEM encode them. + // see EncodeCertificates(...) to PEM encode them. Note: specifying both + // ProviderCA and RoundTripper is an error. ProviderCA string + // RoundTripper is an optional http.RoundTripper to use when sending requests + // to the provider. Note: specifying both ProviderCA and RoundTripper is an + // error. + RoundTripper http.RoundTripper + // NowFunc is a time func that returns the current time. NowFunc func() time.Time `json:"-"` @@ -118,6 +125,7 @@ func NewConfig(issuer string, clientID string, clientSecret ClientSecret, suppor SupportedSigningAlgs: supported, Scopes: opts.withScopes, ProviderCA: opts.withProviderCA, + RoundTripper: opts.withRoundTripper, Audiences: opts.withAudiences, NowFunc: opts.withNowFunc, AllowedRedirectURLs: allowedRedirectURLs, @@ -168,6 +176,16 @@ func (c *Config) Hash() (uint64, error) { args = append(args, audiences...) args = append(args, redirects...) + if c.RoundTripper != nil { + v := reflect.ValueOf(c.RoundTripper) + switch { + case v.CanAddr(): + args = append(args, v.Addr().String()) + default: + args = append(args, v.String()) + } + } + if c.ProviderConfig != nil { args = append( args, @@ -269,6 +287,9 @@ func (c *Config) Validate() error { return fmt.Errorf("%s: %w", op, ErrInvalidCACert) } } + if c.ProviderCA != "" && c.RoundTripper != nil { + return fmt.Errorf("%s: you cannot specify both a ProviderCA and Transport: %w", op, ErrInvalidParameter) + } if c.ProviderConfig != nil { switch { @@ -300,6 +321,7 @@ type configOptions struct { withProviderCA string withNowFunc func() time.Time withProviderConfig *ProviderConfig + withRoundTripper http.RoundTripper } // configDefaults is a handy way to get the defaults at runtime and @@ -319,12 +341,14 @@ func getConfigOpts(opt ...Option) configOptions { } // WithProviderCA provides optional CA certs (PEM encoded) for the provider's -// config. These certs will can be used when making http requests to the +// config. These certs will be used when making http requests to the // provider. // // Valid for: Config // // See EncodeCertificates(...) to PEM encode a number of certs. +// +// Note: specifying both WithProviderCA and WithRoundTripper is a error. func WithProviderCA(cert string) Option { return func(o interface{}) { if o, ok := o.(*configOptions); ok { @@ -333,6 +357,18 @@ func WithProviderCA(cert string) Option { } } +// WithRoundTripper provides and optional RoundTripper for the provider's +// config. This RoundTripper will be used when making http requests to the +// provider. Note: specifying both WithProviderCA and WithRoundTripper is a +// error. +func WithRoundTripper(rt http.RoundTripper) Option { + return func(o interface{}) { + if o, ok := o.(*configOptions); ok { + o.withRoundTripper = rt + } + } +} + // EncodeCertificates will encode a number of x509 certificates to PEM. It will // help encode certs for use with the WithProviderCA(...) option. func EncodeCertificates(certs ...*x509.Certificate) (string, error) { diff --git a/oidc/docs_test.go b/oidc/docs_test.go index 0e50db0..4b84ec5 100644 --- a/oidc/docs_test.go +++ b/oidc/docs_test.go @@ -95,7 +95,7 @@ func ExampleNewConfig() { fmt.Println(pc) // Output: - // &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] } + // &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] } } func ExampleWithProviderConfig() { @@ -120,7 +120,7 @@ func ExampleWithProviderConfig() { fmt.Println(string(val)) // Output: - // {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}} + // {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","RoundTripper":null,"ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}} } func ExampleNewProvider() { diff --git a/oidc/provider.go b/oidc/provider.go index 4c0f491..d03b2e7 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -635,17 +635,22 @@ func (p *Provider) HTTPClient() (*http.Client, error) { // to the same host. On the downside, this transport can leak file // descriptors over time, so we'll be sure to call // client.CloseIdleConnections() in the Provider.Done() to stave that off. - tr := cleanhttp.DefaultPooledTransport() + var tr http.RoundTripper - if p.config.ProviderCA != "" { + switch { + case p.config.RoundTripper != nil && p.config.ProviderCA != "": + return nil, fmt.Errorf("%s: you cannot specify config for both a ProviderCA and Transport: %w", op, ErrInvalidParameter) + case p.config.ProviderCA != "": certPool := x509.NewCertPool() if ok := certPool.AppendCertsFromPEM([]byte(p.config.ProviderCA)); !ok { return nil, fmt.Errorf("%s: %w", op, ErrInvalidCACert) } - - tr.TLSClientConfig = &tls.Config{ + tr = cleanhttp.DefaultPooledTransport() + tr.(*http.Transport).TLSClientConfig = &tls.Config{ RootCAs: certPool, } + case p.config.RoundTripper != nil: + tr = p.config.RoundTripper } c := &http.Client{