Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESD-32688: Improve locking and blocking associated with key retrieval #225

Merged
merged 2 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*.dll
*.so
*.dylib
.DS_Store

# Test binary, built with `go test -c`
*.test
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.19
require (
github.com/google/go-cmp v0.6.0
github.com/stretchr/testify v1.8.4
golang.org/x/sync v0.5.0
gopkg.in/go-jose/go-jose.v2 v2.6.1
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/go-jose/go-jose.v2 v2.6.1 h1:qEzJlIDmG9q5VO0M/o8tGS65QMHMS1w01TQJB1VPJ4U=
Expand Down
37 changes: 32 additions & 5 deletions jwks/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"golang.org/x/sync/semaphore"
"gopkg.in/go-jose/go-jose.v2"

"github.com/auth0/go-jwt-middleware/v2/internal/oidc"
Expand Down Expand Up @@ -97,11 +98,16 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
// CachingProvider handles getting JWKS from the specified IssuerURL
// and caching them for CacheTTL time. It exposes KeyFunc which adheres
// to the keyFunc signature that the Validator requires.
// When the CacheTTL value has been reached, a JWKS refresh will be triggered
// in the background and the existing cached JWKS will be returned until the
// JWKS cache is updated, or if the request errors then it will be evicted from
// the cache.
type CachingProvider struct {
*Provider
CacheTTL time.Duration
mu sync.Mutex
mu sync.RWMutex
cache map[string]cachedJWKS
sem semaphore.Weighted
}

type cachedJWKS struct {
Expand All @@ -120,24 +126,45 @@ func NewCachingProvider(issuerURL *url.URL, cacheTTL time.Duration, opts ...Prov
Provider: NewProvider(issuerURL, opts...),
CacheTTL: cacheTTL,
cache: map[string]cachedJWKS{},
sem: *semaphore.NewWeighted(1),
}
}

// KeyFunc adheres to the keyFunc signature that the Validator requires.
// While it returns an interface to adhere to keyFunc, as long as the
// error is nil the type will be *jose.JSONWebKeySet.
func (c *CachingProvider) KeyFunc(ctx context.Context) (interface{}, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()

issuer := c.IssuerURL.Hostname()

if cached, ok := c.cache[issuer]; ok {
if !time.Now().After(cached.expiresAt) {
return cached.jwks, nil
if time.Now().After(cached.expiresAt) && c.sem.TryAcquire(1) {
go func() {
defer c.sem.Release(1)
refreshCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.refreshKey(refreshCtx, issuer)

if err != nil {
c.mu.Lock()
delete(c.cache, issuer)
c.mu.Unlock()
}
}()
}
c.mu.RUnlock()
return cached.jwks, nil
}

c.mu.RUnlock()
return c.refreshKey(ctx, issuer)
}

func (c *CachingProvider) refreshKey(ctx context.Context, issuer string) (interface{}, error) {
c.mu.Lock()
defer c.mu.Unlock()

jwks, err := c.Provider.KeyFunc(ctx)
if err != nil {
return nil, err
Expand Down
100 changes: 93 additions & 7 deletions jwks/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/go-jose/go-jose.v2"

Expand Down Expand Up @@ -84,7 +85,8 @@ func Test_JWKSProvider(t *testing.T) {
}
})

t.Run("It re-caches the JWKS if they have expired when using CachingProvider", func(t *testing.T) {
t.Run("It eventually re-caches the JWKS if they have expired when using CachingProvider", func(t *testing.T) {
requestCount = 0
expiredCachedJWKS, err := generateJWKS()
require.NoError(t, err)

Expand All @@ -94,16 +96,20 @@ func Test_JWKSProvider(t *testing.T) {
expiresAt: time.Now().Add(-10 * time.Minute),
}

actualJWKS, err := provider.KeyFunc(context.Background())
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)

if !cmp.Equal(expectedJWKS, actualJWKS) {
t.Fatalf("jwks did not match: %s", cmp.Diff(expectedJWKS, actualJWKS))
if !cmp.Equal(expiredCachedJWKS, returnedJWKS) {
t.Fatalf("jwks did not match: %s", cmp.Diff(expiredCachedJWKS, returnedJWKS))
}

if !cmp.Equal(expectedJWKS, provider.cache[testServerURL.Hostname()].jwks) {
t.Fatalf("cached jwks did not match: %s", cmp.Diff(expectedJWKS, provider.cache[testServerURL.Hostname()].jwks))
}
require.EventuallyWithT(t, func(c *assert.CollectT) {
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)

assert.True(c, cmp.Equal(expectedJWKS, returnedJWKS))
assert.Equal(c, int32(2), requestCount)
}, 1*time.Second, 250*time.Millisecond, "JWKS did not update")

cacheExpiresAt := provider.cache[testServerURL.Hostname()].expiresAt
if !time.Now().Before(cacheExpiresAt) {
Expand Down Expand Up @@ -154,6 +160,86 @@ func Test_JWKSProvider(t *testing.T) {
}
},
)

t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with expired cache", func(t *testing.T) {
initialJWKS, err := generateJWKS()
require.NoError(t, err)
requestCount = 0

provider := NewCachingProvider(testServerURL, 5*time.Minute)
provider.cache[testServerURL.Hostname()] = cachedJWKS{
jwks: initialJWKS,
expiresAt: time.Now(),
}

var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
_, _ = provider.KeyFunc(context.Background())
wg.Done()
}()
}
wg.Wait()

require.EventuallyWithT(t, func(c *assert.CollectT) {
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)

assert.True(c, cmp.Equal(expectedJWKS, returnedJWKS))
assert.Equal(c, int32(2), requestCount)
}, 1*time.Second, 250*time.Millisecond, "JWKS did not update")
})

t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with no cache", func(t *testing.T) {
provider := NewCachingProvider(testServerURL, 5*time.Minute)
requestCount = 0

var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
_, _ = provider.KeyFunc(context.Background())
wg.Done()
}()
}
wg.Wait()

if requestCount != 2 {
t.Fatalf("only wanted 2 requests (well known and jwks) , but we got %d requests", requestCount)
}
})

t.Run("Should delete cache entry if the refresh request fails", func(t *testing.T) {
malformedURL, err := url.Parse(testServer.URL + "/malformed")
require.NoError(t, err)

expiredCachedJWKS, err := generateJWKS()
require.NoError(t, err)

provider := NewCachingProvider(malformedURL, 5*time.Minute)
provider.cache[malformedURL.Hostname()] = cachedJWKS{
jwks: expiredCachedJWKS,
expiresAt: time.Now().Add(-10 * time.Minute),
}

// Trigger the refresh of the JWKS, which should return the cached JWKS
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)
assert.Equal(t, expiredCachedJWKS, returnedJWKS)

// Eventually it should return a nil JWKS
require.EventuallyWithT(t, func(c *assert.CollectT) {
returnedJWKS, err := provider.KeyFunc(context.Background())
require.Error(t, err)

assert.Nil(c, returnedJWKS)

cachedJWKS := provider.cache[malformedURL.Hostname()].jwks

assert.Nil(t, cachedJWKS)
}, 1*time.Second, 250*time.Millisecond, "JWKS did not get uncached")
})
}

func generateJWKS() (*jose.JSONWebKeySet, error) {
Expand Down