diff --git a/docs/99-faq.md b/docs/99-faq.md index b71383dc3..800615838 100644 --- a/docs/99-faq.md +++ b/docs/99-faq.md @@ -77,3 +77,30 @@ After all, the only thing we can say is that you should check that you are parsi ## Why are you generating so many fields? Because a lot of the code is repetitive. For example, maintaining the 15 fields in a JWE header in all parts of the code (getter methods, setter methods, marshaling/unmarshaling) is doable but very very very cumbersome. We think that resources used to watching out for typos and other minor problems that may arise during maintenance is better spent elsewhere by automating generation of consistent code. + +## Why is (jwk.Key).Algorithm() and jwa.KeyAlgorithm so confusing? + +To start, we sympathize. Please read on for the reason(s) why things are the way they are. + +First you must understand that JWKs can be used for multiple different purposes, including but not limited to JWS and JWE (key encryption). And the `alg` field is supposed to carry to what purpose the JWK is supposed to be used. + +This means that a JWK can, in jwx terms, carry either `jwa.SignatureAlgorithm` or `jwa.KeyEncryptionAlgorithm`. + +In order to allow passing either `jwa.SignatureAlgorithm` or `jwa.KeyEncrypionAlgorithm`, we initially implemented +`(jwk.Key).Algorithm()` as a string, so it was possible to just change the type depending on the situation. + +This caused a bit of confusion for some users because this field was the only "untyped" field that potentially could have been typed. Most notably, some people wanted to do the following, but couldn't: + +```go +jwt.Verify(token, jwt.WithVerify(key.Algorithm(), key)) +``` + +Since version 2.0.0 `jwk.Key` now stores the `alg` field as a `jwa.KeyAlgorithm` type, which is just an interface that covers `jwa.SignatureAlgorithm`, `jwa.KeyEncryptionAlgorithm`, or any other type that we may need to represent in the future. + +Now you should be able to just pass the `alg` value to most high-level functions and methods such as `jwt.Verify`, `jws.Sign`, and `jwe.Encrypt` + +### When do we use `jwa.KeyAlgorithm` + +There are some functions that accept `jwa.KeyAlgorithm`, while there are others that expect `jwa.SignatureAlgorithm` or `jwa.KeyEncryptionAlgorithm`. So when do we use which? + +The guideline is as follows: If it's a high-level function/method that the users regularly use, use `jwa.KeyAlgorithm`. For example, almost everybody who use `jwt` will want to verify the JWS signed payload, so `jwt.Sign()`, and `jwt.Verify()` expect `jwa.KeyAlgorithm`. On the other hand, `jwt.Serializer` uses `jwa.SignatureAlgorithm` and such. This is a low-level utility, and users are not really meant to use it for their most basic needs: therefore they use the specific algorithm type. diff --git a/jwa/jwa.go b/jwa/jwa.go index c34ad166d..3fb1e13d1 100644 --- a/jwa/jwa.go +++ b/jwa/jwa.go @@ -2,3 +2,60 @@ // Package jwa defines the various algorithm described in https://tools.ietf.org/html/rfc7518 package jwa + +import "fmt" + +// KeyAlgorithm is a workaround for jwk.Key being able to contain different +// types of algorithms in its `alg` field. +// +// Previously the storage for the `alg` field was represented as a string, +// but this caused some users to wonder why the field was not typed appropriately +// like other fields. +// +// Ideally we would like to keep track of Signature Algorithms and +// Content Encryption Algorithms separately, and force the APIs to +// type-check at compile time, but this allows users to pass a value from a +// jwk.Key directly +type KeyAlgorithm interface { + String() string +} + +// InvalidKeyAlgorithm represents an algorithm that the library is not aware of. +type InvalidKeyAlgorithm string + +func (s InvalidKeyAlgorithm) String() string { + return string(s) +} + +func (InvalidKeyAlgorithm) Accept(_ interface{}) error { + return fmt.Errorf(`jwa.InvalidKeyAlgorithm does not support Accept() method calls`) +} + +// KeyAlgorithmFrom takes either a string, `jwa.SignatureAlgorithm` or `jwa.KeyEncryptionAlgorithm` +// and returns a `jwa.KeyAlgorithm`. +// +// If the value cannot be handled, it returns an `jwa.InvalidKeyAlgorithm` +// object instead of returning an error. This design choice was made to allow +// users to directly pass the return value to functions such as `jws.Sign()` +func KeyAlgorithmFrom(v interface{}) KeyAlgorithm { + switch v := v.(type) { + case SignatureAlgorithm: + return v + case ContentEncryptionAlgorithm: + return v + case string: + var salg SignatureAlgorithm + if err := salg.Accept(v); err == nil { + return salg + } + + var kealg KeyEncryptionAlgorithm + if err := kealg.Accept(v); err == nil { + return kealg + } + + return InvalidKeyAlgorithm(v) + default: + return InvalidKeyAlgorithm(fmt.Sprintf("%s", v)) + } +} diff --git a/jwe/gh402_test.go b/jwe/gh402_test.go index 5330dcd81..c82aa2c70 100644 --- a/jwe/gh402_test.go +++ b/jwe/gh402_test.go @@ -66,7 +66,9 @@ func TestGH402(t *testing.T) { m := jwe.NewMessage() // Test WithPostParse while we're at it plain, err := jwe.Decrypt([]byte(data), - "invalid algorithm", + // This is a really cheesy way of creating a jwa.KeyEncryptionAlgorithm + // but a bogus one. + jwa.KeyEncryptionAlgorithm("invalid algorithm"), nil, jwe.WithMessage(m), jwe.WithPostParser(pp), diff --git a/jwe/jwe.go b/jwe/jwe.go index 8b45287f2..2e184b1e8 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -28,8 +28,17 @@ var registry = json.NewRegistry() // Encrypt takes the plaintext payload and encrypts it in JWE compact format. // `key` should be a public key, and it may be a raw key (e.g. rsa.PublicKey) or a jwk.Key // +// `alg` accepts a `jwa.KeyAlgorithm` for convenience so you can directly pass +// the result of `(jwk.Key).Algorithm()`, but in practice it must be of type +// `jwa.KeyEncryptionAlgorithm` or otherwise it will cause an error. +// // Encrypt currently does not support multi-recipient messages. -func Encrypt(payload []byte, keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) ([]byte, error) { +func Encrypt(payload []byte, alg jwa.KeyAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) ([]byte, error) { + keyalg, ok := alg.(jwa.KeyEncryptionAlgorithm) + if !ok { + return nil, errors.Errorf(`expected alg to be jwa.KeyEncryptionAlgorithm, but got %T`, alg) + } + var protected Headers for _, option := range options { //nolint:forcetypeassert @@ -212,11 +221,20 @@ func (ctx *decryptCtx) SetMessage(m *Message) { // key to decrypt the JWE message, and returns the decrypted payload. // The JWE message can be either compact or full JSON format. // +// `alg` accepts a `jwa.KeyAlgorithm` for convenience so you can directly pass +// the result of `(jwk.Key).Algorithm()`, but in practice it must be of type +// `jwa.KeyEncryptionAlgorithm` or otherwise it will cause an error. +// // `key` must be a private key. It can be either in its raw format (e.g. *rsa.PrivateKey) or a jwk.Key -func Decrypt(buf []byte, alg jwa.KeyEncryptionAlgorithm, key interface{}, options ...DecryptOption) ([]byte, error) { +func Decrypt(buf []byte, alg jwa.KeyAlgorithm, key interface{}, options ...DecryptOption) ([]byte, error) { + keyalg, ok := alg.(jwa.KeyEncryptionAlgorithm) + if !ok { + return nil, errors.Errorf(`expected alg to be jwa.KeyEncryptionAlgorithm, but got %T`, alg) + } + var ctx decryptCtx ctx.key = key - ctx.alg = alg + ctx.alg = keyalg var dst *Message var postParse PostParser diff --git a/jwk/ecdsa_gen.go b/jwk/ecdsa_gen.go index e5d6374b6..9dd3c0772 100644 --- a/jwk/ecdsa_gen.go +++ b/jwk/ecdsa_gen.go @@ -36,7 +36,7 @@ type ECDSAPublicKey interface { } type ecdsaPublicKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4 crv *jwa.EllipticCurveAlgorithm keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 keyOps *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 @@ -67,11 +67,11 @@ func (h ecdsaPublicKey) KeyType() jwa.KeyType { return jwa.EC } -func (h *ecdsaPublicKey) Algorithm() string { +func (h *ecdsaPublicKey) Algorithm() jwa.KeyAlgorithm { if h.algorithm != nil { return *(h.algorithm) } - return "" + return jwa.InvalidKeyAlgorithm("") } func (h *ecdsaPublicKey) Crv() jwa.EllipticCurveAlgorithm { @@ -266,10 +266,12 @@ func (h *ecdsaPublicKey) setNoLock(name string, value interface{}) error { return nil case AlgorithmKey: switch v := value.(type) { - case string: - h.algorithm = &v + case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm: + var tmp = jwa.KeyAlgorithmFrom(v) + h.algorithm = &tmp case fmt.Stringer: - tmp := v.String() + s := v.String() + var tmp = jwa.KeyAlgorithmFrom(s) h.algorithm = &tmp default: return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value) @@ -442,9 +444,12 @@ LOOP: return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val) } case AlgorithmKey: - if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil { + var s string + if err := dec.Decode(&s); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey) } + alg := jwa.KeyAlgorithmFrom(s) + h.algorithm = &alg case ECDSACrvKey: var decoded jwa.EllipticCurveAlgorithm if err := dec.Decode(&decoded); err != nil { @@ -597,7 +602,7 @@ type ECDSAPrivateKey interface { } type ecdsaPrivateKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4 crv *jwa.EllipticCurveAlgorithm d []byte keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 @@ -629,11 +634,11 @@ func (h ecdsaPrivateKey) KeyType() jwa.KeyType { return jwa.EC } -func (h *ecdsaPrivateKey) Algorithm() string { +func (h *ecdsaPrivateKey) Algorithm() jwa.KeyAlgorithm { if h.algorithm != nil { return *(h.algorithm) } - return "" + return jwa.InvalidKeyAlgorithm("") } func (h *ecdsaPrivateKey) Crv() jwa.EllipticCurveAlgorithm { @@ -840,10 +845,12 @@ func (h *ecdsaPrivateKey) setNoLock(name string, value interface{}) error { return nil case AlgorithmKey: switch v := value.(type) { - case string: - h.algorithm = &v + case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm: + var tmp = jwa.KeyAlgorithmFrom(v) + h.algorithm = &tmp case fmt.Stringer: - tmp := v.String() + s := v.String() + var tmp = jwa.KeyAlgorithmFrom(s) h.algorithm = &tmp default: return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value) @@ -1025,9 +1032,12 @@ LOOP: return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val) } case AlgorithmKey: - if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil { + var s string + if err := dec.Decode(&s); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey) } + alg := jwa.KeyAlgorithmFrom(s) + h.algorithm = &alg case ECDSACrvKey: var decoded jwa.EllipticCurveAlgorithm if err := dec.Decode(&decoded); err != nil { diff --git a/jwk/headers_test.go b/jwk/headers_test.go index 389a440e7..236e42617 100644 --- a/jwk/headers_test.go +++ b/jwk/headers_test.go @@ -2,7 +2,6 @@ package jwk_test import ( "context" - "fmt" "testing" "github.com/lestrrat-go/jwx/jwa" @@ -98,8 +97,8 @@ func TestHeader(t *testing.T) { if !assert.NoError(t, h.Set("Default", dummy), `Setting "Default" should succeed`) { return } - if h.Algorithm() != "" { - t.Fatalf("Algorithm should be empty string") + if !assert.Empty(t, h.Algorithm().String(), "Algorithm should be empty string") { + return } if h.KeyID() != "" { t.Fatalf("KeyID should be empty string") @@ -128,7 +127,7 @@ func TestHeader(t *testing.T) { return } - if !assert.Equal(t, value.(fmt.Stringer).String(), got, "values match") { + if !assert.Equal(t, value, got, "values match") { return } } diff --git a/jwk/interface_gen.go b/jwk/interface_gen.go index 1f94e540b..b6e603e35 100644 --- a/jwk/interface_gen.go +++ b/jwk/interface_gen.go @@ -95,7 +95,12 @@ type Key interface { // KeyOps returns `key_ops` of a JWK KeyOps() KeyOperationList // Algorithm returns `alg` of a JWK - Algorithm() string + + // Algorithm returns the value of the `alg` field + // + // This field may contain either `jwk.SignatureAlgorithm` or `jwk.KeyEncryptionAlgorithm`. + // This is why there exists a `jwa.KeyAlgorithm` type that encompases both types. + Algorithm() jwa.KeyAlgorithm // KeyID returns `kid` of a JWK KeyID() string // X509URL returns `x58` of a JWK diff --git a/jwk/internal/cmd/genheader/main.go b/jwk/internal/cmd/genheader/main.go index d66371288..47632fcba 100644 --- a/jwk/internal/cmd/genheader/main.go +++ b/jwk/internal/cmd/genheader/main.go @@ -53,6 +53,7 @@ type KeyType struct { func _main() error { codegen.RegisterZeroVal(`jwa.EllipticCurveAlgorithm`, `jwa.InvalidEllipticCurve`) codegen.RegisterZeroVal(`jwa.KeyType`, `jwa.InvalidKeyType`) + codegen.RegisterZeroVal(`jwa.KeyAlgorithm`, `jwa.InvalidKeyAlgorithm("")`) var objectsFile = flag.String("objects", "objects.yml", "") flag.Parse() @@ -411,10 +412,12 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { o.L("case %s:", keyName) if f.Name(false) == `algorithm` { o.L("switch v := value.(type) {") - o.L("case string:") - o.L("h.algorithm = &v") + o.L("case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm:") + o.L("var tmp = jwa.KeyAlgorithmFrom(v)") + o.L("h.algorithm = &tmp") o.L("case fmt.Stringer:") - o.L("tmp := v.String()") + o.L("s := v.String()") + o.L("var tmp = jwa.KeyAlgorithmFrom(s)") o.L("h.algorithm = &tmp") o.L("default:") o.L("return errors.Errorf(`invalid type for %%s key: %%T`, %s, value)", keyName) @@ -542,6 +545,14 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { o.L("if err := json.AssignNextStringToken(&h.%s, dec); err != nil {", f.Name(false)) o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", f.Name(true)) o.L("}") + } else if f.Type() == "jwa.KeyAlgorithm" { + o.L("case %sKey:", f.Name(true)) + o.L("var s string") + o.L("if err := dec.Decode(&s); err != nil {") + o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", f.Name(true)) + o.L("}") + o.L("alg := jwa.KeyAlgorithmFrom(s)") + o.L("h.%s = &alg", f.Name(false)) } else if f.Type() == "[]byte" { name := f.Name(true) switch f.Name(false) { @@ -746,6 +757,12 @@ func generateGenericHeaders(fields codegen.FieldList) error { o.L("KeyType() jwa.KeyType") for _, f := range fields { o.L("// %s returns `%s` of a JWK", f.GetterMethod(true), f.JSON()) + if f.Name(false) == "algorithm" { + o.LL("// Algorithm returns the value of the `alg` field") + o.L("//") + o.L("// This field may contain either `jwk.SignatureAlgorithm` or `jwk.KeyEncryptionAlgorithm`.") + o.L("// This is why there exists a `jwa.KeyAlgorithm` type that encompases both types.") + } o.L("%s() ", f.GetterMethod(true)) if v := fieldGetterReturnValue(f); v != "" { o.R("%s", v) diff --git a/jwk/internal/cmd/genheader/objects.yml b/jwk/internal/cmd/genheader/objects.yml index 778f2412c..5dfb7405e 100644 --- a/jwk/internal/cmd/genheader/objects.yml +++ b/jwk/internal/cmd/genheader/objects.yml @@ -13,6 +13,7 @@ std_fields: json: alg is_std: true comment: https://tools.ietf.org/html/rfc7517#section-4.4 + type: jwa.KeyAlgorithm - name: keyID json: kid is_std: true diff --git a/jwk/iterator_test.go b/jwk/iterator_test.go index c6221d1a6..c3bdd7650 100644 --- a/jwk/iterator_test.go +++ b/jwk/iterator_test.go @@ -14,7 +14,7 @@ import ( func TestIterator(t *testing.T) { commonValues := map[string]interface{}{ - jwk.AlgorithmKey: "dummy", + jwk.AlgorithmKey: jwa.KeyAlgorithmFrom("dummy"), jwk.KeyIDKey: "dummy-kid", jwk.KeyUsageKey: "dummy-usage", jwk.KeyOpsKey: jwk.KeyOperationList{jwk.KeyOpSign, jwk.KeyOpVerify, jwk.KeyOpEncrypt, jwk.KeyOpDecrypt, jwk.KeyOpWrapKey, jwk.KeyOpUnwrapKey, jwk.KeyOpDeriveKey, jwk.KeyOpDeriveBits}, diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index df0e6886f..c7b138ddf 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -54,7 +54,7 @@ func init() { commonDef = map[string]keyDef{ jwk.AlgorithmKey: { Method: "Algorithm", - Value: "random-algorithm", + Value: jwa.KeyAlgorithmFrom("random-algorithm"), }, jwk.KeyIDKey: { Method: "KeyID", diff --git a/jwk/okp_gen.go b/jwk/okp_gen.go index 022c285f2..69523373d 100644 --- a/jwk/okp_gen.go +++ b/jwk/okp_gen.go @@ -33,7 +33,7 @@ type OKPPublicKey interface { } type okpPublicKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4 crv *jwa.EllipticCurveAlgorithm keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 keyOps *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 @@ -63,11 +63,11 @@ func (h okpPublicKey) KeyType() jwa.KeyType { return jwa.OKP } -func (h *okpPublicKey) Algorithm() string { +func (h *okpPublicKey) Algorithm() jwa.KeyAlgorithm { if h.algorithm != nil { return *(h.algorithm) } - return "" + return jwa.InvalidKeyAlgorithm("") } func (h *okpPublicKey) Crv() jwa.EllipticCurveAlgorithm { @@ -250,10 +250,12 @@ func (h *okpPublicKey) setNoLock(name string, value interface{}) error { return nil case AlgorithmKey: switch v := value.(type) { - case string: - h.algorithm = &v + case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm: + var tmp = jwa.KeyAlgorithmFrom(v) + h.algorithm = &tmp case fmt.Stringer: - tmp := v.String() + s := v.String() + var tmp = jwa.KeyAlgorithmFrom(s) h.algorithm = &tmp default: return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value) @@ -417,9 +419,12 @@ LOOP: return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val) } case AlgorithmKey: - if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil { + var s string + if err := dec.Decode(&s); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey) } + alg := jwa.KeyAlgorithmFrom(s) + h.algorithm = &alg case OKPCrvKey: var decoded jwa.EllipticCurveAlgorithm if err := dec.Decode(&decoded); err != nil { @@ -564,7 +569,7 @@ type OKPPrivateKey interface { } type okpPrivateKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4 crv *jwa.EllipticCurveAlgorithm d []byte keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 @@ -595,11 +600,11 @@ func (h okpPrivateKey) KeyType() jwa.KeyType { return jwa.OKP } -func (h *okpPrivateKey) Algorithm() string { +func (h *okpPrivateKey) Algorithm() jwa.KeyAlgorithm { if h.algorithm != nil { return *(h.algorithm) } - return "" + return jwa.InvalidKeyAlgorithm("") } func (h *okpPrivateKey) Crv() jwa.EllipticCurveAlgorithm { @@ -794,10 +799,12 @@ func (h *okpPrivateKey) setNoLock(name string, value interface{}) error { return nil case AlgorithmKey: switch v := value.(type) { - case string: - h.algorithm = &v + case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm: + var tmp = jwa.KeyAlgorithmFrom(v) + h.algorithm = &tmp case fmt.Stringer: - tmp := v.String() + s := v.String() + var tmp = jwa.KeyAlgorithmFrom(s) h.algorithm = &tmp default: return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value) @@ -970,9 +977,12 @@ LOOP: return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val) } case AlgorithmKey: - if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil { + var s string + if err := dec.Decode(&s); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey) } + alg := jwa.KeyAlgorithmFrom(s) + h.algorithm = &alg case OKPCrvKey: var decoded jwa.EllipticCurveAlgorithm if err := dec.Decode(&decoded); err != nil { diff --git a/jwk/rsa_gen.go b/jwk/rsa_gen.go index 21de3bf1e..d5e60bc17 100644 --- a/jwk/rsa_gen.go +++ b/jwk/rsa_gen.go @@ -39,7 +39,7 @@ type RSAPublicKey interface { } type rsaPublicKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4 e []byte keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 keyOps *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 @@ -69,11 +69,11 @@ func (h rsaPublicKey) KeyType() jwa.KeyType { return jwa.RSA } -func (h *rsaPublicKey) Algorithm() string { +func (h *rsaPublicKey) Algorithm() jwa.KeyAlgorithm { if h.algorithm != nil { return *(h.algorithm) } - return "" + return jwa.InvalidKeyAlgorithm("") } func (h *rsaPublicKey) E() []byte { @@ -253,10 +253,12 @@ func (h *rsaPublicKey) setNoLock(name string, value interface{}) error { return nil case AlgorithmKey: switch v := value.(type) { - case string: - h.algorithm = &v + case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm: + var tmp = jwa.KeyAlgorithmFrom(v) + h.algorithm = &tmp case fmt.Stringer: - tmp := v.String() + s := v.String() + var tmp = jwa.KeyAlgorithmFrom(s) h.algorithm = &tmp default: return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value) @@ -420,9 +422,12 @@ LOOP: return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val) } case AlgorithmKey: - if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil { + var s string + if err := dec.Decode(&s); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey) } + alg := jwa.KeyAlgorithmFrom(s) + h.algorithm = &alg case RSAEKey: if err := json.AssignNextBytesToken(&h.e, dec); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, RSAEKey) @@ -570,7 +575,7 @@ type RSAPrivateKey interface { } type rsaPrivateKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4 d []byte dp []byte dq []byte @@ -606,11 +611,11 @@ func (h rsaPrivateKey) KeyType() jwa.KeyType { return jwa.RSA } -func (h *rsaPrivateKey) Algorithm() string { +func (h *rsaPrivateKey) Algorithm() jwa.KeyAlgorithm { if h.algorithm != nil { return *(h.algorithm) } - return "" + return jwa.InvalidKeyAlgorithm("") } func (h *rsaPrivateKey) D() []byte { @@ -862,10 +867,12 @@ func (h *rsaPrivateKey) setNoLock(name string, value interface{}) error { return nil case AlgorithmKey: switch v := value.(type) { - case string: - h.algorithm = &v + case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm: + var tmp = jwa.KeyAlgorithmFrom(v) + h.algorithm = &tmp case fmt.Stringer: - tmp := v.String() + s := v.String() + var tmp = jwa.KeyAlgorithmFrom(s) h.algorithm = &tmp default: return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value) @@ -1083,9 +1090,12 @@ LOOP: return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val) } case AlgorithmKey: - if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil { + var s string + if err := dec.Decode(&s); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey) } + alg := jwa.KeyAlgorithmFrom(s) + h.algorithm = &alg case RSADKey: if err := json.AssignNextBytesToken(&h.d, dec); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, RSADKey) diff --git a/jwk/symmetric_gen.go b/jwk/symmetric_gen.go index 60b0fd972..58f972e22 100644 --- a/jwk/symmetric_gen.go +++ b/jwk/symmetric_gen.go @@ -30,7 +30,7 @@ type SymmetricKey interface { } type symmetricKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4 keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 keyOps *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 @@ -59,11 +59,11 @@ func (h symmetricKey) KeyType() jwa.KeyType { return jwa.OctetSeq } -func (h *symmetricKey) Algorithm() string { +func (h *symmetricKey) Algorithm() jwa.KeyAlgorithm { if h.algorithm != nil { return *(h.algorithm) } - return "" + return jwa.InvalidKeyAlgorithm("") } func (h *symmetricKey) KeyID() string { @@ -231,10 +231,12 @@ func (h *symmetricKey) setNoLock(name string, value interface{}) error { return nil case AlgorithmKey: switch v := value.(type) { - case string: - h.algorithm = &v + case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm: + var tmp = jwa.KeyAlgorithmFrom(v) + h.algorithm = &tmp case fmt.Stringer: - tmp := v.String() + s := v.String() + var tmp = jwa.KeyAlgorithmFrom(s) h.algorithm = &tmp default: return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value) @@ -389,9 +391,12 @@ LOOP: return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val) } case AlgorithmKey: - if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil { + var s string + if err := dec.Decode(&s); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey) } + alg := jwa.KeyAlgorithmFrom(s) + h.algorithm = &alg case KeyIDKey: if err := json.AssignNextStringToken(&h.keyID, dec); err != nil { return errors.Wrapf(err, `failed to decode value for key %s`, KeyIDKey) diff --git a/jws/jws.go b/jws/jws.go index 1aea7255d..4e2f74233 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -108,7 +108,12 @@ var muSigner = &sync.Mutex{} // If you want to use a detached payload, use `jws.WithDetachedPayload()` as // one of the options. When you use this option, you must always set the // first parameter (`payload`) to `nil`, or the function will return an error -func Sign(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) { +func Sign(payload []byte, alg jwa.KeyAlgorithm, key interface{}, options ...SignOption) ([]byte, error) { + salg, ok := alg.(jwa.SignatureAlgorithm) + if !ok { + return nil, errors.Errorf(`expected alg to be jwa.SignatureAlgorithm, but got %T`, alg) + } + var hdrs Headers var detached bool for _, o := range options { @@ -126,14 +131,14 @@ func Sign(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, options . } muSigner.Lock() - signer, ok := signers[alg] + signer, ok := signers[salg] if !ok { - v, err := NewSigner(alg) + v, err := NewSigner(salg) if err != nil { muSigner.Unlock() return nil, errors.Wrap(err, `failed to create signer`) } - signers[alg] = v + signers[salg] = v signer = v } muSigner.Unlock() diff --git a/jwt/jwt.go b/jwt/jwt.go index 4773fbbe3..170a49078 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -324,12 +324,20 @@ OUTER: // // The algorithm specified in the `alg` parameter must be able to support // the type of key you provided, otherwise an error is returned. +// For convenience `alg` is of type jwa.KeyAlgorithm so you can pass +// the return value of `(jwk.Key).Algorithm()` directly, but in practice +// it must be an instance of jwa.SignatureAlgorithm, otherwise an error +// is returned. // // The protected header will also automatically have the `typ` field set // to the literal value `JWT`, unless you provide a custom value for it // by jwt.WithHeaders option. -func Sign(t Token, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) { - return NewSerializer().Sign(alg, key, options...).Serialize(t) +func Sign(t Token, alg jwa.KeyAlgorithm, key interface{}, options ...SignOption) ([]byte, error) { + salg, ok := alg.(jwa.SignatureAlgorithm) + if !ok { + return nil, errors.Errorf(`jwt.Sign received %T for alg. Expected jwa.SignatureAlgorithm`, alg) + } + return NewSerializer().Sign(salg, key, options...).Serialize(t) } // Equal compares two JWT tokens. Do not use `reflect.Equal` or the like diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 980f60fd4..7cd051a90 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -662,7 +662,7 @@ func TestSignJWK(t *testing.T) { key.Set(jwk.AlgorithmKey, jwa.RS256) tok := jwt.New() - signed, err := jwt.Sign(tok, jwa.SignatureAlgorithm(key.Algorithm()), key) + signed, err := jwt.Sign(tok, key.Algorithm(), key) assert.Nil(t, err) header, err := jws.ParseString(string(signed)) diff --git a/jwt/options.go b/jwt/options.go index dbeb4b4fd..ae89a70e3 100644 --- a/jwt/options.go +++ b/jwt/options.go @@ -371,16 +371,16 @@ func WithValidator(v Validator) ValidateOption { } type decryptParams struct { - alg jwa.KeyEncryptionAlgorithm + alg jwa.KeyAlgorithm key interface{} } type DecryptParameters interface { - Algorithm() jwa.KeyEncryptionAlgorithm + Algorithm() jwa.KeyAlgorithm Key() interface{} } -func (dp *decryptParams) Algorithm() jwa.KeyEncryptionAlgorithm { +func (dp *decryptParams) Algorithm() jwa.KeyAlgorithm { return dp.alg } @@ -390,7 +390,12 @@ func (dp *decryptParams) Key() interface{} { // WithDecrypt allows users to specify parameters for decryption using // `jwe.Decrypt`. You must specify this if your JWT is encrypted. -func WithDecrypt(alg jwa.KeyEncryptionAlgorithm, key interface{}) ParseOption { +// +// While `alg` accept jwa.KeyAlgorithm for convenience so you can +// directly pass the return value of `(jwk.Key).Algorithm()`, in practice +// the value must be of type jwa.SignatureAlgorithm. Otherwise the +// verification will fail +func WithDecrypt(alg jwa.KeyAlgorithm, key interface{}) ParseOption { return newParseOption(identDecrypt{}, &decryptParams{ alg: alg, key: key,