diff --git a/clients/go/admin/mocks/TokenSource.go b/clients/go/admin/mocks/TokenSource.go new file mode 100644 index 000000000..60cc87236 --- /dev/null +++ b/clients/go/admin/mocks/TokenSource.go @@ -0,0 +1,54 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + oauth2 "golang.org/x/oauth2" +) + +// TokenSource is an autogenerated mock type for the TokenSource type +type TokenSource struct { + mock.Mock +} + +type TokenSource_Token struct { + *mock.Call +} + +func (_m TokenSource_Token) Return(_a0 *oauth2.Token, _a1 error) *TokenSource_Token { + return &TokenSource_Token{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *TokenSource) OnToken() *TokenSource_Token { + c_call := _m.On("Token") + return &TokenSource_Token{Call: c_call} +} + +func (_m *TokenSource) OnTokenMatch(matchers ...interface{}) *TokenSource_Token { + c_call := _m.On("Token", matchers...) + return &TokenSource_Token{Call: c_call} +} + +// Token provides a mock function with given fields: +func (_m *TokenSource) Token() (*oauth2.Token, error) { + ret := _m.Called() + + var r0 *oauth2.Token + if rf, ok := ret.Get(0).(func() *oauth2.Token); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*oauth2.Token) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/clients/go/admin/token_source_provider.go b/clients/go/admin/token_source_provider.go index c2a520d70..937b6d901 100644 --- a/clients/go/admin/token_source_provider.go +++ b/clients/go/admin/token_source_provider.go @@ -23,6 +23,11 @@ import ( "github.com/flyteorg/flytestdlib/logger" ) +//go:generate mockery -name TokenSource +type TokenSource interface { + Token() (*oauth2.Token, error) +} + const ( audienceKey = "audience" ) @@ -68,7 +73,7 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T } } - tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, audienceValue) + tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, tokenCache, audienceValue) if err != nil { return nil, err } @@ -163,10 +168,12 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke type ClientCredentialsTokenSourceProvider struct { ccConfig clientcredentials.Config - TokenRefreshWindow time.Duration + tokenRefreshWindow time.Duration + tokenCache cache.TokenCache } -func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, audience string) (TokenSourceProvider, error) { +func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, + tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) { var secret string if len(cfg.ClientSecretEnvVar) > 0 { secret = os.Getenv(cfg.ClientSecretEnvVar) @@ -183,6 +190,9 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s endpointParams = url.Values{audienceKey: {audience}} } secret = strings.TrimSpace(secret) + if tokenCache == nil { + tokenCache = &cache.TokenCacheInMemoryProvider{} + } return ClientCredentialsTokenSourceProvider{ ccConfig: clientcredentials.Config{ ClientID: cfg.ClientID, @@ -191,56 +201,81 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s Scopes: scopes, EndpointParams: endpointParams, }, - TokenRefreshWindow: cfg.TokenRefreshWindow.Duration}, nil + tokenRefreshWindow: cfg.TokenRefreshWindow.Duration, + tokenCache: tokenCache}, nil } func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) { - if p.TokenRefreshWindow > 0 { + if p.tokenRefreshWindow > 0 { source := p.ccConfig.TokenSource(ctx) return &customTokenSource{ + ctx: ctx, new: source, mu: sync.Mutex{}, - t: nil, - tokenRefreshWindow: p.TokenRefreshWindow, + tokenRefreshWindow: p.tokenRefreshWindow, + tokenCache: p.tokenCache, }, nil } return p.ccConfig.TokenSource(ctx), nil } type customTokenSource struct { + ctx context.Context new oauth2.TokenSource mu sync.Mutex // guards everything else - t *oauth2.Token refreshTime time.Time failedToRefresh bool tokenRefreshWindow time.Duration + tokenCache cache.TokenCache +} + +// fetchTokenFromCache returns the cached token if available, and a bool indicating if we should try to refresh it. +// This function is not thread safe and should be called with the lock held. +func (s *customTokenSource) fetchTokenFromCache() (*oauth2.Token, bool) { + token, err := s.tokenCache.GetToken() + if err != nil { + logger.Infof(s.ctx, "no token found in cache") + return nil, false + } + if !token.Valid() { + logger.Infof(s.ctx, "cached token invalid") + return nil, false + } + if time.Now().After(s.refreshTime) && !s.failedToRefresh { + logger.Infof(s.ctx, "cached token refresh window exceeded") + return token, true + } + logger.Infof(s.ctx, "using cached token") + return token, false } func (s *customTokenSource) Token() (*oauth2.Token, error) { s.mu.Lock() defer s.mu.Unlock() - if s.t.Valid() { - if time.Now().After(s.refreshTime) && !s.failedToRefresh { - t, err := s.new.Token() - if err != nil { - s.failedToRefresh = true // don't try to refresh again before expiry - return s.t, nil - } - s.t = t - s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow)) - s.failedToRefresh = false - return s.t, nil - } - return s.t, nil + + cachedToken, needsRefresh := s.fetchTokenFromCache() + if cachedToken != nil && !needsRefresh { + return cachedToken, nil } - t, err := s.new.Token() + + token, err := s.new.Token() if err != nil { + if needsRefresh { + logger.Warnf(s.ctx, "failed to refresh token, using last cached token until expired") + s.failedToRefresh = true + return cachedToken, nil + } + logger.Errorf(s.ctx, "failed to refresh token") return nil, err } - s.t = t + logger.Infof(s.ctx, "refreshed token") + err = s.tokenCache.SaveToken(token) + if err != nil { + logger.Warnf(s.ctx, "failed to cache token, using anyway") + } s.failedToRefresh = false - s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow)) - return t, nil + s.refreshTime = token.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow)) + return token, nil } // Get random duration between 0 and maxDuration diff --git a/clients/go/admin/token_source_test.go b/clients/go/admin/token_source_test.go index 9256e5e88..745237759 100644 --- a/clients/go/admin/token_source_test.go +++ b/clients/go/admin/token_source_test.go @@ -2,8 +2,10 @@ package admin import ( "context" + "fmt" "net/url" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -12,6 +14,7 @@ import ( tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/config" ) type DummyTestTokenSource struct { @@ -95,3 +98,182 @@ func TestNewTokenSourceProvider(t *testing.T) { assert.Equal(t, url.Values{audienceKey: {test.expectedAudience}}, clientCredSourceProvider.ccConfig.EndpointParams) } } + +func TestCustomTokenSource_fetchTokenFromCache(t *testing.T) { + ctx := context.Background() + cfg := GetConfig(ctx) + cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} + cfg.ClientSecretLocation = "" + + minuteAgo := time.Now().Add(-time.Minute) + hourAhead := time.Now().Add(time.Hour) + invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} + validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} + + tests := []struct { + name string + refreshTime time.Time + failedToRefresh bool + token *oauth2.Token + expectToken bool + expectNeedsRefresh bool + }{ + { + name: "no token", + refreshTime: hourAhead, + failedToRefresh: false, + token: nil, + expectToken: false, + expectNeedsRefresh: false, + }, + { + name: "invalid token", + refreshTime: hourAhead, + failedToRefresh: false, + token: &invalidToken, + expectToken: false, + expectNeedsRefresh: false, + }, + { + name: "refresh exceeded", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &validToken, + expectToken: false, + expectNeedsRefresh: false, + }, + { + name: "refresh exceeded failed", + refreshTime: minuteAgo, + failedToRefresh: true, + token: &validToken, + expectToken: false, + expectNeedsRefresh: false, + }, + { + name: "valid token", + refreshTime: hourAhead, + failedToRefresh: false, + token: &validToken, + expectToken: false, + expectNeedsRefresh: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tokenCache := &tokenCacheMocks.TokenCache{} + provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") + assert.NoError(t, err) + source, err := provider.GetTokenSource(ctx) + assert.NoError(t, err) + customSource, ok := source.(*customTokenSource) + assert.True(t, ok) + + customSource.refreshTime = test.refreshTime + customSource.failedToRefresh = test.failedToRefresh + var tokenErr error = nil + if test.token == nil { + tokenErr = fmt.Errorf("no token") + } + tokenCache.OnGetToken().Return(nil, tokenErr).Once() + token, needsRefresh := customSource.fetchTokenFromCache() + if test.expectToken { + assert.NotNil(t, token) + } else { + assert.Nil(t, token) + } + assert.Equal(t, test.expectNeedsRefresh, needsRefresh) + }) + } +} + +func TestCustomTokenSource_Token(t *testing.T) { + ctx := context.Background() + cfg := GetConfig(ctx) + cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} + cfg.ClientSecretLocation = "" + + minuteAgo := time.Now().Add(-time.Minute) + hourAhead := time.Now().Add(time.Hour) + twoHourAhead := time.Now().Add(2 * time.Hour) + invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} + validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} + newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead} + + tests := []struct { + name string + refreshTime time.Time + failedToRefresh bool + token *oauth2.Token + newToken *oauth2.Token + expectedToken *oauth2.Token + }{ + { + name: "cached token", + refreshTime: hourAhead, + failedToRefresh: false, + token: &validToken, + newToken: nil, + expectedToken: &validToken, + }, + { + name: "failed refresh still valid", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &validToken, + newToken: nil, + expectedToken: &validToken, + }, + { + name: "failed refresh invalid", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &invalidToken, + newToken: nil, + expectedToken: nil, + }, + { + name: "refresh", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &invalidToken, + newToken: &newToken, + expectedToken: &newToken, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tokenCache := &tokenCacheMocks.TokenCache{} + provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") + assert.NoError(t, err) + source, err := provider.GetTokenSource(ctx) + assert.NoError(t, err) + customSource, ok := source.(*customTokenSource) + assert.True(t, ok) + + mockSource := &adminMocks.TokenSource{} + if test.newToken != nil { + mockSource.OnToken().Return(test.newToken, nil) + } else { + mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed")) + } + customSource.new = mockSource + customSource.refreshTime = test.refreshTime + customSource.failedToRefresh = test.failedToRefresh + tokenCache.OnGetToken().Return(test.token, nil).Once() + if test.newToken != nil { + tokenCache.OnSaveToken(test.newToken).Return(nil).Once() + } + token, err := source.Token() + if test.expectedToken != nil { + assert.Equal(t, test.expectedToken, token) + assert.NoError(t, err) + } else { + assert.Nil(t, token) + assert.Error(t, err) + } + }) + } +}