diff --git a/jwt/jwt.go b/jwt/jwt.go index e095d52..5805d5b 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "strings" "time" "github.com/go-jose/go-jose/v3" @@ -18,19 +19,26 @@ import ( // for validating the "nbf" (Not Before) and "exp" (Expiration Time) claims. const DefaultLeewaySeconds = 150 +type Validator interface { + Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) + ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) +} + // Validator validates JSON Web Tokens (JWT) by providing signature // verification and claims set validation. -type Validator struct { +type validator struct { keySet KeySet } +var _ Validator = (*validator)(nil) + // NewValidator returns a Validator that uses the given KeySet to verify JWT signatures. -func NewValidator(keySet KeySet) (*Validator, error) { +func NewValidator(keySet KeySet) (Validator, error) { if keySet == nil { return nil, errors.New("keySet must not be nil") } - return &Validator{ + return &validator{ keySet: keySet, }, nil } @@ -95,7 +103,7 @@ type Expected struct { // and "exp" (Expiration Time) claims and after the time given by the "iat" // (Issued At) claim, with configurable leeway. See Expected.Now() for details // on how the current time is provided for validation. -func (v *Validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { +func (v *validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { return v.validateAll(ctx, token, expected, false) } @@ -111,11 +119,11 @@ func (v *Validator) Validate(ctx context.Context, token string, expected Expecte // of "nbf", "exp", and "iat" are missing, then this check is skipped. See // Expected.Now() for details on how the current time is provided for // validation. -func (v *Validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { +func (v *validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { return v.validateAll(ctx, token, expected, true) } -func (v *Validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) { +func (v *validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) { // First, verify the signature to ensure subsequent validation is against verified claims allClaims, err := v.keySet.VerifySignature(ctx, token) if err != nil { @@ -227,6 +235,73 @@ func (v *Validator) validateAll(ctx context.Context, token string, expected Expe return allClaims, nil } +// multiValidator validates JSON Web Tokens (JWT) by providing signature +// verification and claims set validation. Unlike Validator, multiValidator +// supports multiple KeySet to verify JWT signatures. +// multiValidator is in now way optimized and is just stubbed out as a POC. +type multiValidator struct { + validators []Validator +} + +var _ Validator = (*multiValidator)(nil) + +// NewMultiValidator returns a Validator that uses the given slice of KeySet to verify JWT signatures. +func NewMultiValidator(keySets []KeySet) (Validator, error) { + if len(keySets) == 0 { + return nil, errors.New("must provide at least one key set") + } + + validators := make([]Validator, 0, len(keySets)) + + for _, keySet := range keySets { + v, err := NewValidator(keySet) + if err != nil { + return nil, err + } + validators = append(validators, v) + } + + return &multiValidator{ + validators: validators, + }, nil +} + +func (m *multiValidator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { + for _, v := range m.validators { + claims, err := v.Validate(ctx, token, expected) + if err != nil { + if strings.Contains(err.Error(), "error verifying token signature") { + // this isn't the right key set, try the next one + continue + } + // this is the right key set, but there was an error validating the claims + return nil, err + } + + return claims, nil + } + + return nil, errors.New("no key set was able to verify the token signature") +} + +func (m *multiValidator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) { + for _, v := range m.validators { + claims, err := v.ValidateAllowMissingIatNbfExp(ctx, token, expected) + if err != nil { + if strings.Contains(err.Error(), "error verifying token signature") { + // this isn't the right key set, try the next one + continue + } + // this is the right key set, but there was an error validating the claims + return nil, err + } + + return claims, nil + } + + return nil, errors.New("no key set was able to verify the token signature") +} + // validateSigningAlgorithm checks whether the JWS "alg" (Algorithm) header // parameter value for the given JWT matches any given in expectedAlgorithms. // If expectedAlgorithms is empty, RS256 will be expected by default. diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index bd69317..1c479af 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -7,6 +7,7 @@ import ( "context" "crypto/rand" "crypto/rsa" + "fmt" "strings" "testing" "time" @@ -18,15 +19,22 @@ import ( ) var priv *rsa.PrivateKey +var priv2 *rsa.PrivateKey func init() { // Generate a key to sign JWTs with throughout most test cases. - // It can be slow sometimes to generate a 4096-bit RSA key, so we only do it once. + // It can be slow sometimes to generate a 4096-bit RSA key, so we only + // generate the test keys once on initialization. var err error priv, err = rsa.GenerateKey(rand.Reader, 4096) if err != nil { panic(err) } + + priv2, err = rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + panic(err) + } } // TestValidator_Validate_Valid_JWT tests cases where a JWT is expected to be valid. @@ -605,6 +613,1012 @@ func TestNewValidator(t *testing.T) { } } +// TestMultiValidator_Validate_Valid_JWT tests cases where a JWT is expected to be valid. +func TestMultiValidator_Validate_Valid_JWT(t *testing.T) { + tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181)) + tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182)) + + // Create the KeySet to be used to verify JWT signatures + keySet1, err := NewOIDCDiscoveryKeySet(context.Background(), tp.Addr(), tp.CACert()) + require.NoError(t, err) + + tp.SetSigningKeys(priv, priv.Public(), oidc.RS256, testKeyID) + + keySet2, err := NewOIDCDiscoveryKeySet(context.Background(), tp2.Addr(), tp2.CACert()) + require.NoError(t, err) + + testKeyID2 := fmt.Sprintf("%s-2", testKeyID) + tp2.SetSigningKeys(priv, priv2.Public(), oidc.RS256, testKeyID2) + + // Establish past, now, and future for validation of time related claims + now := time.Now() + nowUnix := float64(now.Unix()) + pastUnix := float64(now.Add(-2 * jwt.DefaultLeeway).Unix()) + futureUnix := float64(now.Add(2 * jwt.DefaultLeeway).Unix()) + + type args struct { + claims map[string]interface{} + token func(map[string]interface{}) string + expected Expected + } + tests := []struct { + name string + args args + }{ + { + name: "valid jwt with assertion on issuer claim", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Issuer: "https://example.com/", + }, + }, + }, + { + name: "valid jwt with assertion on issuer claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Issuer: "https://example.com/", + }, + }, + }, + { + name: "valid jwt with assertion on subject claim", + args: args{ + claims: map[string]interface{}{ + "sub": "alice@example.com", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Subject: "alice@example.com", + }, + }, + }, + { + name: "valid jwt with assertion on subject claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "sub": "alice@example.com", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Subject: "alice@example.com", + }, + }, + }, + { + name: "valid jwt with assertion on id claim", + args: args{ + claims: map[string]interface{}{ + "jti": "abc123", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + ID: "abc123", + }, + }, + }, + { + name: "valid jwt with assertion on id claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "jti": "abc123", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + ID: "abc123", + }, + }, + }, + { + name: "valid jwt with assertion on audience claim", + args: args{ + claims: map[string]interface{}{ + "aud": []interface{}{"www.example.com", "www.other.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Audiences: []string{"www.example.com", "www.other.com"}, + }, + }, + }, + { + name: "valid jwt with assertion on audience claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "aud": []interface{}{"www.example.com", "www.other.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Audiences: []string{"www.example.com", "www.other.com"}, + }, + }, + }, + { + name: "valid jwt with assertion on algorithm header parameter", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS512), claims, []byte(testKeyID)) + }, + expected: Expected{ + SigningAlgorithms: []Alg{RS512}, + }, + }, + }, + { + name: "valid jwt with assertion on algorithm header parameter from key set 2", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS512), claims, []byte(testKeyID2)) + }, + expected: Expected{ + SigningAlgorithms: []Alg{RS512}, + }, + }, + }, + { + name: "valid jwt with assertions on all expected claims", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "sub": "alice@example.com", + "jti": "abc123", + "aud": []interface{}{"www.example.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Issuer: "https://example.com/", + Subject: "alice@example.com", + ID: "abc123", + Audiences: []string{"www.example.com"}, + SigningAlgorithms: []Alg{RS256}, + }, + }, + }, + { + name: "valid jwt with assertions on all expected claims from key set 2", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "sub": "alice@example.com", + "jti": "abc123", + "aud": []interface{}{"www.example.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Issuer: "https://example.com/", + Subject: "alice@example.com", + ID: "abc123", + Audiences: []string{"www.example.com"}, + SigningAlgorithms: []Alg{RS256}, + }, + }, + }, + { + name: "valid jwt with registered claims assertions skipped when empty", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "sub": "alice@example.com", + "jti": "abc123", + "aud": []interface{}{"www.example.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{}, + }, + }, + { + name: "valid jwt with registered claims assertions skipped when empty from key set 2", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "sub": "alice@example.com", + "jti": "abc123", + "aud": []interface{}{"www.example.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{}, + }, + }, + { + name: "valid jwt exp after exp leeway set", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + // The JWT exp would be invalid with exp leeway < 2 min + ExpirationLeeway: 2 * time.Minute, + ClockSkewLeeway: -1, + Now: func() time.Time { + return time.Unix(int64(futureUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt exp after exp leeway set from key set 2", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + // The JWT exp would be invalid with exp leeway < 2 min + ExpirationLeeway: 2 * time.Minute, + ClockSkewLeeway: -1, + Now: func() time.Time { + return time.Unix(int64(futureUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt nbf after nbf leeway set", + args: args{ + claims: map[string]interface{}{ + "exp": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + // The JWT nbf would be invalid with nbf leeway < 2 min + NotBeforeLeeway: 2 * time.Minute, + ClockSkewLeeway: -1, + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt nbf after nbf leeway set from key set 2", + args: args{ + claims: map[string]interface{}{ + "exp": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + // The JWT nbf would be invalid with nbf leeway < 2 min + NotBeforeLeeway: 2 * time.Minute, + ClockSkewLeeway: -1, + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt nbf after clock skew leeway", + args: args{ + claims: map[string]interface{}{ + "iat": pastUnix, + "nbf": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + // The JWT nbf would be invalid with clock skew leeway < 2 min + ClockSkewLeeway: 2 * time.Minute, + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt nbf after clock skew leeway from key set 2", + args: args{ + claims: map[string]interface{}{ + "iat": pastUnix, + "nbf": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + // The JWT nbf would be invalid with clock skew leeway < 2 min + ClockSkewLeeway: 2 * time.Minute, + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt exp after clock skew leeway", + args: args{ + claims: map[string]interface{}{ + "iat": pastUnix, + "nbf": pastUnix, + "exp": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + // The JWT exp would be invalid with clock skew leeway < 2 min + ClockSkewLeeway: 2 * time.Minute, + Now: func() time.Time { + return time.Unix(int64(futureUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt exp after clock skew leeway from key set 2", + args: args{ + claims: map[string]interface{}{ + "iat": pastUnix, + "nbf": pastUnix, + "exp": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + // The JWT exp would be invalid with clock skew leeway < 2 min + ClockSkewLeeway: 2 * time.Minute, + Now: func() time.Time { + return time.Unix(int64(futureUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt iat after clock skew leeway", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "nbf": pastUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + // The JWT iat would be invalid with clock skew leeway < 2 min + ClockSkewLeeway: 2 * time.Minute, + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + { + name: "valid jwt iat after clock skew leeway from key set 2", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "nbf": pastUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + // The JWT iat would be invalid with clock skew leeway < 2 min + ClockSkewLeeway: 2 * time.Minute, + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + // Create the signed JWT with the given claims + token := tt.args.token(tt.args.claims) + + // Create the validator with the KeySet + validator, err := NewMultiValidator([]KeySet{keySet1, keySet2}) + require.NoError(t, err) + + // Validate the JWT claims against expected values + got, err := validator.Validate(ctx, token, tt.args.expected) + + // Expect to get back the same claims that were serialized in the JWT + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, tt.args.claims, got) + }) + } +} + +func TestMultiValidator_NoExpIatNbf(t *testing.T) { + tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181)) + tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182)) + + // Create the KeySet to be used to verify JWT signatures + keySet1, err := NewOIDCDiscoveryKeySet(context.Background(), tp.Addr(), tp.CACert()) + require.NoError(t, err) + + tp.SetSigningKeys(priv, priv.Public(), oidc.RS256, testKeyID) + + keySet2, err := NewOIDCDiscoveryKeySet(context.Background(), tp2.Addr(), tp2.CACert()) + require.NoError(t, err) + + testKeyID2 := fmt.Sprintf("%s-2", testKeyID) + tp2.SetSigningKeys(priv, priv2.Public(), oidc.RS256, testKeyID2) + + type args struct { + claims map[string]interface{} + token func(map[string]interface{}) string + expected Expected + } + tests := []struct { + name string + args args + }{ + { + name: "valid jwt with assertion on issuer claim", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Issuer: "https://example.com/", + }, + }, + }, + { + name: "valid jwt with assertion on issuer claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Issuer: "https://example.com/", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + // Create the signed JWT with the given claims + token := tt.args.token(tt.args.claims) + + // Create the validator with the KeySet + validator, err := NewMultiValidator([]KeySet{keySet1, keySet2}) + require.NoError(t, err) + + // Validate the JWT claims against expected values + got, err := validator.ValidateAllowMissingIatNbfExp(ctx, token, tt.args.expected) + + // Expect to get back the same claims that were serialized in the JWT + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, tt.args.claims, got) + }) + } +} + +// TestValidator_Validate_Valid_JWT tests cases where a JWT is expected to be invalid. +func TestMultiValidator_Validate_Invalid_JWT(t *testing.T) { + tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181)) + tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182)) + + // Create the KeySet to be used to verify JWT signatures + keySet1, err := NewOIDCDiscoveryKeySet(context.Background(), tp.Addr(), tp.CACert()) + require.NoError(t, err) + + tp.SetSigningKeys(priv, priv.Public(), oidc.RS256, testKeyID) + + keySet2, err := NewOIDCDiscoveryKeySet(context.Background(), tp2.Addr(), tp2.CACert()) + require.NoError(t, err) + + testKeyID2 := fmt.Sprintf("%s-2", testKeyID) + tp2.SetSigningKeys(priv, priv2.Public(), oidc.RS256, testKeyID2) + + // Establish past, now, and future for validation of time related claims + now := time.Now() + nowUnix := float64(now.Unix()) + pastUnix := float64(now.Add(-2 * jwt.DefaultLeeway).Unix()) + futureUnix := float64(now.Add(2 * jwt.DefaultLeeway).Unix()) + + type args struct { + claims map[string]interface{} + token func(map[string]interface{}) string + expected Expected + } + tests := []struct { + name string + args args + }{ + { + name: "invalid jwt with assertion on issuer claim", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Issuer: "https://wrong.com/", + }, + }, + }, + { + name: "invalid jwt with assertion on issuer claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Issuer: "https://wrong.com/", + }, + }, + }, + { + name: "invalid jwt with assertion on subject claim", + args: args{ + claims: map[string]interface{}{ + "sub": "alice@example.com", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Subject: "bob@example.com", + }, + }, + }, + { + name: "invalid jwt with assertion on subject claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "sub": "alice@example.com", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Subject: "bob@example.com", + }, + }, + }, + { + name: "invalid jwt with assertion on id claim", + args: args{ + claims: map[string]interface{}{ + "jti": "abc123", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + ID: "123abc", + }, + }, + }, + { + name: "invalid jwt with assertion on id claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "jti": "abc123", + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + ID: "123abc", + }, + }, + }, + { + name: "invalid jwt with assertion on audience claim", + args: args{ + claims: map[string]interface{}{ + "aud": []interface{}{"www.other.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Audiences: []string{"www.example.com"}, + }, + }, + }, + { + name: "invalid jwt with assertion on audience claim from key set 2", + args: args{ + claims: map[string]interface{}{ + "aud": []interface{}{"www.other.com"}, + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Audiences: []string{"www.example.com"}, + }, + }, + }, + { + name: "invalid jwt with assertion on algorithm header parameter", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + SigningAlgorithms: []Alg{ES256}, + }, + }, + }, + { + name: "invalid jwt with assertion on algorithm header parameter from key set 2", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + SigningAlgorithms: []Alg{ES256}, + }, + }, + }, + { + name: "invalid jwt from failed signature verification", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + // Sign the JWT with a key not in the test provider + pk, err := rsa.GenerateKey(rand.Reader, 4096) + require.NoError(t, err) + return oidc.TestSignJWT(t, pk, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + SigningAlgorithms: []Alg{RS256}, + }, + }, + }, + { + name: "invalid jwt from failed signature verification from key set 2", + args: args{ + claims: map[string]interface{}{ + "iat": nowUnix, + "exp": futureUnix, + }, + token: func(claims map[string]interface{}) string { + // Sign the JWT with a key not in the test provider + pk, err := rsa.GenerateKey(rand.Reader, 4096) + require.NoError(t, err) + return oidc.TestSignJWT(t, pk, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + SigningAlgorithms: []Alg{RS256}, + }, + }, + }, + { + name: "invalid jwt with missing iat, nbf, and exp claims", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{}, + }, + }, + { + name: "invalid jwt with missing iat, nbf, and exp claims from key set 2", + args: args{ + claims: map[string]interface{}{ + "iss": "https://example.com/", + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{}, + }, + }, + { + name: "invalid jwt with now before nbf", + args: args{ + claims: map[string]interface{}{ + "nbf": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + { + name: "invalid jwt with now before nbf from key set 2", + args: args{ + claims: map[string]interface{}{ + "nbf": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Now: func() time.Time { + return time.Unix(int64(pastUnix), 0) + }, + }, + }, + }, + { + name: "invalid jwt with now after exp", + args: args{ + claims: map[string]interface{}{ + "exp": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Now: func() time.Time { + return time.Unix(int64(futureUnix), 0) + }, + }, + }, + }, + { + name: "invalid jwt with now after exp from key set 2", + args: args{ + claims: map[string]interface{}{ + "exp": nowUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Now: func() time.Time { + return time.Unix(int64(futureUnix), 0) + }, + }, + }, + }, + { + name: "invalid jwt with now before iat", + args: args{ + claims: map[string]interface{}{ + "nbf": pastUnix, + "iat": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID)) + }, + expected: Expected{ + Now: func() time.Time { + return time.Unix(int64(nowUnix), 0) + }, + }, + }, + }, + { + name: "invalid jwt with now before iat from key set 2", + args: args{ + claims: map[string]interface{}{ + "nbf": pastUnix, + "iat": futureUnix, + }, + token: func(claims map[string]interface{}) string { + return oidc.TestSignJWT(t, priv2, string(RS256), claims, []byte(testKeyID2)) + }, + expected: Expected{ + Now: func() time.Time { + return time.Unix(int64(nowUnix), 0) + }, + }, + }, + }, + { + name: "invalid malformed jwt", + args: args{ + token: func(claims map[string]interface{}) string { + return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + // Create the signed JWT with the given claims + token := tt.args.token(tt.args.claims) + + // Create the validator with the KeySet + validator, err := NewMultiValidator([]KeySet{keySet1, keySet2}) + require.NoError(t, err) + + // Validate the JWT claims against expected values + got, err := validator.Validate(ctx, token, tt.args.expected) + + // Expect an error and nil claims + require.Error(t, err) + require.Nil(t, got) + }) + } +} + +func TestNewMultiValidator(t *testing.T) { + type args struct { + keySets func() []KeySet + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "new multiValidator with single keySet", + args: args{ + keySets: func() []KeySet { + ks, err := NewJSONWebKeySet(context.Background(), + "https://issuer.com/"+wellKnownJWKS, "") + require.NoError(t, err) + return []KeySet{ks} + }, + }, + }, + { + name: "new multiValidator with multiple keySets", + args: args{ + keySets: func() []KeySet { + kSets := make([]KeySet, 0, 2) + ks, err := NewJSONWebKeySet(context.Background(), + "https://issuer.com/"+wellKnownJWKS, "") + require.NoError(t, err) + + kSets = append(kSets, ks) + ks, err = NewJSONWebKeySet(context.Background(), + "https://issuer2.com/"+wellKnownJWKS, "") + require.NoError(t, err) + + kSets = append(kSets, ks) + + return kSets + }, + }, + }, + { + name: "new multiValidator with no keySets", + args: args{ + keySets: func() []KeySet { + return nil + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewMultiValidator(tt.args.keySets()) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, got) + }) + } +} + func Test_validateAudience(t *testing.T) { type args struct { expectedAudiences []string