diff --git a/jwt/jwt.go b/jwt/jwt.go index 5805d5b..063e913 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" "time" "github.com/go-jose/go-jose/v3" @@ -19,27 +18,26 @@ import ( // for validating the "nbf" (Not Before) and "exp" (Expiration Time) claims. const DefaultLeewaySeconds = 150 -type Validator interface { - Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) - ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) -} - // Validator validates JSON Web Tokens (JWT) by providing signature // verification and claims set validation. -type validator struct { - keySet KeySet +type Validator struct { + keySets []KeySet } -var _ Validator = (*validator)(nil) - // NewValidator returns a Validator that uses the given KeySet to verify JWT signatures. -func NewValidator(keySet KeySet) (Validator, error) { - if keySet == nil { - return nil, errors.New("keySet must not be nil") +func NewValidator(keySets ...KeySet) (*Validator, error) { + if len(keySets) <= 0 { + return nil, errors.New("must provide at least one key set") + } + + for _, keySet := range keySets { + if keySet == nil { + return nil, errors.New("keySet must not be nil") + } } - return &validator{ - keySet: keySet, + return &Validator{ + keySets: keySets, }, nil } @@ -103,7 +101,7 @@ type Expected struct { // and "exp" (Expiration Time) claims and after the time given by the "iat" // (Issued At) claim, with configurable leeway. See Expected.Now() for details // on how the current time is provided for validation. -func (v *validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { +func (v *Validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { return v.validateAll(ctx, token, expected, false) } @@ -119,14 +117,26 @@ func (v *validator) Validate(ctx context.Context, token string, expected Expecte // of "nbf", "exp", and "iat" are missing, then this check is skipped. See // Expected.Now() for details on how the current time is provided for // validation. -func (v *validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { +func (v *Validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { return v.validateAll(ctx, token, expected, true) } -func (v *validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) { - // First, verify the signature to ensure subsequent validation is against verified claims - allClaims, err := v.keySet.VerifySignature(ctx, token) - if err != nil { +func (v *Validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) { + var allClaims map[string]interface{} + var err error + + // Ensure that the token is signed by at least one of the given key sets + var tokenVerified bool + for _, keySet := range v.keySets { + // First, verify the signature to ensure subsequent validation is against verified claims + allClaims, err = keySet.VerifySignature(ctx, token) + if err == nil { + tokenVerified = true + break + } + } + + if !tokenVerified { return nil, fmt.Errorf("error verifying token signature: %w", err) } @@ -235,73 +245,6 @@ func (v *validator) validateAll(ctx context.Context, token string, expected Expe return allClaims, nil } -// multiValidator validates JSON Web Tokens (JWT) by providing signature -// verification and claims set validation. Unlike Validator, multiValidator -// supports multiple KeySet to verify JWT signatures. -// multiValidator is in now way optimized and is just stubbed out as a POC. -type multiValidator struct { - validators []Validator -} - -var _ Validator = (*multiValidator)(nil) - -// NewMultiValidator returns a Validator that uses the given slice of KeySet to verify JWT signatures. -func NewMultiValidator(keySets []KeySet) (Validator, error) { - if len(keySets) == 0 { - return nil, errors.New("must provide at least one key set") - } - - validators := make([]Validator, 0, len(keySets)) - - for _, keySet := range keySets { - v, err := NewValidator(keySet) - if err != nil { - return nil, err - } - validators = append(validators, v) - } - - return &multiValidator{ - validators: validators, - }, nil -} - -func (m *multiValidator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { - for _, v := range m.validators { - claims, err := v.Validate(ctx, token, expected) - if err != nil { - if strings.Contains(err.Error(), "error verifying token signature") { - // this isn't the right key set, try the next one - continue - } - // this is the right key set, but there was an error validating the claims - return nil, err - } - - return claims, nil - } - - return nil, errors.New("no key set was able to verify the token signature") -} - -func (m *multiValidator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { - for _, v := range m.validators { - claims, err := v.ValidateAllowMissingIatNbfExp(ctx, token, expected) - if err != nil { - if strings.Contains(err.Error(), "error verifying token signature") { - // this isn't the right key set, try the next one - continue - } - // this is the right key set, but there was an error validating the claims - return nil, err - } - - return claims, nil - } - - return nil, errors.New("no key set was able to verify the token signature") -} - // validateSigningAlgorithm checks whether the JWS "alg" (Algorithm) header // parameter value for the given JWT matches any given in expectedAlgorithms. // If expectedAlgorithms is empty, RS256 will be expected by default. diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 1c479af..3739639 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -572,7 +572,7 @@ func TestValidator_Validate_Invalid_JWT(t *testing.T) { func TestNewValidator(t *testing.T) { type args struct { - keySet func() KeySet + keySets func() []KeySet } tests := []struct { name string @@ -582,27 +582,41 @@ func TestNewValidator(t *testing.T) { { name: "new validator with keySet", args: args{ - keySet: func() KeySet { + keySets: func() []KeySet { ks, err := NewJSONWebKeySet(context.Background(), "https://issuer.com/"+wellKnownJWKS, "") require.NoError(t, err) - return ks + return []KeySet{ks} }, }, }, { name: "new validator with nil keySet", args: args{ - keySet: func() KeySet { + keySets: func() []KeySet { return nil }, }, wantErr: true, }, + { + name: "new validator with multiple keySets", + args: args{ + keySets: func() []KeySet { + ks, err := NewJSONWebKeySet(context.Background(), + "https://issuer.com/"+wellKnownJWKS, "") + require.NoError(t, err) + + ks2, err := NewJSONWebKeySet(context.Background(), + "https://issuer2.com/"+wellKnownJWKS, "") + return []KeySet{ks, ks2} + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewValidator(tt.args.keySet()) + got, err := NewValidator(tt.args.keySets()...) if tt.wantErr { require.Error(t, err) return @@ -613,8 +627,9 @@ func TestNewValidator(t *testing.T) { } } -// TestMultiValidator_Validate_Valid_JWT tests cases where a JWT is expected to be valid. -func TestMultiValidator_Validate_Valid_JWT(t *testing.T) { +// TestValidator_MultipleKeySets_Validate_Valid_JWT tests cases where a JWT is expected to be valid where the +// validator is initialized with multiple KeySets. +func TestValidator_MultipleKeySets_Validate_Valid_JWT(t *testing.T) { tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181)) tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182)) @@ -1088,11 +1103,11 @@ func TestMultiValidator_Validate_Valid_JWT(t *testing.T) { token := tt.args.token(tt.args.claims) // Create the validator with the KeySet - validator, err := NewMultiValidator([]KeySet{keySet1, keySet2}) + v, err := NewValidator(keySet1, keySet2) require.NoError(t, err) // Validate the JWT claims against expected values - got, err := validator.Validate(ctx, token, tt.args.expected) + got, err := v.Validate(ctx, token, tt.args.expected) // Expect to get back the same claims that were serialized in the JWT require.NoError(t, err) @@ -1102,7 +1117,7 @@ func TestMultiValidator_Validate_Valid_JWT(t *testing.T) { } } -func TestMultiValidator_NoExpIatNbf(t *testing.T) { +func TestValidator_MultipleKeySets_NoExpIatNbf(t *testing.T) { tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181)) tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182)) @@ -1164,11 +1179,11 @@ func TestMultiValidator_NoExpIatNbf(t *testing.T) { token := tt.args.token(tt.args.claims) // Create the validator with the KeySet - validator, err := NewMultiValidator([]KeySet{keySet1, keySet2}) + v, err := NewValidator(keySet1, keySet2) require.NoError(t, err) // Validate the JWT claims against expected values - got, err := validator.ValidateAllowMissingIatNbfExp(ctx, token, tt.args.expected) + got, err := v.ValidateAllowMissingIatNbfExp(ctx, token, tt.args.expected) // Expect to get back the same claims that were serialized in the JWT require.NoError(t, err) @@ -1178,8 +1193,9 @@ func TestMultiValidator_NoExpIatNbf(t *testing.T) { } } -// TestValidator_Validate_Valid_JWT tests cases where a JWT is expected to be invalid. -func TestMultiValidator_Validate_Invalid_JWT(t *testing.T) { +// TestValidator_MultipleKeySets_Validate_Invalid_JWT tests cases where a JWT is expected to be invalid where the +// validator is initialized with multiple KeySets. +func TestValidator_MultipleKeySets_Validate_Invalid_JWT(t *testing.T) { tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181)) tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182)) @@ -1543,11 +1559,11 @@ func TestMultiValidator_Validate_Invalid_JWT(t *testing.T) { token := tt.args.token(tt.args.claims) // Create the validator with the KeySet - validator, err := NewMultiValidator([]KeySet{keySet1, keySet2}) + v, err := NewValidator(keySet1, keySet2) require.NoError(t, err) // Validate the JWT claims against expected values - got, err := validator.Validate(ctx, token, tt.args.expected) + got, err := v.Validate(ctx, token, tt.args.expected) // Expect an error and nil claims require.Error(t, err) @@ -1556,69 +1572,6 @@ func TestMultiValidator_Validate_Invalid_JWT(t *testing.T) { } } -func TestNewMultiValidator(t *testing.T) { - type args struct { - keySets func() []KeySet - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "new multiValidator with single keySet", - args: args{ - keySets: func() []KeySet { - ks, err := NewJSONWebKeySet(context.Background(), - "https://issuer.com/"+wellKnownJWKS, "") - require.NoError(t, err) - return []KeySet{ks} - }, - }, - }, - { - name: "new multiValidator with multiple keySets", - args: args{ - keySets: func() []KeySet { - kSets := make([]KeySet, 0, 2) - ks, err := NewJSONWebKeySet(context.Background(), - "https://issuer.com/"+wellKnownJWKS, "") - require.NoError(t, err) - - kSets = append(kSets, ks) - ks, err = NewJSONWebKeySet(context.Background(), - "https://issuer2.com/"+wellKnownJWKS, "") - require.NoError(t, err) - - kSets = append(kSets, ks) - - return kSets - }, - }, - }, - { - name: "new multiValidator with no keySets", - args: args{ - keySets: func() []KeySet { - return nil - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := NewMultiValidator(tt.args.keySets()) - if tt.wantErr { - require.Error(t, err) - return - } - require.NoError(t, err) - require.NotNil(t, got) - }) - } -} - func Test_validateAudience(t *testing.T) { type args struct { expectedAudiences []string