Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Use TokenCache in ClientCredentialsTokenSourceProvider #377

Merged
merged 6 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions clients/go/admin/mocks/TokenSource.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

85 changes: 60 additions & 25 deletions clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ import (
"github.com/flyteorg/flytestdlib/logger"
)

//go:generate mockery -name TokenSource
type TokenSource interface {
Token() (*oauth2.Token, error)
}

const (
audienceKey = "audience"
)
Expand Down Expand Up @@ -68,7 +73,7 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T
}
}

tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, audienceValue)
tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, tokenCache, audienceValue)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -163,10 +168,12 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke

type ClientCredentialsTokenSourceProvider struct {
ccConfig clientcredentials.Config
TokenRefreshWindow time.Duration
tokenRefreshWindow time.Duration
tokenCache cache.TokenCache
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, audience string) (TokenSourceProvider, error) {
func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string,
tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) {
var secret string
if len(cfg.ClientSecretEnvVar) > 0 {
secret = os.Getenv(cfg.ClientSecretEnvVar)
Expand All @@ -183,6 +190,9 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
endpointParams = url.Values{audienceKey: {audience}}
}
secret = strings.TrimSpace(secret)
if tokenCache == nil {
tokenCache = &cache.TokenCacheInMemoryProvider{}
}
return ClientCredentialsTokenSourceProvider{
ccConfig: clientcredentials.Config{
ClientID: cfg.ClientID,
Expand All @@ -191,56 +201,81 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
Scopes: scopes,
EndpointParams: endpointParams,
},
TokenRefreshWindow: cfg.TokenRefreshWindow.Duration}, nil
tokenRefreshWindow: cfg.TokenRefreshWindow.Duration,
tokenCache: tokenCache}, nil
}

func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
if p.TokenRefreshWindow > 0 {
if p.tokenRefreshWindow > 0 {
source := p.ccConfig.TokenSource(ctx)
return &customTokenSource{
ctx: ctx,
new: source,
mu: sync.Mutex{},
t: nil,
tokenRefreshWindow: p.TokenRefreshWindow,
tokenRefreshWindow: p.tokenRefreshWindow,
tokenCache: p.tokenCache,
}, nil
}
return p.ccConfig.TokenSource(ctx), nil
}

type customTokenSource struct {
ctx context.Context
new oauth2.TokenSource
mu sync.Mutex // guards everything else
t *oauth2.Token
refreshTime time.Time
failedToRefresh bool
tokenRefreshWindow time.Duration
tokenCache cache.TokenCache
}

// fetchTokenFromCache returns the cached token if available, and a bool indicating if we should try to refresh it.
// This function is not thread safe and should be called with the lock held.
func (s *customTokenSource) fetchTokenFromCache() (*oauth2.Token, bool) {
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Infof(s.ctx, "no token found in cache")
return nil, false
}
if !token.Valid() {
logger.Infof(s.ctx, "cached token invalid")
return nil, false
}
if time.Now().After(s.refreshTime) && !s.failedToRefresh {
logger.Infof(s.ctx, "cached token refresh window exceeded")
return token, true
}
logger.Infof(s.ctx, "using cached token")
return token, false
}

func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.t.Valid() {
if time.Now().After(s.refreshTime) && !s.failedToRefresh {
t, err := s.new.Token()
if err != nil {
s.failedToRefresh = true // don't try to refresh again before expiry
return s.t, nil
}
s.t = t
s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
s.failedToRefresh = false
return s.t, nil
}
return s.t, nil

cachedToken, needsRefresh := s.fetchTokenFromCache()
if cachedToken != nil && !needsRefresh {
return cachedToken, nil
}
t, err := s.new.Token()

token, err := s.new.Token()
if err != nil {
if needsRefresh {
logger.Warnf(s.ctx, "failed to refresh token, using last cached token until expired")
s.failedToRefresh = true
return cachedToken, nil
}
logger.Errorf(s.ctx, "failed to refresh token")
return nil, err
}
s.t = t
logger.Infof(s.ctx, "refreshed token")
err = s.tokenCache.SaveToken(token)
if err != nil {
logger.Warnf(s.ctx, "failed to cache token, using anyway")
}
s.failedToRefresh = false
s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
return t, nil
s.refreshTime = token.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
return token, nil
}

// Get random duration between 0 and maxDuration
Expand Down
Loading