Skip to content

Commit

Permalink
Filecache WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
micahhausler committed Aug 29, 2024
1 parent 6c21d4d commit 5d3f1e1
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 110 deletions.
55 changes: 0 additions & 55 deletions pkg/filecache/converter.go

This file was deleted.

35 changes: 8 additions & 27 deletions pkg/filecache/filecache.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/gofrs/flock"
"github.com/spf13/afero"
"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -136,7 +135,7 @@ type FileCacheProvider struct {
cachedCredential aws.Credentials // the cached credential, if it exists
}

var _ credentials.Provider = &FileCacheProvider{}
var _ aws.CredentialsProvider = &FileCacheProvider{}

// NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials,
// and works with an on disk cache to speed up credential usage when the cached copy is not expired.
Expand Down Expand Up @@ -200,26 +199,19 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.Crede
return resp, nil
}

// Retrieve() implements the Provider interface, returning the cached credential if is not expired,
// otherwise fetching the credential from the underlying Provider and caching the results on disk
// Retrieve() implements the aws.CredentialsProvider interface, returning the cached credential if is not expired,
// otherwise fetching the credential from the underlying CredentialProvider and caching the results on disk
// with an expiration time.
func (f *FileCacheProvider) Retrieve() (credentials.Value, error) {
return f.RetrieveWithContext(context.Background())
}

// Retrieve() implements the Provider interface, returning the cached credential if is not expired,
// otherwise fetching the credential from the underlying Provider and caching the results on disk
// with an expiration time.
func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
func (f *FileCacheProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() {
// use the cached credential
return V2CredentialToV1Value(f.cachedCredential), nil
return f.cachedCredential, nil
} else {
_, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n")
// fetch the credentials from the underlying Provider
credential, err := f.provider.Retrieve(ctx)
if err != nil {
return V2CredentialToV1Value(credential), err
return credential, err
}

if credential.CanExpire {
Expand All @@ -235,7 +227,7 @@ func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credential
if !ok {
// can't get write lock to create/update cache, but still return the credential
_, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err)
return V2CredentialToV1Value(credential), nil
return credential, nil
}
f.cachedCredential = credential
// don't really care about read error. Either read the cache, or we create a new cache.
Expand All @@ -254,21 +246,10 @@ func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credential
_, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err)
err = nil
}
return V2CredentialToV1Value(credential), err
return credential, err
}
}

// IsExpired() implements the Provider interface, deferring to the cached credential first,
// but fall back to the underlying Provider if it is expired.
func (f *FileCacheProvider) IsExpired() bool {
return f.cachedCredential.CanExpire && f.cachedCredential.Expired()
}

// ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential
func (f *FileCacheProvider) ExpiresAt() time.Time {
return f.cachedCredential.Expires
}

// defaultCacheFilename returns the name of the credential cache file, which can either be
// set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml
func defaultCacheFilename() string {
Expand Down
29 changes: 10 additions & 19 deletions pkg/filecache/filecache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,6 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) {
t.Errorf("Cached credential should not be expired")
}

if p.ExpiresAt() != p.cachedCredential.Expires {
t.Errorf("Credential expiration time is not correct, expected %v, got %v",
p.cachedCredential.Expires, p.ExpiresAt())
}
}

func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) {
Expand All @@ -407,7 +403,7 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) {
)
validateFileCacheProvider(t, p, err, provider)

credential, err := p.Retrieve()
credential, err := p.Retrieve(context.Background())
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
Expand Down Expand Up @@ -442,12 +438,12 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) {
tfl.success = false
tfl.err = errors.New("lock stuck, needs wd-40")

credential, err := p.Retrieve()
credential, err := p.Retrieve(context.Background())
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" ||
credential.SessionToken != "TOKEN" || credential.ProviderName != "stubProvider" {
credential.SessionToken != "TOKEN" || credential.Source != "stubProvider" {
t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential)
}
}
Expand All @@ -471,14 +467,14 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) {
)
validateFileCacheProvider(t, p, err, provider)

credential, err := p.Retrieve()
credential, err := p.Retrieve(context.Background())
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if credential.AccessKeyID != provider.creds.AccessKeyID ||
credential.SecretAccessKey != provider.creds.SecretAccessKey ||
credential.SessionToken != provider.creds.SessionToken ||
credential.ProviderName != provider.creds.Source {
credential.Source != provider.creds.Source {
t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential)
}

Expand Down Expand Up @@ -525,14 +521,14 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) {
// retrieve credential, which will fetch from underlying Provider
// same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable,
// but write to disk (code coverage)
credential, err := p.Retrieve()
credential, err := p.Retrieve(context.Background())
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if credential.AccessKeyID != provider.creds.AccessKeyID ||
credential.SecretAccessKey != provider.creds.SecretAccessKey ||
credential.SessionToken != provider.creds.SessionToken ||
credential.ProviderName != provider.creds.Source {
credential.Source != provider.creds.Source {
t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential)
}
}
Expand Down Expand Up @@ -567,19 +563,14 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) {
}))
validateFileCacheProvider(t, p, err, provider)

credential, err := p.Retrieve()
credential, err := p.Retrieve(context.Background())
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if credential.AccessKeyID != "ABC" || credential.SecretAccessKey != "DEF" ||
credential.SessionToken != "GHI" || credential.ProviderName != "JKL" {
credential.SessionToken != "GHI" || credential.Source != "JKL" ||
!credential.Expires.Equal(currentTime.Add(time.Hour*6)) {
t.Errorf("cached credential not returned")
}

if !p.ExpiresAt().Equal(currentTime.Add(time.Hour * 6)) {
t.Errorf("unexpected expiration time: got %s, wanted %s",
p.ExpiresAt().Format(time.RFC3339Nano),
currentTime.Add(time.Hour*6).Format(time.RFC3339Nano),
)
}
}
21 changes: 12 additions & 9 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,25 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
var profile string
if v := os.Getenv("AWS_PROFILE"); len(v) > 0 {
profile = v
} else {
profile = session.DefaultSharedConfigProfile
}
cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...)


// create a cacheing Provider wrapper around the Credentials
if cacheProvider, err := filecache.NewFileCacheProvider(
// Create a new config to get the default cred chain
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return Token{}, fmt.Errorf("could not create config: %v", err)
}
// create a caching Provider wrapper around the Credentials
cacheProvider, err := filecache.NewFileCacheProvider(
options.ClusterID,
profile,
options.AssumeRoleARN,
cfg.Credentials); err == nil {
sess.Config.Credentials = credentials.NewCredentials(cacheProvider)
cfg.Credentials,
)
if err == nil {
loadOpts = append(loadOpts, config.WithCredentialsProvider(cacheProvider))
} else {
fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err)
}
}

cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...)
if err != nil {
Expand Down

0 comments on commit 5d3f1e1

Please sign in to comment.