Skip to content

Commit

Permalink
Refactored oauth2RoundTripper.RoundTrip (#634)
Browse files Browse the repository at this point in the history
* Avoid race condidtion on rt.rt == nil check
* Trying to improve readability (less ifs)
* Some comment fixes

Signed-off-by: bwplotka <[email protected]>
  • Loading branch information
bwplotka authored May 16, 2024
1 parent a7407da commit 6b9921f
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 135 deletions.
203 changes: 103 additions & 100 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ func (o *OAuth2) UnmarshalJSON(data []byte) error {
}

// SetDirectory joins any relative file paths with dir.
func (a *OAuth2) SetDirectory(dir string) {
if a == nil {
func (o *OAuth2) SetDirectory(dir string) {
if o == nil {
return
}
a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile)
a.TLSConfig.SetDirectory(dir)
o.ClientSecretFile = JoinDir(dir, o.ClientSecretFile)
o.TLSConfig.SetDirectory(dir)
}

// LoadHTTPConfig parses the YAML input s into a HTTPClientConfig.
Expand Down Expand Up @@ -563,7 +563,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
return NewRoundTripperFromConfigWithContext(context.Background(), cfg, name, optFuncs...)
}

// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// NewRoundTripperFromConfigWithContext returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig and config.HTTPClientOption.
// The name is used as go-conntrack metric label.
func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
Expand Down Expand Up @@ -647,7 +647,7 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon
}

if cfg.OAuth2 != nil {
clientSecret, err := toSecret(opts.secretManager, Secret(cfg.OAuth2.ClientSecret), cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef)
clientSecret, err := toSecret(opts.secretManager, cfg.OAuth2.ClientSecret, cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef)
if err != nil {
return nil, fmt.Errorf("unable to use client secret: %w", err)
}
Expand Down Expand Up @@ -702,7 +702,7 @@ type inlineSecret struct {
text string
}

func (s *inlineSecret) fetch(ctx context.Context) (string, error) {
func (s *inlineSecret) fetch(context.Context) (string, error) {
return s.text, nil
}

Expand Down Expand Up @@ -737,7 +737,7 @@ func (s *fileSecret) immutable() bool {
// refSecret fetches a single secret from a secret manager.
type refSecret struct {
ref string
manager SecretManager
manager SecretManager // manager is expected to be not nil.
}

func (s *refSecret) fetch(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -791,20 +791,22 @@ func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials se
}

func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) == 0 {
var authCredentials string
if rt.authCredentials != nil {
var err error
authCredentials, err = rt.authCredentials.fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read authorization credentials: %w", err)
}
}
if len(req.Header.Get("Authorization")) != 0 {
return rt.rt.RoundTrip(req)
}

req = cloneRequest(req)
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials))
var authCredentials string
if rt.authCredentials != nil {
var err error
authCredentials, err = rt.authCredentials.fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read authorization credentials: %w", err)
}
}

req = cloneRequest(req)
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials))

return rt.rt.RoundTrip(req)
}

Expand Down Expand Up @@ -858,117 +860,118 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() {
}

type oauth2RoundTripper struct {
mtx sync.RWMutex
lastRT *oauth2.Transport
lastSecret string

// Required for interaction with Oauth2 server.
config *OAuth2
rt http.RoundTripper
next http.RoundTripper
clientSecret secret
lastSecret string
mtx sync.RWMutex
opts *httpClientOptions
client *http.Client
}

func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
if clientSecret == nil {
clientSecret = &inlineSecret{text: ""}
}

return &oauth2RoundTripper{
config: config,
next: next,
config: config,
// A correct tokenSource will be added later on.
lastRT: &oauth2.Transport{Base: next},
opts: opts,
clientSecret: clientSecret,
}
}

func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
secret string
changed bool
)
func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret string) (client *http.Client, source oauth2.TokenSource, err error) {
tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager))
if err != nil {
return nil, nil, err
}

// Fetch the secret if it's our first run or always if the secret can change.
if rt.rt == nil || (rt.clientSecret != nil && !rt.clientSecret.immutable()) {
if rt.clientSecret != nil {
var err error
secret, err = rt.clientSecret.fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err)
}
tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
return &http.Transport{
TLSClientConfig: tlsConfig,
Proxy: rt.config.ProxyConfig.Proxy(),
ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(),
DisableKeepAlives: !rt.opts.keepAlivesEnabled,
MaxIdleConns: 20,
MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}, nil
}

if !rt.clientSecret.immutable() {
rt.mtx.RLock()
changed = secret != rt.lastSecret
rt.mtx.RUnlock()
}
var t http.RoundTripper
tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager)
if err != nil {
return nil, nil, err
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport)
if err != nil {
return nil, nil, err
}
}

if rt.rt == nil {
changed = true
}
if ua := req.UserAgent(); ua != "" {
t = NewUserAgentRoundTripper(ua, t)
}

if changed {
config := &clientcredentials.Config{
ClientID: rt.config.ClientID,
ClientSecret: secret,
Scopes: rt.config.Scopes,
TokenURL: rt.config.TokenURL,
EndpointParams: mapToValues(rt.config.EndpointParams),
}
config := &clientcredentials.Config{
ClientID: rt.config.ClientID,
ClientSecret: secret,
Scopes: rt.config.Scopes,
TokenURL: rt.config.TokenURL,
EndpointParams: mapToValues(rt.config.EndpointParams),
}
client = &http.Client{Transport: t}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)
return client, config.TokenSource(ctx), nil
}

tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager))
if err != nil {
return nil, err
}
func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
secret string
needsInit bool
)

tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
return &http.Transport{
TLSClientConfig: tlsConfig,
Proxy: rt.config.ProxyConfig.Proxy(),
ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(),
DisableKeepAlives: !rt.opts.keepAlivesEnabled,
MaxIdleConns: 20,
MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}, nil
}
rt.mtx.RLock()
secret = rt.lastSecret
needsInit = rt.lastRT.Source == nil
rt.mtx.RUnlock()

var t http.RoundTripper
tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager)
// Fetch the secret if it's our first run or always if the secret can change.
if !rt.clientSecret.immutable() || needsInit {
newSecret, err := rt.clientSecret.fetch(req.Context())
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err)
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport)
if newSecret != secret || needsInit {
// Secret changed or it's a first run. Rebuilt oauth2 setup.
client, source, err := rt.newOauth2TokenSource(req, newSecret)
if err != nil {
return nil, err
}
}

if ua := req.UserAgent(); ua != "" {
t = NewUserAgentRoundTripper(ua, t)
}

client := &http.Client{Transport: t}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)
tokenSource := config.TokenSource(ctx)

rt.mtx.Lock()
rt.lastSecret = secret
rt.rt = &oauth2.Transport{
Base: rt.next,
Source: tokenSource,
}
if rt.client != nil {
rt.client.CloseIdleConnections()
rt.mtx.Lock()
rt.lastSecret = secret
rt.lastRT.Source = source
if rt.client != nil {
rt.client.CloseIdleConnections()
}
rt.client = client
rt.mtx.Unlock()
}
rt.client = client
rt.mtx.Unlock()
}

rt.mtx.RLock()
currentRT := rt.rt
currentRT := rt.lastRT
rt.mtx.RUnlock()
return currentRT.RoundTrip(req)
}
Expand All @@ -977,7 +980,7 @@ func (rt *oauth2RoundTripper) CloseIdleConnections() {
if rt.client != nil {
rt.client.CloseIdleConnections()
}
if ci, ok := rt.next.(closeIdler); ok {
if ci, ok := rt.lastRT.Base.(closeIdler); ok {
ci.CloseIdleConnections()
}
}
Expand Down Expand Up @@ -1019,7 +1022,7 @@ func NewTLSConfig(cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, err
return NewTLSConfigWithContext(context.Background(), cfg, optFuncs...)
}

// NewTLSConfig creates a new tls.Config from the given TLSConfig.
// NewTLSConfigWithContext creates a new tls.Config from the given TLSConfig.
func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) {
opts := tlsConfigOptions{}
for _, opt := range optFuncs {
Expand Down
72 changes: 37 additions & 35 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,46 +507,48 @@ func TestNewClientFromConfig(t *testing.T) {
}

for _, validConfig := range newClientValidConfig {
testServer, err := newTestServer(validConfig.handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()
t.Run("", func(t *testing.T) {
testServer, err := newTestServer(validConfig.handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()

if validConfig.clientConfig.OAuth2 != nil {
// We don't have access to the test server's URL when configuring the test cases,
// so it has to be specified here.
validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token"
}
if validConfig.clientConfig.OAuth2 != nil {
// We don't have access to the test server's URL when configuring the test cases,
// so it has to be specified here.
validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token"
}

err = validConfig.clientConfig.Validate()
if err != nil {
t.Fatal(err.Error())
}
client, err := NewClientFromConfig(validConfig.clientConfig, "test")
if err != nil {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
continue
}
err = validConfig.clientConfig.Validate()
if err != nil {
t.Fatal(err.Error())
}
client, err := NewClientFromConfig(validConfig.clientConfig, "test")
if err != nil {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
return
}

response, err := client.Get(testServer.URL)
if err != nil {
t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err)
continue
}
response, err := client.Get(testServer.URL)
if err != nil {
t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err)
return
}

message, err := io.ReadAll(response.Body)
response.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig)
continue
}
message, err := io.ReadAll(response.Body)
response.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig)
return
}

trimMessage := strings.TrimSpace(string(message))
if ExpectedMessage != trimMessage {
t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v",
ExpectedMessage, trimMessage, validConfig.clientConfig)
}
trimMessage := strings.TrimSpace(string(message))
if ExpectedMessage != trimMessage {
t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v",
ExpectedMessage, trimMessage, validConfig.clientConfig)
}
})
}
}

Expand Down

0 comments on commit 6b9921f

Please sign in to comment.