Skip to content

Commit

Permalink
PR review, simplify naive multi JWKS implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
johnlanda committed Feb 2, 2024
1 parent 4b2bdd5 commit 8ecd48f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 167 deletions.
119 changes: 31 additions & 88 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/go-jose/go-jose/v3"
Expand All @@ -19,27 +18,26 @@ import (
// for validating the "nbf" (Not Before) and "exp" (Expiration Time) claims.
const DefaultLeewaySeconds = 150

type Validator interface {
Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error)
ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error)
}

// Validator validates JSON Web Tokens (JWT) by providing signature
// verification and claims set validation.
type validator struct {
keySet KeySet
type Validator struct {
keySets []KeySet
}

var _ Validator = (*validator)(nil)

// NewValidator returns a Validator that uses the given KeySet to verify JWT signatures.
func NewValidator(keySet KeySet) (Validator, error) {
if keySet == nil {
return nil, errors.New("keySet must not be nil")
func NewValidator(keySets ...KeySet) (*Validator, error) {
if len(keySets) <= 0 {
return nil, errors.New("must provide at least one key set")
}

for _, keySet := range keySets {
if keySet == nil {
return nil, errors.New("keySet must not be nil")
}
}

return &validator{
keySet: keySet,
return &Validator{
keySets: keySets,
}, nil
}

Expand Down Expand Up @@ -103,7 +101,7 @@ type Expected struct {
// and "exp" (Expiration Time) claims and after the time given by the "iat"
// (Issued At) claim, with configurable leeway. See Expected.Now() for details
// on how the current time is provided for validation.
func (v *validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
func (v *Validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
return v.validateAll(ctx, token, expected, false)
}

Expand All @@ -119,14 +117,26 @@ func (v *validator) Validate(ctx context.Context, token string, expected Expecte
// of "nbf", "exp", and "iat" are missing, then this check is skipped. See
// Expected.Now() for details on how the current time is provided for
// validation.
func (v *validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
func (v *Validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
return v.validateAll(ctx, token, expected, true)
}

func (v *validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) {
// First, verify the signature to ensure subsequent validation is against verified claims
allClaims, err := v.keySet.VerifySignature(ctx, token)
if err != nil {
func (v *Validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) {
var allClaims map[string]interface{}
var err error

// Ensure that the token is signed by at least one of the given key sets
var tokenVerified bool
for _, keySet := range v.keySets {
// First, verify the signature to ensure subsequent validation is against verified claims
allClaims, err = keySet.VerifySignature(ctx, token)
if err == nil {
tokenVerified = true
break
}
}

if !tokenVerified {
return nil, fmt.Errorf("error verifying token signature: %w", err)
}

Expand Down Expand Up @@ -235,73 +245,6 @@ func (v *validator) validateAll(ctx context.Context, token string, expected Expe
return allClaims, nil
}

// multiValidator validates JSON Web Tokens (JWT) by providing signature
// verification and claims set validation. Unlike Validator, multiValidator
// supports multiple KeySet to verify JWT signatures.
// multiValidator is in now way optimized and is just stubbed out as a POC.
type multiValidator struct {
validators []Validator
}

var _ Validator = (*multiValidator)(nil)

// NewMultiValidator returns a Validator that uses the given slice of KeySet to verify JWT signatures.
func NewMultiValidator(keySets []KeySet) (Validator, error) {
if len(keySets) == 0 {
return nil, errors.New("must provide at least one key set")
}

validators := make([]Validator, 0, len(keySets))

for _, keySet := range keySets {
v, err := NewValidator(keySet)
if err != nil {
return nil, err
}
validators = append(validators, v)
}

return &multiValidator{
validators: validators,
}, nil
}

func (m *multiValidator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
for _, v := range m.validators {
claims, err := v.Validate(ctx, token, expected)
if err != nil {
if strings.Contains(err.Error(), "error verifying token signature") {
// this isn't the right key set, try the next one
continue
}
// this is the right key set, but there was an error validating the claims
return nil, err
}

return claims, nil
}

return nil, errors.New("no key set was able to verify the token signature")
}

func (m *multiValidator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
for _, v := range m.validators {
claims, err := v.ValidateAllowMissingIatNbfExp(ctx, token, expected)
if err != nil {
if strings.Contains(err.Error(), "error verifying token signature") {
// this isn't the right key set, try the next one
continue
}
// this is the right key set, but there was an error validating the claims
return nil, err
}

return claims, nil
}

return nil, errors.New("no key set was able to verify the token signature")
}

// validateSigningAlgorithm checks whether the JWS "alg" (Algorithm) header
// parameter value for the given JWT matches any given in expectedAlgorithms.
// If expectedAlgorithms is empty, RS256 will be expected by default.
Expand Down
111 changes: 32 additions & 79 deletions jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ func TestValidator_Validate_Invalid_JWT(t *testing.T) {

func TestNewValidator(t *testing.T) {
type args struct {
keySet func() KeySet
keySets func() []KeySet
}
tests := []struct {
name string
Expand All @@ -582,27 +582,41 @@ func TestNewValidator(t *testing.T) {
{
name: "new validator with keySet",
args: args{
keySet: func() KeySet {
keySets: func() []KeySet {
ks, err := NewJSONWebKeySet(context.Background(),
"https://issuer.com/"+wellKnownJWKS, "")
require.NoError(t, err)
return ks
return []KeySet{ks}
},
},
},
{
name: "new validator with nil keySet",
args: args{
keySet: func() KeySet {
keySets: func() []KeySet {
return nil
},
},
wantErr: true,
},
{
name: "new validator with multiple keySets",
args: args{
keySets: func() []KeySet {
ks, err := NewJSONWebKeySet(context.Background(),
"https://issuer.com/"+wellKnownJWKS, "")
require.NoError(t, err)

ks2, err := NewJSONWebKeySet(context.Background(),
"https://issuer2.com/"+wellKnownJWKS, "")
return []KeySet{ks, ks2}
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewValidator(tt.args.keySet())
got, err := NewValidator(tt.args.keySets()...)
if tt.wantErr {
require.Error(t, err)
return
Expand All @@ -613,8 +627,9 @@ func TestNewValidator(t *testing.T) {
}
}

// TestMultiValidator_Validate_Valid_JWT tests cases where a JWT is expected to be valid.
func TestMultiValidator_Validate_Valid_JWT(t *testing.T) {
// TestValidator_MultipleKeySets_Validate_Valid_JWT tests cases where a JWT is expected to be valid where the
// validator is initialized with multiple KeySets.
func TestValidator_MultipleKeySets_Validate_Valid_JWT(t *testing.T) {
tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181))
tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182))

Expand Down Expand Up @@ -1088,11 +1103,11 @@ func TestMultiValidator_Validate_Valid_JWT(t *testing.T) {
token := tt.args.token(tt.args.claims)

// Create the validator with the KeySet
validator, err := NewMultiValidator([]KeySet{keySet1, keySet2})
v, err := NewValidator(keySet1, keySet2)
require.NoError(t, err)

// Validate the JWT claims against expected values
got, err := validator.Validate(ctx, token, tt.args.expected)
got, err := v.Validate(ctx, token, tt.args.expected)

// Expect to get back the same claims that were serialized in the JWT
require.NoError(t, err)
Expand All @@ -1102,7 +1117,7 @@ func TestMultiValidator_Validate_Valid_JWT(t *testing.T) {
}
}

func TestMultiValidator_NoExpIatNbf(t *testing.T) {
func TestValidator_MultipleKeySets_NoExpIatNbf(t *testing.T) {
tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181))
tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182))

Expand Down Expand Up @@ -1164,11 +1179,11 @@ func TestMultiValidator_NoExpIatNbf(t *testing.T) {
token := tt.args.token(tt.args.claims)

// Create the validator with the KeySet
validator, err := NewMultiValidator([]KeySet{keySet1, keySet2})
v, err := NewValidator(keySet1, keySet2)
require.NoError(t, err)

// Validate the JWT claims against expected values
got, err := validator.ValidateAllowMissingIatNbfExp(ctx, token, tt.args.expected)
got, err := v.ValidateAllowMissingIatNbfExp(ctx, token, tt.args.expected)

// Expect to get back the same claims that were serialized in the JWT
require.NoError(t, err)
Expand All @@ -1178,8 +1193,9 @@ func TestMultiValidator_NoExpIatNbf(t *testing.T) {
}
}

// TestValidator_Validate_Valid_JWT tests cases where a JWT is expected to be invalid.
func TestMultiValidator_Validate_Invalid_JWT(t *testing.T) {
// TestValidator_MultipleKeySets_Validate_Invalid_JWT tests cases where a JWT is expected to be invalid where the
// validator is initialized with multiple KeySets.
func TestValidator_MultipleKeySets_Validate_Invalid_JWT(t *testing.T) {
tp := oidc.StartTestProvider(t, oidc.WithTestPort(8181))
tp2 := oidc.StartTestProvider(t, oidc.WithTestPort(8182))

Expand Down Expand Up @@ -1543,11 +1559,11 @@ func TestMultiValidator_Validate_Invalid_JWT(t *testing.T) {
token := tt.args.token(tt.args.claims)

// Create the validator with the KeySet
validator, err := NewMultiValidator([]KeySet{keySet1, keySet2})
v, err := NewValidator(keySet1, keySet2)
require.NoError(t, err)

// Validate the JWT claims against expected values
got, err := validator.Validate(ctx, token, tt.args.expected)
got, err := v.Validate(ctx, token, tt.args.expected)

// Expect an error and nil claims
require.Error(t, err)
Expand All @@ -1556,69 +1572,6 @@ func TestMultiValidator_Validate_Invalid_JWT(t *testing.T) {
}
}

func TestNewMultiValidator(t *testing.T) {
type args struct {
keySets func() []KeySet
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "new multiValidator with single keySet",
args: args{
keySets: func() []KeySet {
ks, err := NewJSONWebKeySet(context.Background(),
"https://issuer.com/"+wellKnownJWKS, "")
require.NoError(t, err)
return []KeySet{ks}
},
},
},
{
name: "new multiValidator with multiple keySets",
args: args{
keySets: func() []KeySet {
kSets := make([]KeySet, 0, 2)
ks, err := NewJSONWebKeySet(context.Background(),
"https://issuer.com/"+wellKnownJWKS, "")
require.NoError(t, err)

kSets = append(kSets, ks)
ks, err = NewJSONWebKeySet(context.Background(),
"https://issuer2.com/"+wellKnownJWKS, "")
require.NoError(t, err)

kSets = append(kSets, ks)

return kSets
},
},
},
{
name: "new multiValidator with no keySets",
args: args{
keySets: func() []KeySet {
return nil
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewMultiValidator(tt.args.keySets())
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, got)
})
}
}

func Test_validateAudience(t *testing.T) {
type args struct {
expectedAudiences []string
Expand Down

0 comments on commit 8ecd48f

Please sign in to comment.