From e06cb6499f7eda3aef08ab18ff197016f667684b Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 28 Oct 2024 16:56:20 -0600 Subject: [PATCH] feat(auth): add universe domain support to credentials/impersonate (#10953) * return err if both Client and Credentials are set in CredentialsOptions --- auth/credentials/impersonate/impersonate.go | 102 +++++----- .../impersonate/impersonate_test.go | 183 ++++++++++-------- .../impersonate/integration_test.go | 29 ++- auth/credentials/impersonate/user.go | 36 +++- auth/credentials/impersonate/user_test.go | 7 +- auth/httptransport/httptransport.go | 3 +- 6 files changed, 206 insertions(+), 154 deletions(-) diff --git a/auth/credentials/impersonate/impersonate.go b/auth/credentials/impersonate/impersonate.go index 91b42bc3f7f3..3af236f7d07d 100644 --- a/auth/credentials/impersonate/impersonate.go +++ b/auth/credentials/impersonate/impersonate.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "net/http" + "strings" "time" "cloud.google.com/go/auth" @@ -30,11 +31,13 @@ import ( ) var ( - iamCredentialsEndpoint = "https://iamcredentials.googleapis.com" + universeDomainPlaceholder = "UNIVERSE_DOMAIN" + iamCredentialsEndpoint = "https://iamcredentials.UNIVERSE_DOMAIN" oauth2Endpoint = "https://oauth2.googleapis.com" errMissingTargetPrincipal = errors.New("impersonate: target service account must be provided") errMissingScopes = errors.New("impersonate: scopes must be provided") errLifetimeOverMax = errors.New("impersonate: max lifetime is 12 hours") + errClientAndCredentials = errors.New("impersonate: client and credentials must not both be provided") errUniverseNotSupportedDomainWideDelegation = errors.New("impersonate: service account user is configured for the credential. " + "Domain-wide delegation is not supported in universes other than googleapis.com") ) @@ -62,55 +65,49 @@ func NewCredentials(opts *CredentialsOptions) (*auth.Credentials, error) { var client *http.Client var creds *auth.Credentials - if opts.Client == nil && opts.Credentials == nil { + if opts.Client == nil { var err error - creds, err = credentials.DetectDefault(&credentials.DetectOptions{ - Scopes: []string{defaultScope}, - UseSelfSignedJWT: true, - }) - if err != nil { - return nil, err + if opts.Credentials == nil { + creds, err = credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{defaultScope}, + UseSelfSignedJWT: true, + }) + if err != nil { + return nil, err + } + } else { + creds = opts.Credentials } client, err = httptransport.NewClient(&httptransport.Options{ - Credentials: creds, + Credentials: creds, + UniverseDomain: opts.UniverseDomain, }) if err != nil { return nil, err } - } else if opts.Credentials != nil { - creds = opts.Credentials - client = internal.DefaultClient() - if err := httptransport.AddAuthorizationMiddleware(client, opts.Credentials); err != nil { - return nil, err - } } else { client = opts.Client } + universeDomainProvider := resolveUniverseDomainProvider(creds) // If a subject is specified a domain-wide delegation auth-flow is initiated // to impersonate as the provided subject (user). if opts.Subject != "" { - if !opts.isUniverseDomainGDU() { - return nil, errUniverseNotSupportedDomainWideDelegation - } - tp, err := user(opts, client, lifetime, isStaticToken) + tp, err := user(opts, client, lifetime, isStaticToken, universeDomainProvider) if err != nil { return nil, err } - var udp auth.CredentialsPropertyProvider - if creds != nil { - udp = auth.CredentialsPropertyFunc(creds.UniverseDomain) - } return auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: tp, - UniverseDomainProvider: udp, + UniverseDomainProvider: universeDomainProvider, }), nil } its := impersonatedTokenProvider{ - client: client, - targetPrincipal: opts.TargetPrincipal, - lifetime: fmt.Sprintf("%.fs", lifetime.Seconds()), + client: client, + targetPrincipal: opts.TargetPrincipal, + lifetime: fmt.Sprintf("%.fs", lifetime.Seconds()), + universeDomainProvider: universeDomainProvider, } for _, v := range opts.Delegates { its.delegates = append(its.delegates, formatIAMServiceAccountName(v)) @@ -125,16 +122,23 @@ func NewCredentials(opts *CredentialsOptions) (*auth.Credentials, error) { } } - var udp auth.CredentialsPropertyProvider - if creds != nil { - udp = auth.CredentialsPropertyFunc(creds.UniverseDomain) - } return auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: auth.NewCachedTokenProvider(its, tpo), - UniverseDomainProvider: udp, + UniverseDomainProvider: universeDomainProvider, }), nil } +// resolveUniverseDomainProvider returns the default service domain for a given +// Cloud universe. This is the universe domain configured for the credentials, +// which will be used in endpoint(s), and compared to the universe domain that +// is separately configured for the client. +func resolveUniverseDomainProvider(creds *auth.Credentials) auth.CredentialsPropertyProvider { + if creds != nil { + return auth.CredentialsPropertyFunc(creds.UniverseDomain) + } + return internal.StaticCredentialsProperty(internal.DefaultUniverseDomain) +} + // CredentialsOptions for generating an impersonated credential token. type CredentialsOptions struct { // TargetPrincipal is the email address of the service account to @@ -163,11 +167,13 @@ type CredentialsOptions struct { // will try to be detected from the environment. Optional. Credentials *auth.Credentials // Client configures the underlying client used to make network requests - // when fetching tokens. If provided the client should provide it's own + // when fetching tokens. If provided the client should provide its own // credentials at call time. Optional. Client *http.Client // UniverseDomain is the default service domain for a given Cloud universe. - // The default value is "googleapis.com". Optional. + // The default value is "googleapis.com". This is the universe domain + // configured for the client, which will be compared to the universe domain + // that is separately configured for the credentials. Optional. UniverseDomain string } @@ -184,22 +190,10 @@ func (o *CredentialsOptions) validate() error { if o.Lifetime.Hours() > 12 { return errLifetimeOverMax } - return nil -} - -// getUniverseDomain is the default service domain for a given Cloud universe. -// The default value is "googleapis.com". -func (o *CredentialsOptions) getUniverseDomain() string { - if o.UniverseDomain == "" { - return internal.DefaultUniverseDomain + if o.Client != nil && o.Credentials != nil { + return errClientAndCredentials } - return o.UniverseDomain -} - -// isUniverseDomainGDU returns true if the universe domain is the default Google -// universe. -func (o *CredentialsOptions) isUniverseDomainGDU() bool { - return o.getUniverseDomain() == internal.DefaultUniverseDomain + return nil } func formatIAMServiceAccountName(name string) string { @@ -218,7 +212,8 @@ type generateAccessTokenResponse struct { } type impersonatedTokenProvider struct { - client *http.Client + client *http.Client + universeDomainProvider auth.CredentialsPropertyProvider targetPrincipal string lifetime string @@ -237,7 +232,12 @@ func (i impersonatedTokenProvider) Token(ctx context.Context) (*auth.Token, erro if err != nil { return nil, fmt.Errorf("impersonate: unable to marshal request: %w", err) } - url := fmt.Sprintf("%s/v1/%s:generateAccessToken", iamCredentialsEndpoint, formatIAMServiceAccountName(i.targetPrincipal)) + universeDomain, err := i.universeDomainProvider.GetProperty(ctx) + if err != nil { + return nil, err + } + endpoint := strings.Replace(iamCredentialsEndpoint, universeDomainPlaceholder, universeDomain, 1) + url := fmt.Sprintf("%s/v1/%s:generateAccessToken", endpoint, formatIAMServiceAccountName(i.targetPrincipal)) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(b)) if err != nil { return nil, fmt.Errorf("impersonate: unable to create request: %w", err) diff --git a/auth/credentials/impersonate/impersonate_test.go b/auth/credentials/impersonate/impersonate_test.go index cd468a43a1f2..777b3a1e02c6 100644 --- a/auth/credentials/impersonate/impersonate_test.go +++ b/auth/credentials/impersonate/impersonate_test.go @@ -24,15 +24,18 @@ import ( "testing" "time" + "cloud.google.com/go/auth" + "cloud.google.com/go/auth/internal" "github.com/google/go-cmp/cmp" ) func TestNewCredentials_serviceAccount(t *testing.T) { ctx := context.Background() tests := []struct { - name string - config CredentialsOptions - wantErr error + name string + config CredentialsOptions + wantErr error + wantUniverseDomain string }{ { name: "missing targetPrincipal", @@ -54,23 +57,55 @@ func TestNewCredentials_serviceAccount(t *testing.T) { }, wantErr: errLifetimeOverMax, }, + { + name: "credentials and client", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + Client: &http.Client{}, + Credentials: staticCredentials("googleapis.com"), + }, + wantErr: errClientAndCredentials, + }, { name: "works", config: CredentialsOptions{ TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", Scopes: []string{"scope"}, }, - wantErr: nil, + wantErr: nil, + wantUniverseDomain: "googleapis.com", }, { - name: "universe domain", + name: "universe domain from options", config: CredentialsOptions{ TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", Scopes: []string{"scope"}, - Subject: "admin@example.com", UniverseDomain: "example.com", }, - wantErr: errUniverseNotSupportedDomainWideDelegation, + wantErr: nil, + wantUniverseDomain: "googleapis.com", // From creds, not CredentialsOptions.UniverseDomain + }, + { + name: "universe domain from options and credentials", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + UniverseDomain: "NOT.example.com", + Credentials: staticCredentials("example.com"), + }, + wantErr: nil, + wantUniverseDomain: "example.com", // From creds, not CredentialsOptions.UniverseDomain + }, + { + name: "universe domain from credentials", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + Credentials: staticCredentials("example.com"), + }, + wantErr: nil, + wantUniverseDomain: "example.com", }, } @@ -80,53 +115,66 @@ func TestNewCredentials_serviceAccount(t *testing.T) { saTok := "sa-token" client := &http.Client{ Transport: RoundTripFn(func(req *http.Request) *http.Response { - if strings.Contains(req.URL.Path, "generateAccessToken") { - defer req.Body.Close() - b, err := io.ReadAll(req.Body) - if err != nil { - t.Error(err) - } - var r generateAccessTokenRequest - if err := json.Unmarshal(b, &r); err != nil { - t.Error(err) - } - if !cmp.Equal(r.Scope, tt.config.Scopes) { - t.Errorf("got %v, want %v", r.Scope, tt.config.Scopes) - } - if !strings.Contains(req.URL.Path, tt.config.TargetPrincipal) { - t.Errorf("got %q, want %q", req.URL.Path, tt.config.TargetPrincipal) - } + if !strings.Contains(req.URL.Path, "generateAccessToken") { + t.Fatal("path must contain 'generateAccessToken'") + } + defer req.Body.Close() + b, err := io.ReadAll(req.Body) + if err != nil { + t.Error(err) + } + var r generateAccessTokenRequest + if err := json.Unmarshal(b, &r); err != nil { + t.Error(err) + } + if !cmp.Equal(r.Scope, tt.config.Scopes) { + t.Errorf("got %v, want %v", r.Scope, tt.config.Scopes) + } + if !strings.Contains(req.URL.Path, tt.config.TargetPrincipal) { + t.Errorf("got %q, want %q", req.URL.Path, tt.config.TargetPrincipal) + } + if !strings.Contains(req.URL.Hostname(), tt.wantUniverseDomain) { + t.Errorf("got %q, want %q", req.URL.Hostname(), tt.wantUniverseDomain) + } - resp := generateAccessTokenResponse{ - AccessToken: saTok, - ExpireTime: time.Now().Format(time.RFC3339), - } - b, err = json.Marshal(&resp) - if err != nil { - t.Fatalf("unable to marshal response: %v", err) - } - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewReader(b)), - Header: http.Header{}, - } + resp := generateAccessTokenResponse{ + AccessToken: saTok, + ExpireTime: time.Now().Format(time.RFC3339), + } + b, err = json.Marshal(&resp) + if err != nil { + t.Fatalf("unable to marshal response: %v", err) + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(b)), + Header: http.Header{}, } - return nil }), } - tt.config.Client = client - ts, err := NewCredentials(&tt.config) + if tt.config.Credentials == nil { + tt.config.Client = client + } + creds, err := NewCredentials(&tt.config) if err != nil { if err != tt.wantErr { t.Fatalf("err: %v", err) } + } else if tt.config.Credentials != nil { + // config.Credentials is invalid for Token request, just assert universe domain. + if got, _ := creds.UniverseDomain(ctx); got != tt.wantUniverseDomain { + t.Errorf("got %q, want %q", got, tt.wantUniverseDomain) + } } else { - tok, err := ts.Token(ctx) + tok, err := creds.Token(ctx) if err != nil { - t.Fatal(err) + t.Error(err) } if tok.Value != saTok { - t.Fatalf("got %q, want %q", tok.Value, saTok) + t.Errorf("got %q, want %q", tok.Value, saTok) + } + if got, _ := creds.UniverseDomain(ctx); got != tt.wantUniverseDomain { + t.Errorf("got %q, want %q", got, tt.wantUniverseDomain) } } }) @@ -137,44 +185,15 @@ type RoundTripFn func(req *http.Request) *http.Response func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil } -func TestCredentialsOptions_UniverseDomain(t *testing.T) { - testCases := []struct { - name string - opts *CredentialsOptions - wantUniverseDomain string - wantIsGDU bool - }{ - { - name: "empty", - opts: &CredentialsOptions{}, - wantUniverseDomain: "googleapis.com", - wantIsGDU: true, - }, - { - name: "defaults", - opts: &CredentialsOptions{ - UniverseDomain: "googleapis.com", - }, - wantUniverseDomain: "googleapis.com", - wantIsGDU: true, - }, - { - name: "non-GDU", - opts: &CredentialsOptions{ - UniverseDomain: "example.com", - }, - wantUniverseDomain: "example.com", - wantIsGDU: false, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if got := tc.opts.getUniverseDomain(); got != tc.wantUniverseDomain { - t.Errorf("got %v, want %v", got, tc.wantUniverseDomain) - } - if got := tc.opts.isUniverseDomainGDU(); got != tc.wantIsGDU { - t.Errorf("got %v, want %v", got, tc.wantIsGDU) - } - }) - } +func staticCredentials(universeDomain string) *auth.Credentials { + return auth.NewCredentials(&auth.CredentialsOptions{ + TokenProvider: staticTokenProvider("base credentials Token should never be called"), + UniverseDomainProvider: internal.StaticCredentialsProperty(universeDomain), + }) +} + +type staticTokenProvider string + +func (s staticTokenProvider) Token(context.Context) (*auth.Token, error) { + return &auth.Token{Value: string(s)}, nil } diff --git a/auth/credentials/impersonate/integration_test.go b/auth/credentials/impersonate/integration_test.go index fd12b4fffc75..2637fc24fae6 100644 --- a/auth/credentials/impersonate/integration_test.go +++ b/auth/credentials/impersonate/integration_test.go @@ -58,15 +58,28 @@ func TestMain(m *testing.M) { readerEmail = os.Getenv(envReaderEmail) writerEmail = os.Getenv(envWriterEmail) - if !testing.Short() && (baseKeyFile == "" || - readerKeyFile == "" || - readerEmail == "" || - writerEmail == "" || - projectID == "") { - log.Println("required environment variable not set, skipping") - os.Exit(0) + if !testing.Short() { + missing := []string{} + if baseKeyFile == "" { + missing = append(missing, credsfile.GoogleAppCredsEnvVar) + } + if projectID == "" { + missing = append(missing, envProjectID) + } + if readerKeyFile == "" { + missing = append(missing, envReaderCreds) + } + if readerEmail == "" { + missing = append(missing, envReaderEmail) + } + if writerEmail == "" { + missing = append(missing, envWriterEmail) + } + if len(missing) > 0 { + log.Printf("skipping, required environment variable(s) not set: %s\n", missing) + os.Exit(0) + } } - os.Exit(m.Run()) } diff --git a/auth/credentials/impersonate/user.go b/auth/credentials/impersonate/user.go index 1acaaa922d9d..b5e5fc8f6645 100644 --- a/auth/credentials/impersonate/user.go +++ b/auth/credentials/impersonate/user.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -30,12 +31,16 @@ import ( // user provides an auth flow for domain-wide delegation, setting // CredentialsConfig.Subject to be the impersonated user. -func user(opts *CredentialsOptions, client *http.Client, lifetime time.Duration, isStaticToken bool) (auth.TokenProvider, error) { +func user(opts *CredentialsOptions, client *http.Client, lifetime time.Duration, isStaticToken bool, universeDomainProvider auth.CredentialsPropertyProvider) (auth.TokenProvider, error) { + if opts.Subject == "" { + return nil, errors.New("CredentialsConfig.Subject must not be empty") + } u := userTokenProvider{ - client: client, - targetPrincipal: opts.TargetPrincipal, - subject: opts.Subject, - lifetime: lifetime, + client: client, + targetPrincipal: opts.TargetPrincipal, + subject: opts.Subject, + lifetime: lifetime, + universeDomainProvider: universeDomainProvider, } u.delegates = make([]string, len(opts.Delegates)) for i, v := range opts.Delegates { @@ -84,14 +89,25 @@ type exchangeTokenResponse struct { type userTokenProvider struct { client *http.Client - targetPrincipal string - subject string - scopes []string - lifetime time.Duration - delegates []string + targetPrincipal string + subject string + scopes []string + lifetime time.Duration + delegates []string + universeDomainProvider auth.CredentialsPropertyProvider } func (u userTokenProvider) Token(ctx context.Context) (*auth.Token, error) { + // Because a subject is specified a domain-wide delegation auth-flow is initiated + // to impersonate as the provided subject (user). + // Return error if users try to use domain-wide delegation in a non-GDU universe. + ud, err := u.universeDomainProvider.GetProperty(ctx) + if err != nil { + return nil, err + } + if ud != internal.DefaultUniverseDomain { + return nil, errUniverseNotSupportedDomainWideDelegation + } signedJWT, err := u.signJWT(ctx) if err != nil { return nil, err diff --git a/auth/credentials/impersonate/user_test.go b/auth/credentials/impersonate/user_test.go index adb4612d5eca..87b897a1c248 100644 --- a/auth/credentials/impersonate/user_test.go +++ b/auth/credentials/impersonate/user_test.go @@ -37,6 +37,7 @@ func TestNewCredentials_user(t *testing.T) { lifetime time.Duration subject string wantErr bool + wantTokenErr bool universeDomain string }{ { @@ -60,14 +61,13 @@ func TestNewCredentials_user(t *testing.T) { targetPrincipal: "foo@project-id.iam.gserviceaccount.com", scopes: []string{"scope"}, subject: "admin@example.com", - wantErr: false, }, { name: "universeDomain", targetPrincipal: "foo@project-id.iam.gserviceaccount.com", scopes: []string{"scope"}, subject: "admin@example.com", - wantErr: true, + wantTokenErr: true, // Non-GDU Universe Domain should result in error if // CredentialsConfig.Subject is present for domain-wide delegation. universeDomain: "example.com", @@ -152,6 +152,9 @@ func TestNewCredentials_user(t *testing.T) { t.Fatal(err) } tok, err := ts.Token(ctx) + if tt.wantTokenErr && err != nil { + return + } if err != nil { t.Fatal(err) } diff --git a/auth/httptransport/httptransport.go b/auth/httptransport/httptransport.go index 30fedf9562f9..38e8c99399bb 100644 --- a/auth/httptransport/httptransport.go +++ b/auth/httptransport/httptransport.go @@ -155,6 +155,8 @@ type InternalOptions struct { // transport that sets the Authorization header with the value produced by the // provided [cloud.google.com/go/auth.Credentials]. An error is returned only // if client or creds is nil. +// +// This function does not support setting a universe domain value on the client. func AddAuthorizationMiddleware(client *http.Client, creds *auth.Credentials) error { if client == nil || creds == nil { return fmt.Errorf("httptransport: client and tp must not be nil") @@ -173,7 +175,6 @@ func AddAuthorizationMiddleware(client *http.Client, creds *auth.Credentials) er client.Transport = &authTransport{ creds: creds, base: base, - // TODO(quartzmo): Somehow set clientUniverseDomain from impersonate calls. } return nil }