diff --git a/ecdsa.go b/ecdsa.go index f9773812..bc3129ca 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -53,8 +53,9 @@ func (m *SigningMethodECDSA) Alg() string { return m.Name } -// Implements the Verify method from SigningMethod -// For this verify method, key must be an ecdsa.PublicKey struct +// Implements the Verify method from SigningMethod. +// For this verify method, key must be in types of either *ecdsa.PublicKey or +// []*ecdsa.PublicKey (for rotation keys). func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -64,15 +65,6 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa return err } - // Get the key - var ecdsaKey *ecdsa.PublicKey - switch k := key.(type) { - case *ecdsa.PublicKey: - ecdsaKey = k - default: - return ErrInvalidKeyType - } - if len(sig) != 2*m.KeySize { return ErrECDSAVerification } @@ -80,19 +72,38 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa r := big.NewInt(0).SetBytes(sig[:m.KeySize]) s := big.NewInt(0).SetBytes(sig[m.KeySize:]) - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - // Verify the signature - if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { - return nil - } else { - return ErrECDSAVerification + var keys []*ecdsa.PublicKey + + // Get the key + switch v := key.(type) { + case *ecdsa.PublicKey: + keys = append(keys, v) + case []*ecdsa.PublicKey: + keys = v + default: + return ErrInvalidKeyType + } + + if len(keys) == 0 { + return ErrInvalidKeyType + } + var lastErr error + for _, ecdsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { + return nil + } + lastErr = ErrECDSAVerification } + return lastErr } // Implements the Sign method from SigningMethod diff --git a/hmac.go b/hmac.go index addbe5d4..38ddfc8d 100644 --- a/hmac.go +++ b/hmac.go @@ -47,12 +47,6 @@ func (m *SigningMethodHMAC) Alg() string { // Verify the signature of HSXXX tokens. Returns nil if the signature is valid. func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error { - // Verify the key is the right type - keyBytes, ok := key.([]byte) - if !ok { - return ErrInvalidKeyType - } - // Decode signature, for comparison sig, err := DecodeSegment(signature) if err != nil { @@ -64,17 +58,35 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac return ErrHashUnavailable } - // This signing method is symmetric, so we validate the signature - // by reproducing the signature from the signing string and key, then - // comparing that against the provided signature. - hasher := hmac.New(m.Hash.New, keyBytes) - hasher.Write([]byte(signingString)) - if !hmac.Equal(sig, hasher.Sum(nil)) { - return ErrSignatureInvalid + var keys [][]byte + + // Verify the key is the right type + switch v := key.(type) { + case []byte: + keys = append(keys, v) + case [][]byte: + keys = v + default: + return ErrInvalidKeyType } - // No validation errors. Signature is good. - return nil + if len(keys) == 0 { + return ErrInvalidKeyType + } + var lastErr error + for _, keyBytes := range keys { + // This signing method is symmetric, so we validate the signature + // by reproducing the signature from the signing string and key, then + // comparing that against the provided signature. + hasher := hmac.New(m.Hash.New, keyBytes) + hasher.Write([]byte(signingString)) + if hmac.Equal(sig, hasher.Sum(nil)) { + // No validation errors. Signature is good. + return nil + } + lastErr = ErrSignatureInvalid + } + return lastErr } // Implements the Sign method from SigningMethod for this signing method. diff --git a/hmac_test.go b/hmac_test.go index c7e114f4..95463ea9 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -62,6 +62,23 @@ func TestHMACVerify(t *testing.T) { } } +func TestHMACVerifyKeyRotation(t *testing.T) { + invalidKey1 := []byte("foo") + invalidKey2 := []byte("bar") + for _, data := range hmacTestData { + parts := strings.Split(data.tokenString, ".") + + method := jwt.GetSigningMethod(data.alg) + err := method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, hmacTestKey, invalidKey2}) + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying key: %v", data.name, err) + } + if !data.valid && err == nil { + t.Errorf("[%v] Invalid key passed validation", data.name) + } + } +} + func TestHMACSign(t *testing.T) { for _, data := range hmacTestData { if data.valid { diff --git a/rsa.go b/rsa.go index e4caf1ca..c8c93321 100644 --- a/rsa.go +++ b/rsa.go @@ -45,7 +45,8 @@ func (m *SigningMethodRSA) Alg() string { } // Implements the Verify method from SigningMethod -// For this signing method, must be an *rsa.PublicKey structure. +// For this signing method, key must be in types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -55,22 +56,37 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface return err } - var rsaKey *rsa.PublicKey - var ok bool + if !m.Hash.Available() { + return ErrHashUnavailable + } + + var keys []*rsa.PublicKey - if rsaKey, ok = key.(*rsa.PublicKey); !ok { + switch v := key.(type) { + case *rsa.PublicKey: + keys = append(keys, v) + case []*rsa.PublicKey: + keys = v + default: return ErrInvalidKeyType } - // Create hasher - if !m.Hash.Available() { - return ErrHashUnavailable + if len(keys) == 0 { + return ErrInvalidKeyType } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - - // Verify the signature - return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + var lastErr error + for _, rsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + lastErr = rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + if lastErr == nil { + return nil + } + } + return lastErr } // Implements the Sign method from SigningMethod diff --git a/rsa_pss.go b/rsa_pss.go index c0147086..cfaf51cc 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -80,7 +80,8 @@ func init() { } // Implements the Verify method from SigningMethod -// For this verify method, key must be an rsa.PublicKey struct +// For this verify method, key must be in the types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error { var err error @@ -90,27 +91,41 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf return err } - var rsaKey *rsa.PublicKey - switch k := key.(type) { + if !m.Hash.Available() { + return ErrHashUnavailable + } + + var keys []*rsa.PublicKey + + switch v := key.(type) { case *rsa.PublicKey: - rsaKey = k + keys = append(keys, v) + case []*rsa.PublicKey: + keys = v default: return ErrInvalidKey } - // Create hasher - if !m.Hash.Available() { - return ErrHashUnavailable + if len(keys) == 0 { + return ErrInvalidKeyType } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - - opts := m.Options - if m.VerifyOptions != nil { - opts = m.VerifyOptions + var lastErr error + for _, rsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + opts := m.Options + if m.VerifyOptions != nil { + opts = m.VerifyOptions + } + + lastErr = rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) + if lastErr == nil { + return nil + } } - - return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) + return lastErr } // Implements the Sign method from SigningMethod