diff --git a/auth/credentials/idtoken/validate.go b/auth/credentials/idtoken/validate.go index 4b17af202118..3b5e948d7db0 100644 --- a/auth/credentials/idtoken/validate.go +++ b/auth/credentials/idtoken/validate.go @@ -67,13 +67,19 @@ type jwk struct { // Validator provides a way to validate Google ID Tokens type Validator struct { - client *cachingClient + client *cachingClient + rsa256CertsURL string + es256CertsURL string } // ValidatorOptions provides a way to configure a [Validator]. type ValidatorOptions struct { // Client used to make requests to the certs URL. Optional. Client *http.Client + // Custom certs URL for RSA256 JWK to be used. Optional. + RSA256CertsURL string + // Custom certs URL for ES256 JWK to be used. Optional. + ES256CertsURL string } // NewValidator creates a Validator that uses the options provided to configure @@ -85,7 +91,17 @@ func NewValidator(opts *ValidatorOptions) (*Validator, error) { } else { client = internal.DefaultClient() } - return &Validator{client: newCachingClient(client)}, nil + + rsa256CertsURL := googleSACertsURL + es256CertsURL := googleIAPCertsURL + if opts != nil && opts.RSA256CertsURL != "" { + rsa256CertsURL = opts.RSA256CertsURL + } + if opts != nil && opts.ES256CertsURL != "" { + es256CertsURL = opts.ES256CertsURL + } + + return &Validator{client: newCachingClient(client), rsa256CertsURL: rsa256CertsURL, es256CertsURL: es256CertsURL}, nil } // Validate is used to validate the provided idToken with a known Google cert @@ -137,11 +153,11 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin hashedContent := hashHeaderPayload(idToken) switch header.Algorithm { case jwt.HeaderAlgRSA256: - if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig); err != nil { + if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig, v.rsa256CertsURL); err != nil { return nil, err } - case "ES256": - if err := v.validateES256(ctx, header.KeyID, hashedContent, sig); err != nil { + case jwt.HeaderAlgES256: + if err := v.validateES256(ctx, header.KeyID, hashedContent, sig, v.es256CertsURL); err != nil { return nil, err } default: @@ -151,8 +167,8 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin return payload, nil } -func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { - certResp, err := v.client.getCert(ctx, googleSACertsURL) +func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte, certsURL string) error { + certResp, err := v.client.getCert(ctx, certsURL) if err != nil { return err } @@ -176,8 +192,8 @@ func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedConte return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig) } -func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { - certResp, err := v.client.getCert(ctx, googleIAPCertsURL) +func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte, certsURL string) error { + certResp, err := v.client.getCert(ctx, certsURL) if err != nil { return err } diff --git a/auth/credentials/idtoken/validate_test.go b/auth/credentials/idtoken/validate_test.go index afabe814c3a2..fe6f1510ef36 100644 --- a/auth/credentials/idtoken/validate_test.go +++ b/auth/credentials/idtoken/validate_test.go @@ -49,44 +49,60 @@ var ( func TestValidateRS256(t *testing.T) { idToken, pk := createRS256JWT(t) tests := []struct { - name string - keyID string - n *big.Int - e int - nowFunc func() time.Time - wantErr bool + name string + keyID string + certsURL string + n *big.Int + e int + nowFunc func() time.Time + wantErr bool + wantCertsURL string }{ { - name: "works", - keyID: keyID, - n: pk.N, - e: pk.E, - nowFunc: beforeExp, - wantErr: false, + name: "works", + keyID: keyID, + n: pk.N, + e: pk.E, + nowFunc: beforeExp, + wantErr: false, + wantCertsURL: googleSACertsURL, }, { - name: "no matching key", - keyID: "5678", - n: pk.N, - e: pk.E, - nowFunc: beforeExp, - wantErr: true, + name: "works with custom certs url", + keyID: keyID, + certsURL: "https://www.googleapis.com/service_accounts/v1/jwk/chat@system.gserviceaccount.com", + n: pk.N, + e: pk.E, + nowFunc: beforeExp, + wantErr: false, + wantCertsURL: "https://www.googleapis.com/service_accounts/v1/jwk/chat@system.gserviceaccount.com", }, { - name: "sig does not match", - keyID: keyID, - n: new(big.Int).SetBytes([]byte("42")), - e: 42, - nowFunc: beforeExp, - wantErr: true, + name: "no matching key", + keyID: "5678", + n: pk.N, + e: pk.E, + nowFunc: beforeExp, + wantErr: true, + wantCertsURL: googleSACertsURL, }, { - name: "token expired", - keyID: keyID, - n: pk.N, - e: pk.E, - nowFunc: afterExp, - wantErr: true, + name: "sig does not match", + keyID: keyID, + n: new(big.Int).SetBytes([]byte("42")), + e: 42, + nowFunc: beforeExp, + wantErr: true, + wantCertsURL: googleSACertsURL, + }, + { + name: "token expired", + keyID: keyID, + n: pk.N, + e: pk.E, + nowFunc: afterExp, + wantErr: true, + wantCertsURL: googleSACertsURL, }, } @@ -94,6 +110,9 @@ func TestValidateRS256(t *testing.T) { t.Run(tt.name, func(t *testing.T) { client := &http.Client{ Transport: RoundTripFn(func(req *http.Request) *http.Response { + if req.URL.String() != tt.wantCertsURL { + t.Fatalf("Invalid request uri, want %v got %v", tt.wantCertsURL, req.URL.String()) + } cr := certResponse{ Keys: []jwk{ { @@ -119,7 +138,8 @@ func TestValidateRS256(t *testing.T) { now = tt.nowFunc v, err := NewValidator(&ValidatorOptions{ - Client: client, + Client: client, + RSA256CertsURL: tt.certsURL, }) if err != nil { t.Fatalf("NewValidator(...) = %q, want nil", err) @@ -162,50 +182,69 @@ func TestValidateRS256(t *testing.T) { func TestValidateES256(t *testing.T) { idToken, pk := createES256JWT(t) tests := []struct { - name string - keyID string - x *big.Int - y *big.Int - nowFunc func() time.Time - wantErr bool + name string + keyID string + certsURL string + x *big.Int + y *big.Int + nowFunc func() time.Time + wantErr bool + wantCertsURL string }{ { - name: "works", - keyID: keyID, - x: pk.X, - y: pk.Y, - nowFunc: beforeExp, - wantErr: false, + name: "works", + keyID: keyID, + x: pk.X, + y: pk.Y, + nowFunc: beforeExp, + wantErr: false, + wantCertsURL: googleIAPCertsURL, }, { - name: "no matching key", - keyID: "5678", - x: pk.X, - y: pk.Y, - nowFunc: beforeExp, - wantErr: true, + name: "works with custom certs url", + keyID: keyID, + certsURL: "http://example.com", + x: pk.X, + y: pk.Y, + nowFunc: beforeExp, + wantErr: false, + wantCertsURL: "http://example.com", }, { - name: "sig does not match", - keyID: keyID, - x: new(big.Int), - y: new(big.Int), - nowFunc: beforeExp, - wantErr: true, + name: "no matching key", + keyID: "5678", + x: pk.X, + y: pk.Y, + nowFunc: beforeExp, + wantErr: true, + wantCertsURL: googleIAPCertsURL, }, { - name: "token expired", - keyID: keyID, - x: pk.X, - y: pk.Y, - nowFunc: afterExp, - wantErr: true, + name: "sig does not match", + keyID: keyID, + x: new(big.Int), + y: new(big.Int), + nowFunc: beforeExp, + wantErr: true, + wantCertsURL: googleIAPCertsURL, + }, + { + name: "token expired", + keyID: keyID, + x: pk.X, + y: pk.Y, + nowFunc: afterExp, + wantErr: true, + wantCertsURL: googleIAPCertsURL, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := &http.Client{ Transport: RoundTripFn(func(req *http.Request) *http.Response { + if req.URL.String() != tt.wantCertsURL { + t.Fatalf("Invalid request uri, want %v got %v", tt.wantCertsURL, req.URL.String()) + } cr := certResponse{ Keys: []jwk{ { @@ -231,7 +270,8 @@ func TestValidateES256(t *testing.T) { now = tt.nowFunc v, err := NewValidator(&ValidatorOptions{ - Client: client, + Client: client, + ES256CertsURL: tt.certsURL, }) if err != nil { t.Fatalf("NewValidator(...) = %q, want nil", err)