Skip to content

Commit

Permalink
fix(auth): disable automatic universe domain check for MDS (#10620)
Browse files Browse the repository at this point in the history
* add (*Token).MetadataString
  • Loading branch information
quartzmo authored Aug 1, 2024
1 parent 2fef238 commit 7cea5ed
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 16 deletions.
14 changes: 14 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ func (t *Token) IsValid() bool {
return t.isValidWithEarlyExpiry(defaultExpiryDelta)
}

// MetadataString is a convenience method for accessing string values in the
// token's metadata. Returns an empty string if the metadata is nil or the value
// for the given key cannot be cast to a string.
func (t *Token) MetadataString(k string) string {
if t.Metadata == nil {
return ""
}
s, ok := t.Metadata[k].(string)
if !ok {
return ""
}
return s
}

func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool {
if t.isEmpty() {
return false
Expand Down
33 changes: 33 additions & 0 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ func TestError_Temporary(t *testing.T) {
}
}

func TestToken_MetadataString(t *testing.T) {
cases := []struct {
name string
metadata map[string]interface{}
want string
}{
{
name: "nil metadata",
want: "",
},
{
name: "not string",
metadata: map[string]interface{}{
"my.key": 123,
},
want: "",
},
{
name: "string",
metadata: map[string]interface{}{
"my.key": "my.value",
},
want: "my.value",
},
}
for _, tc := range cases {
tok := &Token{Metadata: tc.metadata}
if got, want := tok.MetadataString("my.key"), tc.want; got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}

func TestToken_isValidWithEarlyExpiry(t *testing.T) {
now := time.Now()
timeNow = func() time.Time { return now }
Expand Down
3 changes: 3 additions & 0 deletions auth/credentials/compute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ func TestComputeTokenProvider(t *testing.T) {
if want := "bearer"; tok.Type != want {
t.Errorf("got %q, want %q", tok.Type, want)
}
if got, want := tok.MetadataString("auth.google.tokenSource"), "compute-metadata"; got != want {
t.Errorf("got %q, want %q", got, want)
}
}
4 changes: 2 additions & 2 deletions auth/grpctransport/directpath.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ func isTokenProviderDirectPathCompatible(tp auth.TokenProvider, _ *Options) bool
if tok == nil {
return false
}
if source, _ := tok.Metadata["auth.google.tokenSource"].(string); source != "compute-metadata" {
if tok.MetadataString("auth.google.tokenSource") != "compute-metadata" {
return false
}
if acct, _ := tok.Metadata["auth.google.serviceAccount"].(string); acct != "default" {
if tok.MetadataString("auth.google.serviceAccount") != "default" {
return false
}
return true
Expand Down
16 changes: 9 additions & 7 deletions auth/grpctransport/grpctransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,17 +337,19 @@ func (c *grpcCredentialsProvider) getClientUniverseDomain() string {
}

func (c *grpcCredentialsProvider) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
credentialsUniverseDomain, err := c.creds.UniverseDomain(ctx)
if err != nil {
return nil, err
}
if err := transport.ValidateUniverseDomain(c.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
return nil, err
}
token, err := c.creds.Token(ctx)
if err != nil {
return nil, err
}
if token.MetadataString("auth.google.tokenSource") != "compute-metadata" {
credentialsUniverseDomain, err := c.creds.UniverseDomain(ctx)
if err != nil {
return nil, err
}
if err := transport.ValidateUniverseDomain(c.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
return nil, err
}
}
if c.secure {
ri, _ := grpccreds.RequestInfoFromContext(ctx)
if err = grpccreds.CheckSecurityLevel(ri.AuthInfo, grpccreds.PrivacyAndIntegrity); err != nil {
Expand Down
16 changes: 9 additions & 7 deletions auth/httptransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,19 @@ func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
}()
}
credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
if err != nil {
return nil, err
}
if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
return nil, err
}
token, err := t.creds.Token(req.Context())
if err != nil {
return nil, err
}
if token.MetadataString("auth.google.tokenSource") != "compute-metadata" {
credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
if err != nil {
return nil, err
}
if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
return nil, err
}
}
req2 := req.Clone(req.Context())
SetAuthHeader(token, req2)
reqBodyClosed = true
Expand Down

0 comments on commit 7cea5ed

Please sign in to comment.