diff --git a/auth/auth.go b/auth/auth.go index 58af93188774..2a4350af3f11 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -101,6 +101,20 @@ func (t *Token) IsValid() bool { return t.isValidWithEarlyExpiry(defaultExpiryDelta) } +// MetadataString is a convenience method for accessing string values in the +// token's metadata. Returns an empty string if the metadata is nil or the value +// for the given key cannot be cast to a string. +func (t *Token) MetadataString(k string) string { + if t.Metadata == nil { + return "" + } + s, ok := t.Metadata[k].(string) + if !ok { + return "" + } + return s +} + func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool { if t.isEmpty() { return false diff --git a/auth/auth_test.go b/auth/auth_test.go index b25222bfafa7..4cd726c0de37 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -103,6 +103,39 @@ func TestError_Temporary(t *testing.T) { } } +func TestToken_MetadataString(t *testing.T) { + cases := []struct { + name string + metadata map[string]interface{} + want string + }{ + { + name: "nil metadata", + want: "", + }, + { + name: "not string", + metadata: map[string]interface{}{ + "my.key": 123, + }, + want: "", + }, + { + name: "string", + metadata: map[string]interface{}{ + "my.key": "my.value", + }, + want: "my.value", + }, + } + for _, tc := range cases { + tok := &Token{Metadata: tc.metadata} + if got, want := tok.MetadataString("my.key"), tc.want; got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + func TestToken_isValidWithEarlyExpiry(t *testing.T) { now := time.Now() timeNow = func() time.Time { return now } diff --git a/auth/credentials/compute_test.go b/auth/credentials/compute_test.go index 6d0d691839e1..9f2cee8d2a17 100644 --- a/auth/credentials/compute_test.go +++ b/auth/credentials/compute_test.go @@ -53,4 +53,7 @@ func TestComputeTokenProvider(t *testing.T) { if want := "bearer"; tok.Type != want { t.Errorf("got %q, want %q", tok.Type, want) } + if got, want := tok.MetadataString("auth.google.tokenSource"), "compute-metadata"; got != want { + t.Errorf("got %q, want %q", got, want) + } } diff --git a/auth/grpctransport/directpath.go b/auth/grpctransport/directpath.go index 8dbfa7ef7e90..efc91c2b0c35 100644 --- a/auth/grpctransport/directpath.go +++ b/auth/grpctransport/directpath.go @@ -66,10 +66,10 @@ func isTokenProviderDirectPathCompatible(tp auth.TokenProvider, _ *Options) bool if tok == nil { return false } - if source, _ := tok.Metadata["auth.google.tokenSource"].(string); source != "compute-metadata" { + if tok.MetadataString("auth.google.tokenSource") != "compute-metadata" { return false } - if acct, _ := tok.Metadata["auth.google.serviceAccount"].(string); acct != "default" { + if tok.MetadataString("auth.google.serviceAccount") != "default" { return false } return true diff --git a/auth/grpctransport/grpctransport.go b/auth/grpctransport/grpctransport.go index 5c3bc66f9981..0442a5938a80 100644 --- a/auth/grpctransport/grpctransport.go +++ b/auth/grpctransport/grpctransport.go @@ -337,17 +337,19 @@ func (c *grpcCredentialsProvider) getClientUniverseDomain() string { } func (c *grpcCredentialsProvider) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - credentialsUniverseDomain, err := c.creds.UniverseDomain(ctx) - if err != nil { - return nil, err - } - if err := transport.ValidateUniverseDomain(c.getClientUniverseDomain(), credentialsUniverseDomain); err != nil { - return nil, err - } token, err := c.creds.Token(ctx) if err != nil { return nil, err } + if token.MetadataString("auth.google.tokenSource") != "compute-metadata" { + credentialsUniverseDomain, err := c.creds.UniverseDomain(ctx) + if err != nil { + return nil, err + } + if err := transport.ValidateUniverseDomain(c.getClientUniverseDomain(), credentialsUniverseDomain); err != nil { + return nil, err + } + } if c.secure { ri, _ := grpccreds.RequestInfoFromContext(ctx) if err = grpccreds.CheckSecurityLevel(ri.AuthInfo, grpccreds.PrivacyAndIntegrity); err != nil { diff --git a/auth/httptransport/transport.go b/auth/httptransport/transport.go index 94caeb00f0ab..07eea474446b 100644 --- a/auth/httptransport/transport.go +++ b/auth/httptransport/transport.go @@ -193,17 +193,19 @@ func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { } }() } - credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context()) - if err != nil { - return nil, err - } - if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil { - return nil, err - } token, err := t.creds.Token(req.Context()) if err != nil { return nil, err } + if token.MetadataString("auth.google.tokenSource") != "compute-metadata" { + credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context()) + if err != nil { + return nil, err + } + if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil { + return nil, err + } + } req2 := req.Clone(req.Context()) SetAuthHeader(token, req2) reqBodyClosed = true