Skip to content

Commit

Permalink
fix(auth): handle non-Transport DefaultTransport (#10162)
Browse files Browse the repository at this point in the history
Since `http.DefaultTransport` is a `RoundTripper` interface and mutable global variable, it is not safe to assume it is always going to concretely be `*http.Transport`. If it is not, I suppose we should just use the value directly instead of making a `Clone`. The `http.DefaultTransport` being overridden is pretty intentional by application authors, so it is best if we respect that and just reuse it direct.

Fixes #10159
  • Loading branch information
noahdietz authored May 14, 2024
1 parent 1320d7d commit fa3bfdb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
9 changes: 8 additions & 1 deletion auth/httptransport/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,14 @@ func AddAuthorizationMiddleware(client *http.Client, creds *auth.Credentials) er
}
base := client.Transport
if base == nil {
base = http.DefaultTransport.(*http.Transport).Clone()
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
base = dt.Clone()
} else {
// Directly reuse the DefaultTransport if the application has
// replaced it with an implementation of RoundTripper other than
// http.Transport.
base = http.DefaultTransport
}
}
client.Transport = &authTransport{
creds: creds,
Expand Down
22 changes: 22 additions & 0 deletions auth/httptransport/httptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@ func TestAddAuthorizationMiddleware(t *testing.T) {
}
}

func TestAddAuthorizationMiddleware_HandlesNonTransportAsDefaultTransport(t *testing.T) {
client := &http.Client{}
creds := auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: staticTP("fakeToken"),
})
dt := http.DefaultTransport

http.DefaultTransport = &rt{}
defer func() { http.DefaultTransport = dt }()

err := AddAuthorizationMiddleware(client, creds)
if err != nil {
t.Fatal(err)
}

at := client.Transport.(*authTransport)
_, ok := at.base.(*rt)
if !ok {
t.Errorf("got %T, want %T", at.base, &rt{})
}
}

func TestNewClient_FailsValidation(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit fa3bfdb

Please sign in to comment.