diff --git a/auth/credentials/idtoken/validate.go b/auth/credentials/idtoken/validate.go index 3b5e948d7db0..7abd1dab3a73 100644 --- a/auth/credentials/idtoken/validate.go +++ b/auth/credentials/idtoken/validate.go @@ -34,9 +34,11 @@ import ( ) const ( - es256KeySize int = 32 + es256KeySize int = 32 + // googleIAPCertsURL is used for ES256 Certs. googleIAPCertsURL string = "https://www.gstatic.com/iap/verify/public_key-jwk" - googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs" + // googleSACertsURL is used for RS256 Certs. + googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs" ) var ( @@ -67,18 +69,20 @@ type jwk struct { // Validator provides a way to validate Google ID Tokens type Validator struct { - client *cachingClient - rsa256CertsURL string - es256CertsURL string + client *cachingClient + rs256URL string + es256URL string } // ValidatorOptions provides a way to configure a [Validator]. type ValidatorOptions struct { // Client used to make requests to the certs URL. Optional. Client *http.Client - // Custom certs URL for RSA256 JWK to be used. Optional. - RSA256CertsURL string - // Custom certs URL for ES256 JWK to be used. Optional. + // Custom certs URL for RS256 JWK to be used. If not provided, the default + // Google oauth2 endpoint will be used. Optional. + RS256CertsURL string + // Custom certs URL for ES256 JWK to be used. If not provided, the default + // Google IAP endpoint will be used. Optional. ES256CertsURL string } @@ -91,17 +95,15 @@ func NewValidator(opts *ValidatorOptions) (*Validator, error) { } else { client = internal.DefaultClient() } - - rsa256CertsURL := googleSACertsURL - es256CertsURL := googleIAPCertsURL - if opts != nil && opts.RSA256CertsURL != "" { - rsa256CertsURL = opts.RSA256CertsURL + var rs256URL string + if opts != nil { + rs256URL = opts.RS256CertsURL } - if opts != nil && opts.ES256CertsURL != "" { - es256CertsURL = opts.ES256CertsURL + var es256URL string + if opts != nil { + es256URL = opts.ES256CertsURL } - - return &Validator{client: newCachingClient(client), rsa256CertsURL: rsa256CertsURL, es256CertsURL: es256CertsURL}, nil + return &Validator{client: newCachingClient(client), rs256URL: rs256URL, es256URL: es256URL}, nil } // Validate is used to validate the provided idToken with a known Google cert @@ -153,11 +155,11 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin hashedContent := hashHeaderPayload(idToken) switch header.Algorithm { case jwt.HeaderAlgRSA256: - if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig, v.rsa256CertsURL); err != nil { + if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig); err != nil { return nil, err } case jwt.HeaderAlgES256: - if err := v.validateES256(ctx, header.KeyID, hashedContent, sig, v.es256CertsURL); err != nil { + if err := v.validateES256(ctx, header.KeyID, hashedContent, sig); err != nil { return nil, err } default: @@ -167,8 +169,8 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin return payload, nil } -func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte, certsURL string) error { - certResp, err := v.client.getCert(ctx, certsURL) +func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { + certResp, err := v.client.getCert(ctx, v.rs256CertsURL()) if err != nil { return err } @@ -192,8 +194,15 @@ func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedConte return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig) } -func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte, certsURL string) error { - certResp, err := v.client.getCert(ctx, certsURL) +func (v *Validator) rs256CertsURL() string { + if v.rs256URL == "" { + return googleSACertsURL + } + return v.rs256URL +} + +func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { + certResp, err := v.client.getCert(ctx, v.es256CertsURL()) if err != nil { return err } @@ -223,6 +232,13 @@ func (v *Validator) validateES256(ctx context.Context, keyID string, hashedConte return nil } +func (v *Validator) es256CertsURL() string { + if v.es256URL == "" { + return googleIAPCertsURL + } + return v.es256URL +} + func findMatchingKey(response *certResponse, keyID string) (*jwk, error) { if response == nil { return nil, fmt.Errorf("idtoken: cert response is nil") diff --git a/auth/credentials/idtoken/validate_test.go b/auth/credentials/idtoken/validate_test.go index fe6f1510ef36..8a100da19b89 100644 --- a/auth/credentials/idtoken/validate_test.go +++ b/auth/credentials/idtoken/validate_test.go @@ -138,8 +138,8 @@ func TestValidateRS256(t *testing.T) { now = tt.nowFunc v, err := NewValidator(&ValidatorOptions{ - Client: client, - RSA256CertsURL: tt.certsURL, + Client: client, + RS256CertsURL: tt.certsURL, }) if err != nil { t.Fatalf("NewValidator(...) = %q, want nil", err)