Skip to content

Commit

Permalink
Remove unnecessary bits, rework so options are easier
Browse files Browse the repository at this point in the history
  • Loading branch information
lestrrat committed Feb 22, 2022
1 parent beade17 commit a222ffb
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 141 deletions.
10 changes: 10 additions & 0 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@ func PublicRawKeyOf(v interface{}) (interface{}, error) {
}
}

type SetFetcher interface {
Fetch(context.Context, string, ...FetchOption) (Set, error)
}

type SetFetchFunc func(context.Context, string, ...FetchOption) (Set, error)

func (f SetFetchFunc) Fetch(ctx context.Context, urlstring string, options ...FetchOption) (Set, error) {
return f(ctx, urlstring, options...)
}

// Fetch fetches a JWK resource specified by a URL. The url must be
// pointing to a resource that is supported by `net/http`.
//
Expand Down
74 changes: 22 additions & 52 deletions jws/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,6 @@ func SignMulti(payload []byte, options ...Option) ([]byte, error) {
return json.Marshal(result)
}

type verifyCtx struct {
dst *Message
detachedPayload []byte
keyProviders []KeyProvider
keyUsed interface{}
}

var allowNoneWhitelist = jwk.WhitelistFunc(func(string) bool {
return false
})
Expand All @@ -231,18 +224,26 @@ var allowNoneWhitelist = jwk.WhitelistFunc(func(string) bool {
// If you need to access signatures and JOSE headers in a JWS message,
// use `Parse` function to get `Message` object.
func Verify(buf []byte, options ...VerifyOption) ([]byte, error) {
var ctx verifyCtx
var dst *Message
var detachedPayload []byte
var keyProviders []KeyProvider
var keyUsed interface{}

ctx := context.Background()

//nolint:forcetypeassert
for _, option := range options {
switch option.Ident() {
case identMessage{}:
ctx.dst = option.Value().(*Message)
dst = option.Value().(*Message)
case identDetachedPayload{}:
ctx.detachedPayload = option.Value().([]byte)
detachedPayload = option.Value().([]byte)
case identKeyProvider{}:
ctx.keyProviders = append(ctx.keyProviders, option.Value().(KeyProvider))
keyProviders = append(keyProviders, option.Value().(KeyProvider))
case identKeyUsed{}:
ctx.keyUsed = option.Value()
keyUsed = option.Value()
case identContext{}:
ctx = option.Value().(context.Context)
default:
return nil, errors.Errorf(`invalid jws.VerifyOption %q passed`, `With`+strings.TrimPrefix(fmt.Sprintf(`%T`, option.Ident()), `jws.ident`))
}
Expand All @@ -254,12 +255,12 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) {
}
defer msg.clearRaw()

if ctx.detachedPayload != nil {
if detachedPayload != nil {
if len(msg.payload) != 0 {
return nil, fmt.Errorf(`can't specify detached payload for JWS with payload`)
}

msg.payload = ctx.detachedPayload
msg.payload = detachedPayload
}

// Pre-compute the base64 encoded version of payload
Expand Down Expand Up @@ -296,9 +297,9 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) {
verifyBuf.WriteByte('.')
verifyBuf.WriteString(payload)

for i, kp := range ctx.keyProviders {
for i, kp := range keyProviders {
var sink algKeySink
if err := kp.FetchKeys(&sink, sig); err != nil {
if err := kp.FetchKeys(ctx, &sink, sig); err != nil {
return nil, fmt.Errorf(`key provider %d failed: %w`, i, err)
}

Expand All @@ -313,14 +314,14 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) {
continue
}

if ctx.keyUsed != nil {
if err := blackmagic.AssignIfCompatible(ctx.keyUsed, key); err != nil {
return nil, fmt.Errorf(`failed to assign used key (%T) to %T: %w`, key, ctx.keyUsed, err)
if keyUsed != nil {
if err := blackmagic.AssignIfCompatible(keyUsed, key); err != nil {
return nil, fmt.Errorf(`failed to assign used key (%T) to %T: %w`, key, keyUsed, err)
}
}

if ctx.dst != nil {
*(ctx.dst) = *msg
if dst != nil {
*(dst) = *msg
}

return msg.payload, nil
Expand All @@ -346,37 +347,6 @@ func getB64Value(hdr Headers) bool {
return b64
}

// JWKSetFetcher is used to fetch JWK Set spcified in the `jku` field.
type JWKSetFetcher interface {
Fetch(string) (jwk.Set, error)
}

// SimpleJWKSetFetcher is the default object used to fetch JWK Sets specified in `jku`,
// which uses `jwk.Fetch()`
//
// For more complicated cases, such as using `jwk.AutoRefetch`, you will have to
// create your custom instance of `jws.JWKSetFetcher`
type SimpleJWKSetFetcher struct {
options []jwk.FetchOption
}

func NewJWKSetFetcher(options ...jwk.FetchOption) *SimpleJWKSetFetcher {
// We shove this in the front so that the ser is forced to
// specify a whitelist
options = append([]jwk.FetchOption{jwk.WithFetchWhitelist(allowNoneWhitelist)}, options...)
return &SimpleJWKSetFetcher{options: options}
}

func (f *SimpleJWKSetFetcher) Fetch(u string) (jwk.Set, error) {
return jwk.Fetch(context.TODO(), u, f.options...)
}

type JWKSetFetchFunc func(string) (jwk.Set, error)

func (f JWKSetFetchFunc) Fetch(u string) (jwk.Set, error) {
return f(u)
}

// This is an "optimized" ioutil.ReadAll(). It will attempt to read
// all of the contents from the reader IF the reader is of a certain
// concrete type.
Expand Down
61 changes: 38 additions & 23 deletions jws/jws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1599,38 +1599,37 @@ func TestJKU(t *testing.T) {

t.Run("Compact", func(t *testing.T) {
testcases := []struct {
Name string
Error bool
Query string
Fetcher func() jws.JWKSetFetcher
Name string
Error bool
Query string
Fetcher func() jwk.SetFetcher
FetchOptions func() []jwk.FetchOption
}{
{
Name: "Fail without whitelist",
Error: true,
Fetcher: func() jws.JWKSetFetcher {
return jws.NewJWKSetFetcher(
jwk.WithHTTPClient(srv.Client()),
)
FetchOptions: func() []jwk.FetchOption {
return []jwk.FetchOption{jwk.WithHTTPClient(srv.Client())}
},
},
{
Name: "Success",
Fetcher: func() jws.JWKSetFetcher {
return jws.NewJWKSetFetcher(
FetchOptions: func() []jwk.FetchOption {
return []jwk.FetchOption{
jwk.WithFetchWhitelist(jwk.InsecureWhitelist{}),
jwk.WithHTTPClient(srv.Client()),
)
}
},
},
{
Name: "Rejected by whitelist",
Error: true,
Fetcher: func() jws.JWKSetFetcher {
FetchOptions: func() []jwk.FetchOption {
wl := jwk.NewMapWhitelist().Add(`https://github.com/lestrrat-go/jwx`)
return jws.NewJWKSetFetcher(
return []jwk.FetchOption{
jwk.WithFetchWhitelist(wl),
jwk.WithHTTPClient(srv.Client()),
)
}
},
},
{
Expand All @@ -1641,21 +1640,28 @@ func TestJKU(t *testing.T) {
Name: "Backoff",
Error: false,
Query: "type=backoff",
Fetcher: func() jws.JWKSetFetcher {
FetchOptions: func() []jwk.FetchOption {
bo := backoff.NewConstantPolicy(backoff.WithInterval(500 * time.Millisecond))
return jws.NewJWKSetFetcher(
return []jwk.FetchOption{
jwk.WithFetchWhitelist(jwk.InsecureWhitelist{}),
jwk.WithFetchBackoff(bo),
jwk.WithHTTPClient(srv.Client()),
)
}
},
},
{
Name: "JWKSetFetcher",
Fetcher: func() jws.JWKSetFetcher {
Fetcher: func() jwk.SetFetcher {
ar := jwk.NewAutoRefresh(context.TODO())
return jws.JWKSetFetchFunc(func(u string) (jwk.Set, error) {
ar.Configure(u, jwk.WithHTTPClient(srv.Client()))
return jwk.SetFetchFunc(func(ctx context.Context, u string, options ...jwk.FetchOption) (jwk.Set, error) {
var aropts []jwk.AutoRefreshOption
for _, option := range options {
aropts = append(aropts, option)
}
aropts = append(aropts, jwk.WithHTTPClient(srv.Client()))
aropts = append(aropts, jwk.WithFetchWhitelist(jwk.InsecureWhitelist{}))
ar.Configure(u, aropts...)

return ar.Fetch(context.TODO(), u)
})
},
Expand All @@ -1676,7 +1682,16 @@ func TestJKU(t *testing.T) {
return
}

decoded, err := jws.Verify(signed, jws.WithVerifyAuto(tc.Fetcher()))
var options []jwk.FetchOption
if f := tc.FetchOptions; f != nil {
options = append(options, f()...)
}

var fetcher jwk.SetFetcher
if f := tc.Fetcher; f != nil {
fetcher = f()
}
decoded, err := jws.Verify(signed, jws.WithVerifyAuto(fetcher, options...))
if tc.Error {
if !assert.Error(t, err, `jws.Verify should fail`) {
return
Expand Down Expand Up @@ -1790,8 +1805,8 @@ func TestJKU(t *testing.T) {
options = fn()
}
options = append(options, jwk.WithHTTPClient(srv.Client()))
f := jws.NewJWKSetFetcher(options...)
decoded, err := jws.Verify(signed, jws.WithVerifyAuto(f), jws.WithMessage(m))

decoded, err := jws.Verify(signed, jws.WithVerifyAuto(nil, options...), jws.WithMessage(m))
if tc.Error {
if !assert.Error(t, err, `jws.Verify should fail`) {
return
Expand Down
20 changes: 11 additions & 9 deletions jws/key_provider.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jws

import (
"context"
"fmt"
"net/url"
"sync"
Expand All @@ -10,7 +11,7 @@ import (
)

type KeyProvider interface {
FetchKeys(KeySink, *Signature) error
FetchKeys(context.Context, KeySink, *Signature) error
}

type KeySink interface {
Expand Down Expand Up @@ -38,7 +39,7 @@ type staticKeyProvider struct {
key interface{}
}

func (kp *staticKeyProvider) FetchKeys(sink KeySink, _ *Signature) error {
func (kp *staticKeyProvider) FetchKeys(_ context.Context, sink KeySink, _ *Signature) error {
sink.Key(kp.alg, kp.key)
return nil
}
Expand Down Expand Up @@ -91,7 +92,7 @@ func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, sig *Signature) e
return nil
}

func (kp *keySetProvider) FetchKeys(sink KeySink, sig *Signature) error {
func (kp *keySetProvider) FetchKeys(_ context.Context, sink KeySink, sig *Signature) error {
if kp.requireKid {
var key jwk.Key

Expand Down Expand Up @@ -130,10 +131,11 @@ func (kp *keySetProvider) FetchKeys(sink KeySink, sig *Signature) error {
}

type jkuProvider struct {
fetcher JWKSetFetcher
fetcher jwk.SetFetcher
options []jwk.FetchOption
}

func (kp jkuProvider) FetchKeys(sink KeySink, sig *Signature) error {
func (kp jkuProvider) FetchKeys(ctx context.Context, sink KeySink, sig *Signature) error {
kid := sig.ProtectedHeaders().KeyID()
if kid == "" {
return nil
Expand All @@ -154,7 +156,7 @@ func (kp jkuProvider) FetchKeys(sink KeySink, sig *Signature) error {
return fmt.Errorf(`url in "jku" must be HTTPS`)
}

set, err := kp.fetcher.Fetch(u)
set, err := kp.fetcher.Fetch(ctx, u, kp.options...)
if err != nil {
return fmt.Errorf(`failed to fetch %q: %w`, u, err)
}
Expand Down Expand Up @@ -183,8 +185,8 @@ func (kp jkuProvider) FetchKeys(sink KeySink, sig *Signature) error {
return nil
}

type KeyProviderFunc func(KeySink, *Signature) error
type KeyProviderFunc func(context.Context, KeySink, *Signature) error

func (kp KeyProviderFunc) FetchKeys(sink KeySink, sig *Signature) error {
return kp(sink, sig)
func (kp KeyProviderFunc) FetchKeys(ctx context.Context, sink KeySink, sig *Signature) error {
return kp(ctx, sink, sig)
}
Loading

0 comments on commit a222ffb

Please sign in to comment.