Skip to content

Commit

Permalink
Change JWK alg field (#553)
Browse files Browse the repository at this point in the history
* Fix return value (presumably broken during merge)

* Change the type of (jwk.Key).Algorithm

Allows users to directly pass (jwk.Key).Algorithm().
This forces us to jump through some hoops, but users (rightly)
do not get why (jwk.Key).Algorithm() is left as a string.

* Update docs on why jwa.KeyAlgorithm and where

* Update jws.Sign and jws.Verify

* appease linter

* run go generate

* Update Changes
  • Loading branch information
lestrrat committed Feb 22, 2022
1 parent a222ffb commit 2aa98ce
Show file tree
Hide file tree
Showing 18 changed files with 253 additions and 74 deletions.
27 changes: 27 additions & 0 deletions docs/99-faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,30 @@ After all, the only thing we can say is that you should check that you are parsi
## Why are you generating so many fields?

Because a lot of the code is repetitive. For example, maintaining the 15 fields in a JWE header in all parts of the code (getter methods, setter methods, marshaling/unmarshaling) is doable but very very very cumbersome. We think that resources used to watching out for typos and other minor problems that may arise during maintenance is better spent elsewhere by automating generation of consistent code.

## Why is (jwk.Key).Algorithm() and jwa.KeyAlgorithm so confusing?

To start, we sympathize. Please read on for the reason(s) why things are the way they are.

First you must understand that JWKs can be used for multiple different purposes, including but not limited to JWS and JWE (key encryption). And the `alg` field is supposed to carry to what purpose the JWK is supposed to be used.

This means that a JWK can, in jwx terms, carry either `jwa.SignatureAlgorithm` or `jwa.KeyEncryptionAlgorithm`.

In order to allow passing either `jwa.SignatureAlgorithm` or `jwa.KeyEncrypionAlgorithm`, we initially implemented
`(jwk.Key).Algorithm()` as a string, so it was possible to just change the type depending on the situation.

This caused a bit of confusion for some users because this field was the only "untyped" field that potentially could have been typed. Most notably, some people wanted to do the following, but couldn't:

```go
jwt.Verify(token, jwt.WithVerify(key.Algorithm(), key))
```

Since version 2.0.0 `jwk.Key` now stores the `alg` field as a `jwa.KeyAlgorithm` type, which is just an interface that covers `jwa.SignatureAlgorithm`, `jwa.KeyEncryptionAlgorithm`, or any other type that we may need to represent in the future.

Now you should be able to just pass the `alg` value to most high-level functions and methods such as `jwt.Verify`, `jws.Sign`, and `jwe.Encrypt`

### When do we use `jwa.KeyAlgorithm`

There are some functions that accept `jwa.KeyAlgorithm`, while there are others that expect `jwa.SignatureAlgorithm` or `jwa.KeyEncryptionAlgorithm`. So when do we use which?

The guideline is as follows: If it's a high-level function/method that the users regularly use, use `jwa.KeyAlgorithm`. For example, almost everybody who use `jwt` will want to verify the JWS signed payload, so `jwt.Sign()`, and `jwt.Verify()` expect `jwa.KeyAlgorithm`. On the other hand, `jwt.Serializer` uses `jwa.SignatureAlgorithm` and such. This is a low-level utility, and users are not really meant to use it for their most basic needs: therefore they use the specific algorithm type.
57 changes: 57 additions & 0 deletions jwa/jwa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,60 @@

// Package jwa defines the various algorithm described in https://tools.ietf.org/html/rfc7518
package jwa

import "fmt"

// KeyAlgorithm is a workaround for jwk.Key being able to contain different
// types of algorithms in its `alg` field.
//
// Previously the storage for the `alg` field was represented as a string,
// but this caused some users to wonder why the field was not typed appropriately
// like other fields.
//
// Ideally we would like to keep track of Signature Algorithms and
// Content Encryption Algorithms separately, and force the APIs to
// type-check at compile time, but this allows users to pass a value from a
// jwk.Key directly
type KeyAlgorithm interface {
String() string
}

// InvalidKeyAlgorithm represents an algorithm that the library is not aware of.
type InvalidKeyAlgorithm string

func (s InvalidKeyAlgorithm) String() string {
return string(s)
}

func (InvalidKeyAlgorithm) Accept(_ interface{}) error {
return fmt.Errorf(`jwa.InvalidKeyAlgorithm does not support Accept() method calls`)
}

// KeyAlgorithmFrom takes either a string, `jwa.SignatureAlgorithm` or `jwa.KeyEncryptionAlgorithm`
// and returns a `jwa.KeyAlgorithm`.
//
// If the value cannot be handled, it returns an `jwa.InvalidKeyAlgorithm`
// object instead of returning an error. This design choice was made to allow
// users to directly pass the return value to functions such as `jws.Sign()`
func KeyAlgorithmFrom(v interface{}) KeyAlgorithm {
switch v := v.(type) {
case SignatureAlgorithm:
return v
case ContentEncryptionAlgorithm:
return v
case string:
var salg SignatureAlgorithm
if err := salg.Accept(v); err == nil {
return salg
}

var kealg KeyEncryptionAlgorithm
if err := kealg.Accept(v); err == nil {
return kealg
}

return InvalidKeyAlgorithm(v)
default:
return InvalidKeyAlgorithm(fmt.Sprintf("%s", v))
}
}
4 changes: 3 additions & 1 deletion jwe/gh402_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ func TestGH402(t *testing.T) {
m := jwe.NewMessage()
// Test WithPostParse while we're at it
plain, err := jwe.Decrypt([]byte(data),
"invalid algorithm",
// This is a really cheesy way of creating a jwa.KeyEncryptionAlgorithm
// but a bogus one.
jwa.KeyEncryptionAlgorithm("invalid algorithm"),
nil,
jwe.WithMessage(m),
jwe.WithPostParser(pp),
Expand Down
24 changes: 21 additions & 3 deletions jwe/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,17 @@ var registry = json.NewRegistry()
// Encrypt takes the plaintext payload and encrypts it in JWE compact format.
// `key` should be a public key, and it may be a raw key (e.g. rsa.PublicKey) or a jwk.Key
//
// `alg` accepts a `jwa.KeyAlgorithm` for convenience so you can directly pass
// the result of `(jwk.Key).Algorithm()`, but in practice it must be of type
// `jwa.KeyEncryptionAlgorithm` or otherwise it will cause an error.
//
// Encrypt currently does not support multi-recipient messages.
func Encrypt(payload []byte, keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) ([]byte, error) {
func Encrypt(payload []byte, alg jwa.KeyAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) ([]byte, error) {
keyalg, ok := alg.(jwa.KeyEncryptionAlgorithm)
if !ok {
return nil, errors.Errorf(`expected alg to be jwa.KeyEncryptionAlgorithm, but got %T`, alg)
}

var protected Headers
for _, option := range options {
//nolint:forcetypeassert
Expand Down Expand Up @@ -212,11 +221,20 @@ func (ctx *decryptCtx) SetMessage(m *Message) {
// key to decrypt the JWE message, and returns the decrypted payload.
// The JWE message can be either compact or full JSON format.
//
// `alg` accepts a `jwa.KeyAlgorithm` for convenience so you can directly pass
// the result of `(jwk.Key).Algorithm()`, but in practice it must be of type
// `jwa.KeyEncryptionAlgorithm` or otherwise it will cause an error.
//
// `key` must be a private key. It can be either in its raw format (e.g. *rsa.PrivateKey) or a jwk.Key
func Decrypt(buf []byte, alg jwa.KeyEncryptionAlgorithm, key interface{}, options ...DecryptOption) ([]byte, error) {
func Decrypt(buf []byte, alg jwa.KeyAlgorithm, key interface{}, options ...DecryptOption) ([]byte, error) {
keyalg, ok := alg.(jwa.KeyEncryptionAlgorithm)
if !ok {
return nil, errors.Errorf(`expected alg to be jwa.KeyEncryptionAlgorithm, but got %T`, alg)
}

var ctx decryptCtx
ctx.key = key
ctx.alg = alg
ctx.alg = keyalg

var dst *Message
var postParse PostParser
Expand Down
38 changes: 24 additions & 14 deletions jwk/ecdsa_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type ECDSAPublicKey interface {
}

type ecdsaPublicKey struct {
algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4
algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4
crv *jwa.EllipticCurveAlgorithm
keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4
keyOps *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3
Expand Down Expand Up @@ -67,11 +67,11 @@ func (h ecdsaPublicKey) KeyType() jwa.KeyType {
return jwa.EC
}

func (h *ecdsaPublicKey) Algorithm() string {
func (h *ecdsaPublicKey) Algorithm() jwa.KeyAlgorithm {
if h.algorithm != nil {
return *(h.algorithm)
}
return ""
return jwa.InvalidKeyAlgorithm("")
}

func (h *ecdsaPublicKey) Crv() jwa.EllipticCurveAlgorithm {
Expand Down Expand Up @@ -266,10 +266,12 @@ func (h *ecdsaPublicKey) setNoLock(name string, value interface{}) error {
return nil
case AlgorithmKey:
switch v := value.(type) {
case string:
h.algorithm = &v
case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm:
var tmp = jwa.KeyAlgorithmFrom(v)
h.algorithm = &tmp
case fmt.Stringer:
tmp := v.String()
s := v.String()
var tmp = jwa.KeyAlgorithmFrom(s)
h.algorithm = &tmp
default:
return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value)
Expand Down Expand Up @@ -442,9 +444,12 @@ LOOP:
return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val)
}
case AlgorithmKey:
if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil {
var s string
if err := dec.Decode(&s); err != nil {
return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey)
}
alg := jwa.KeyAlgorithmFrom(s)
h.algorithm = &alg
case ECDSACrvKey:
var decoded jwa.EllipticCurveAlgorithm
if err := dec.Decode(&decoded); err != nil {
Expand Down Expand Up @@ -597,7 +602,7 @@ type ECDSAPrivateKey interface {
}

type ecdsaPrivateKey struct {
algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4
algorithm *jwa.KeyAlgorithm // https://tools.ietf.org/html/rfc7517#section-4.4
crv *jwa.EllipticCurveAlgorithm
d []byte
keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4
Expand Down Expand Up @@ -629,11 +634,11 @@ func (h ecdsaPrivateKey) KeyType() jwa.KeyType {
return jwa.EC
}

func (h *ecdsaPrivateKey) Algorithm() string {
func (h *ecdsaPrivateKey) Algorithm() jwa.KeyAlgorithm {
if h.algorithm != nil {
return *(h.algorithm)
}
return ""
return jwa.InvalidKeyAlgorithm("")
}

func (h *ecdsaPrivateKey) Crv() jwa.EllipticCurveAlgorithm {
Expand Down Expand Up @@ -840,10 +845,12 @@ func (h *ecdsaPrivateKey) setNoLock(name string, value interface{}) error {
return nil
case AlgorithmKey:
switch v := value.(type) {
case string:
h.algorithm = &v
case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm:
var tmp = jwa.KeyAlgorithmFrom(v)
h.algorithm = &tmp
case fmt.Stringer:
tmp := v.String()
s := v.String()
var tmp = jwa.KeyAlgorithmFrom(s)
h.algorithm = &tmp
default:
return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value)
Expand Down Expand Up @@ -1025,9 +1032,12 @@ LOOP:
return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val)
}
case AlgorithmKey:
if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil {
var s string
if err := dec.Decode(&s); err != nil {
return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey)
}
alg := jwa.KeyAlgorithmFrom(s)
h.algorithm = &alg
case ECDSACrvKey:
var decoded jwa.EllipticCurveAlgorithm
if err := dec.Decode(&decoded); err != nil {
Expand Down
7 changes: 3 additions & 4 deletions jwk/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package jwk_test

import (
"context"
"fmt"
"testing"

"github.com/lestrrat-go/jwx/jwa"
Expand Down Expand Up @@ -98,8 +97,8 @@ func TestHeader(t *testing.T) {
if !assert.NoError(t, h.Set("Default", dummy), `Setting "Default" should succeed`) {
return
}
if h.Algorithm() != "" {
t.Fatalf("Algorithm should be empty string")
if !assert.Empty(t, h.Algorithm().String(), "Algorithm should be empty string") {
return
}
if h.KeyID() != "" {
t.Fatalf("KeyID should be empty string")
Expand Down Expand Up @@ -128,7 +127,7 @@ func TestHeader(t *testing.T) {
return
}

if !assert.Equal(t, value.(fmt.Stringer).String(), got, "values match") {
if !assert.Equal(t, value, got, "values match") {
return
}
}
Expand Down
7 changes: 6 additions & 1 deletion jwk/interface_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ type Key interface {
// KeyOps returns `key_ops` of a JWK
KeyOps() KeyOperationList
// Algorithm returns `alg` of a JWK
Algorithm() string

// Algorithm returns the value of the `alg` field
//
// This field may contain either `jwk.SignatureAlgorithm` or `jwk.KeyEncryptionAlgorithm`.
// This is why there exists a `jwa.KeyAlgorithm` type that encompases both types.
Algorithm() jwa.KeyAlgorithm
// KeyID returns `kid` of a JWK
KeyID() string
// X509URL returns `x58` of a JWK
Expand Down
23 changes: 20 additions & 3 deletions jwk/internal/cmd/genheader/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type KeyType struct {
func _main() error {
codegen.RegisterZeroVal(`jwa.EllipticCurveAlgorithm`, `jwa.InvalidEllipticCurve`)
codegen.RegisterZeroVal(`jwa.KeyType`, `jwa.InvalidKeyType`)
codegen.RegisterZeroVal(`jwa.KeyAlgorithm`, `jwa.InvalidKeyAlgorithm("")`)

var objectsFile = flag.String("objects", "objects.yml", "")
flag.Parse()
Expand Down Expand Up @@ -411,10 +412,12 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error {
o.L("case %s:", keyName)
if f.Name(false) == `algorithm` {
o.L("switch v := value.(type) {")
o.L("case string:")
o.L("h.algorithm = &v")
o.L("case string, jwa.SignatureAlgorithm, jwa.ContentEncryptionAlgorithm:")
o.L("var tmp = jwa.KeyAlgorithmFrom(v)")
o.L("h.algorithm = &tmp")
o.L("case fmt.Stringer:")
o.L("tmp := v.String()")
o.L("s := v.String()")
o.L("var tmp = jwa.KeyAlgorithmFrom(s)")
o.L("h.algorithm = &tmp")
o.L("default:")
o.L("return errors.Errorf(`invalid type for %%s key: %%T`, %s, value)", keyName)
Expand Down Expand Up @@ -542,6 +545,14 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error {
o.L("if err := json.AssignNextStringToken(&h.%s, dec); err != nil {", f.Name(false))
o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", f.Name(true))
o.L("}")
} else if f.Type() == "jwa.KeyAlgorithm" {
o.L("case %sKey:", f.Name(true))
o.L("var s string")
o.L("if err := dec.Decode(&s); err != nil {")
o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", f.Name(true))
o.L("}")
o.L("alg := jwa.KeyAlgorithmFrom(s)")
o.L("h.%s = &alg", f.Name(false))
} else if f.Type() == "[]byte" {
name := f.Name(true)
switch f.Name(false) {
Expand Down Expand Up @@ -746,6 +757,12 @@ func generateGenericHeaders(fields codegen.FieldList) error {
o.L("KeyType() jwa.KeyType")
for _, f := range fields {
o.L("// %s returns `%s` of a JWK", f.GetterMethod(true), f.JSON())
if f.Name(false) == "algorithm" {
o.LL("// Algorithm returns the value of the `alg` field")
o.L("//")
o.L("// This field may contain either `jwk.SignatureAlgorithm` or `jwk.KeyEncryptionAlgorithm`.")
o.L("// This is why there exists a `jwa.KeyAlgorithm` type that encompases both types.")
}
o.L("%s() ", f.GetterMethod(true))
if v := fieldGetterReturnValue(f); v != "" {
o.R("%s", v)
Expand Down
1 change: 1 addition & 0 deletions jwk/internal/cmd/genheader/objects.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ std_fields:
json: alg
is_std: true
comment: https://tools.ietf.org/html/rfc7517#section-4.4
type: jwa.KeyAlgorithm
- name: keyID
json: kid
is_std: true
Expand Down
2 changes: 1 addition & 1 deletion jwk/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestIterator(t *testing.T) {
commonValues := map[string]interface{}{
jwk.AlgorithmKey: "dummy",
jwk.AlgorithmKey: jwa.KeyAlgorithmFrom("dummy"),
jwk.KeyIDKey: "dummy-kid",
jwk.KeyUsageKey: "dummy-usage",
jwk.KeyOpsKey: jwk.KeyOperationList{jwk.KeyOpSign, jwk.KeyOpVerify, jwk.KeyOpEncrypt, jwk.KeyOpDecrypt, jwk.KeyOpWrapKey, jwk.KeyOpUnwrapKey, jwk.KeyOpDeriveKey, jwk.KeyOpDeriveBits},
Expand Down
2 changes: 1 addition & 1 deletion jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func init() {
commonDef = map[string]keyDef{
jwk.AlgorithmKey: {
Method: "Algorithm",
Value: "random-algorithm",
Value: jwa.KeyAlgorithmFrom("random-algorithm"),
},
jwk.KeyIDKey: {
Method: "KeyID",
Expand Down
Loading

0 comments on commit 2aa98ce

Please sign in to comment.