Skip to content

Commit

Permalink
Add support for multiple JWKS urls
Browse files Browse the repository at this point in the history
  • Loading branch information
johnlanda committed Feb 1, 2024
1 parent 1d58e0e commit 4b2bdd5
Show file tree
Hide file tree
Showing 2 changed files with 1,096 additions and 7 deletions.
87 changes: 81 additions & 6 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/go-jose/go-jose/v3"
Expand All @@ -18,19 +19,26 @@ import (
// for validating the "nbf" (Not Before) and "exp" (Expiration Time) claims.
const DefaultLeewaySeconds = 150

type Validator interface {
Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error)
ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error)
}

// Validator validates JSON Web Tokens (JWT) by providing signature
// verification and claims set validation.
type Validator struct {
type validator struct {
keySet KeySet
}

var _ Validator = (*validator)(nil)

// NewValidator returns a Validator that uses the given KeySet to verify JWT signatures.
func NewValidator(keySet KeySet) (*Validator, error) {
func NewValidator(keySet KeySet) (Validator, error) {
if keySet == nil {
return nil, errors.New("keySet must not be nil")
}

return &Validator{
return &validator{
keySet: keySet,
}, nil
}
Expand Down Expand Up @@ -95,7 +103,7 @@ type Expected struct {
// and "exp" (Expiration Time) claims and after the time given by the "iat"
// (Issued At) claim, with configurable leeway. See Expected.Now() for details
// on how the current time is provided for validation.
func (v *Validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
func (v *validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
return v.validateAll(ctx, token, expected, false)
}

Expand All @@ -111,11 +119,11 @@ func (v *Validator) Validate(ctx context.Context, token string, expected Expecte
// of "nbf", "exp", and "iat" are missing, then this check is skipped. See
// Expected.Now() for details on how the current time is provided for
// validation.
func (v *Validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
func (v *validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
return v.validateAll(ctx, token, expected, true)
}

func (v *Validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) {
func (v *validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) {
// First, verify the signature to ensure subsequent validation is against verified claims
allClaims, err := v.keySet.VerifySignature(ctx, token)
if err != nil {
Expand Down Expand Up @@ -227,6 +235,73 @@ func (v *Validator) validateAll(ctx context.Context, token string, expected Expe
return allClaims, nil
}

// multiValidator validates JSON Web Tokens (JWT) by providing signature
// verification and claims set validation. Unlike Validator, multiValidator
// supports multiple KeySet to verify JWT signatures.
// multiValidator is in now way optimized and is just stubbed out as a POC.
type multiValidator struct {
validators []Validator
}

var _ Validator = (*multiValidator)(nil)

// NewMultiValidator returns a Validator that uses the given slice of KeySet to verify JWT signatures.
func NewMultiValidator(keySets []KeySet) (Validator, error) {
if len(keySets) == 0 {
return nil, errors.New("must provide at least one key set")
}

validators := make([]Validator, 0, len(keySets))

for _, keySet := range keySets {
v, err := NewValidator(keySet)
if err != nil {
return nil, err
}
validators = append(validators, v)
}

return &multiValidator{
validators: validators,
}, nil
}

func (m *multiValidator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
for _, v := range m.validators {
claims, err := v.Validate(ctx, token, expected)
if err != nil {
if strings.Contains(err.Error(), "error verifying token signature") {
// this isn't the right key set, try the next one
continue
}
// this is the right key set, but there was an error validating the claims
return nil, err
}

return claims, nil
}

return nil, errors.New("no key set was able to verify the token signature")
}

func (m *multiValidator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
for _, v := range m.validators {
claims, err := v.ValidateAllowMissingIatNbfExp(ctx, token, expected)
if err != nil {
if strings.Contains(err.Error(), "error verifying token signature") {
// this isn't the right key set, try the next one
continue
}
// this is the right key set, but there was an error validating the claims
return nil, err
}

return claims, nil
}

return nil, errors.New("no key set was able to verify the token signature")
}

// validateSigningAlgorithm checks whether the JWS "alg" (Algorithm) header
// parameter value for the given JWT matches any given in expectedAlgorithms.
// If expectedAlgorithms is empty, RS256 will be expected by default.
Expand Down
Loading

0 comments on commit 4b2bdd5

Please sign in to comment.