From 975cca807777bff267c9c099456a31d58f262a59 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 29 Aug 2024 11:05:38 -0500 Subject: [PATCH] Refactored token filecache The token filecache used to use a private global function for creating a filelock, and overrode it in tests with a hand-crafted mocks for filesystem and environment variable operations. This change adds adds injectability to the filecache's filesystem and file lock using afero. This change also will simplify future changes when updating the AWS SDK with new credential interfaces. Signed-off-by: Micah Hausler --- go.mod | 2 +- pkg/{token => filecache}/filecache.go | 173 +++++---- pkg/{token => filecache}/filecache_test.go | 407 ++++++++++++--------- pkg/token/token.go | 5 +- 4 files changed, 327 insertions(+), 260 deletions(-) rename pkg/{token => filecache}/filecache.go (71%) rename pkg/{token => filecache}/filecache_test.go (55%) diff --git a/go.mod b/go.mod index eb704fd2c..3d866c974 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/manifoldco/promptui v0.9.0 github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/afero v1.11.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.18.2 golang.org/x/time v0.5.0 @@ -61,7 +62,6 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect diff --git a/pkg/token/filecache.go b/pkg/filecache/filecache.go similarity index 71% rename from pkg/token/filecache.go rename to pkg/filecache/filecache.go index e1a0c2a84..41597edaa 100644 --- a/pkg/token/filecache.go +++ b/pkg/filecache/filecache.go @@ -1,4 +1,4 @@ -package token +package filecache import ( "context" @@ -12,68 +12,22 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" + "github.com/spf13/afero" "gopkg.in/yaml.v2" ) // env variable name for custom credential cache file location const cacheFileNameEnv = "AWS_IAM_AUTHENTICATOR_CACHE_FILE" -// A mockable filesystem interface -var f filesystem = osFS{} - -type filesystem interface { - Stat(filename string) (os.FileInfo, error) - ReadFile(filename string) ([]byte, error) - WriteFile(filename string, data []byte, perm os.FileMode) error - MkdirAll(path string, perm os.FileMode) error -} - -// default os based implementation -type osFS struct{} - -func (osFS) Stat(filename string) (os.FileInfo, error) { - return os.Stat(filename) -} - -func (osFS) ReadFile(filename string) ([]byte, error) { - return os.ReadFile(filename) -} - -func (osFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - return os.WriteFile(filename, data, perm) -} - -func (osFS) MkdirAll(path string, perm os.FileMode) error { - return os.MkdirAll(path, perm) -} - -// A mockable environment interface -var e environment = osEnv{} - -type environment interface { - Getenv(key string) string - LookupEnv(key string) (string, bool) -} - -// default os based implementation -type osEnv struct{} - -func (osEnv) Getenv(key string) string { - return os.Getenv(key) -} - -func (osEnv) LookupEnv(key string) (string, bool) { - return os.LookupEnv(key) -} - -// A mockable flock interface -type filelock interface { +// FileLocker is a subset of the methods exposed by *flock.Flock +type FileLocker interface { Unlock() error TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) } -var newFlock = func(filename string) filelock { +// NewFileLocker returns a *flock.Flock that satisfies FileLocker +func NewFileLocker(filename string) FileLocker { return flock.New(filename) } @@ -135,11 +89,11 @@ func (c *cachedCredential) IsExpired() bool { // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. -func readCacheWhileLocked(filename string) (cache cacheFile, err error) { +func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ map[string]map[string]map[string]cachedCredential{}, } - data, err := f.ReadFile(filename) + data, err := afero.ReadFile(fs, filename) if err != nil { err = fmt.Errorf("unable to open file %s: %v", filename, err) return @@ -155,45 +109,86 @@ func readCacheWhileLocked(filename string) (cache cacheFile, err error) { // writeCacheWhileLocked writes the contents of the credential cache using the // yaml marshaled form of the passed cacheFile object. This method must be // called while an exclusive lock is held on the filename. -func writeCacheWhileLocked(filename string, cache cacheFile) error { +func writeCacheWhileLocked(fs afero.Fs, filename string, cache cacheFile) error { data, err := yaml.Marshal(cache) if err == nil { // write privately owned by the user - err = f.WriteFile(filename, data, 0600) + err = afero.WriteFile(fs, filename, data, 0600) } return err } -// FileCacheProvider is a Provider implementation that wraps an underlying Provider +type FileCacheOpt func(*FileCacheProvider) + +// WithFs returns a FileCacheOpt that sets the cache's filesystem +func WithFs(fs afero.Fs) FileCacheOpt { + return func(p *FileCacheProvider) { + p.fs = fs + } +} + +// WithFilename returns a FileCacheOpt that sets the cache's file +func WithFilename(filename string) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filename = filename + } +} + +// WithFileLockCreator returns a FileCacheOpt that sets the cache's FileLocker +// creation function +func WithFileLockerCreator(f func(string) FileLocker) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filelockCreator = f + } +} + +// FileCacheProvider is a credentials.Provider implementation that wraps an underlying Provider // (contained in Credentials) and provides caching support for credentials for the // specified clusterID, profile, and roleARN (contained in cacheKey) type FileCacheProvider struct { + fs afero.Fs + filelockCreator func(string) FileLocker + filename string credentials *credentials.Credentials // the underlying implementation that has the *real* Provider cacheKey cacheKey // cache key parameters used to create Provider cachedCredential cachedCredential // the cached credential, if it exists } +var _ credentials.Provider = &FileCacheProvider{} + // NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials) (FileCacheProvider, error) { +func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { if creds == nil { - return FileCacheProvider{}, errors.New("no underlying Credentials object provided") + return nil, errors.New("no underlying Credentials object provided") + } + + resp := &FileCacheProvider{ + fs: afero.NewOsFs(), + filelockCreator: NewFileLocker, + filename: defaultCacheFilename(), + credentials: creds, + cacheKey: cacheKey{clusterID, profile, roleARN}, + cachedCredential: cachedCredential{}, } - filename := CacheFilename() - cacheKey := cacheKey{clusterID, profile, roleARN} - cachedCredential := cachedCredential{} + + // override defaults + for _, opt := range opts { + opt(resp) + } + // ensure path to cache file exists - _ = f.MkdirAll(filepath.Dir(filename), 0700) - if info, err := f.Stat(filename); err == nil { + _ = resp.fs.MkdirAll(filepath.Dir(resp.filename), 0700) + if info, err := resp.fs.Stat(resp.filename); err == nil { if info.Mode()&0077 != 0 { // cache file has secret credentials and should only be accessible to the user, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("cache file %s is not private", filename) + return nil, fmt.Errorf("cache file %s is not private", resp.filename) } // do file locking on cache to prevent inconsistent reads - lock := newFlock(filename) + lock := resp.filelockCreator(resp.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -201,30 +196,26 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials ok, err := lock.TryRLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // unable to lock the cache, something is wrong, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("unable to read lock file %s: %v", filename, err) + return nil, fmt.Errorf("unable to read lock file %s: %v", resp.filename, err) } - cache, err := readCacheWhileLocked(filename) + cache, err := readCacheWhileLocked(resp.fs, resp.filename) if err != nil { // can't read or parse cache, refuse to use it. - return FileCacheProvider{}, err + return nil, err } - cachedCredential = cache.Get(cacheKey) + resp.cachedCredential = cache.Get(resp.cacheKey) } else { if errors.Is(err, fs.ErrNotExist) { // cache file is missing. maybe this is the very first run? continue to use cache. - _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", filename) + _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", resp.filename) } else { - return FileCacheProvider{}, fmt.Errorf("couldn't stat cache file: %w", err) + return nil, fmt.Errorf("couldn't stat cache file: %w", err) } } - return FileCacheProvider{ - creds, - cacheKey, - cachedCredential, - }, nil + return resp, nil } // Retrieve() implements the Provider interface, returning the cached credential if is not expired, @@ -243,9 +234,9 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { } if expiration, err := f.credentials.ExpiresAt(); err == nil { // underlying provider supports Expirer interface, so we can cache - filename := CacheFilename() + // do file locking on cache to prevent inconsistent writes - lock := newFlock(filename) + lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -253,7 +244,7 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) return credential, nil } f.cachedCredential = cachedCredential{ @@ -262,12 +253,12 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { nil, } // don't really care about read error. Either read the cache, or we create a new cache. - cache, _ := readCacheWhileLocked(filename) + cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) - err = writeCacheWhileLocked(filename, cache) + err = writeCacheWhileLocked(f.fs, f.filename, cache) if err != nil { // can't write cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", f.filename, err) err = nil } else { _, _ = fmt.Fprintf(os.Stderr, "Updated cached credential\n") @@ -292,23 +283,23 @@ func (f *FileCacheProvider) ExpiresAt() time.Time { return f.cachedCredential.Expiration } -// CacheFilename returns the name of the credential cache file, which can either be +// defaultCacheFilename returns the name of the credential cache file, which can either be // set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml -func CacheFilename() string { - if filename, ok := e.LookupEnv(cacheFileNameEnv); ok { +func defaultCacheFilename() string { + if filename := os.Getenv(cacheFileNameEnv); filename != "" { return filename } else { - return filepath.Join(UserHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") + return filepath.Join(userHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") } } -// UserHomeDir returns the home directory for the user the process is +// userHomeDir returns the home directory for the user the process is // running under. -func UserHomeDir() string { +func userHomeDir() string { if runtime.GOOS == "windows" { // Windows - return e.Getenv("USERPROFILE") + return os.Getenv("USERPROFILE") } // *nix - return e.Getenv("HOME") + return os.Getenv("HOME") } diff --git a/pkg/token/filecache_test.go b/pkg/filecache/filecache_test.go similarity index 55% rename from pkg/token/filecache_test.go rename to pkg/filecache/filecache_test.go index d69c75937..fd87bdf7c 100644 --- a/pkg/token/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,21 +1,32 @@ -package token +package filecache import ( "bytes" "context" "errors" - "github.com/aws/aws-sdk-go/aws/credentials" + "fmt" + "io/fs" "os" "testing" "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/spf13/afero" +) + +const ( + testFilename = "/test.yaml" ) +// stubProvider implements credentials.Provider with configurable response values type stubProvider struct { creds credentials.Value expired bool err error } +var _ credentials.Provider = &stubProvider{} + func (s *stubProvider) Retrieve() (credentials.Value, error) { s.expired = false s.creds.ProviderName = "stubProvider" @@ -26,89 +37,54 @@ func (s *stubProvider) IsExpired() bool { return s.expired } +// stubProviderExpirer implements credentials.Expirer with configurable expiration type stubProviderExpirer struct { stubProvider expiration time.Time } +var _ credentials.Expirer = &stubProviderExpirer{} + func (s *stubProviderExpirer) ExpiresAt() time.Time { return s.expiration } +// testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string size int64 - mode os.FileMode + mode fs.FileMode modTime time.Time } +var _ fs.FileInfo = &testFileInfo{} + func (fs *testFileInfo) Name() string { return fs.name } func (fs *testFileInfo) Size() int64 { return fs.size } -func (fs *testFileInfo) Mode() os.FileMode { return fs.mode } +func (fs *testFileInfo) Mode() fs.FileMode { return fs.mode } func (fs *testFileInfo) ModTime() time.Time { return fs.modTime } func (fs *testFileInfo) IsDir() bool { return fs.Mode().IsDir() } func (fs *testFileInfo) Sys() interface{} { return nil } +// testFs wraps afero.Fs with an overridable Stat() method type testFS struct { - filename string - fileinfo testFileInfo - data []byte + afero.Fs + + fileinfo fs.FileInfo err error - perm os.FileMode } -func (t *testFS) Stat(filename string) (os.FileInfo, error) { - t.filename = filename - if t.err == nil { - return &t.fileinfo, nil - } else { +func (t *testFS) Stat(filename string) (fs.FileInfo, error) { + if t.err != nil { return nil, t.err } + if t.fileinfo != nil { + return t.fileinfo, nil + } + return t.Fs.Stat(filename) } -func (t *testFS) ReadFile(filename string) ([]byte, error) { - t.filename = filename - return t.data, t.err -} - -func (t *testFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - t.filename = filename - t.data = data - t.perm = perm - return t.err -} - -func (t *testFS) MkdirAll(path string, perm os.FileMode) error { - t.filename = path - t.perm = perm - return t.err -} - -func (t *testFS) reset() { - t.filename = "" - t.fileinfo = testFileInfo{} - t.data = []byte{} - t.err = nil - t.perm = 0600 -} - -type testEnv struct { - values map[string]string -} - -func (e *testEnv) Getenv(key string) string { - return e.values[key] -} - -func (e *testEnv) LookupEnv(key string) (string, bool) { - value, ok := e.values[key] - return value, ok -} - -func (e *testEnv) reset() { - e.values = map[string]string{} -} - +// testFileLock implements FileLocker with configurable response options type testFilelock struct { ctx context.Context retryDelay time.Duration @@ -116,6 +92,8 @@ type testFilelock struct { err error } +var _ FileLocker = &testFilelock{} + func (l *testFilelock) Unlock() error { return nil } @@ -132,28 +110,12 @@ func (l *testFilelock) TryRLockContext(ctx context.Context, retryDelay time.Dura return l.success, l.err } -func (l *testFilelock) reset() { - l.ctx = context.TODO() - l.retryDelay = 0 - l.success = true - l.err = nil -} - -func getMocks() (tf *testFS, te *testEnv, testFlock *testFilelock) { - tf = &testFS{} - tf.reset() - f = tf - te = &testEnv{} - te.reset() - e = te - testFlock = &testFilelock{} - testFlock.reset() - newFlock = func(filename string) filelock { - return testFlock - } - return +// getMocks returns a mocked filesystem and FileLocker +func getMocks() (*testFS, *testFilelock) { + return &testFS{Fs: afero.NewMemMapFs()}, &testFilelock{context.TODO(), 0, true, nil} } +// makeCredential returns a dummy AWS crdential func makeCredential() credentials.Value { return credentials.Value{ AccessKeyID: "AKID", @@ -163,7 +125,9 @@ func makeCredential() credentials.Value { } } -func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c *credentials.Credentials) { +// validateFileCacheProvider ensures that the cache provider is properly initialized +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { + t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -181,21 +145,37 @@ func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c * } } +// testSetEnv sets an env var, and returns a cleanup func +func testSetEnv(t *testing.T, key, value string) func() { + t.Helper() + old := os.Getenv(key) + os.Setenv(key, value) + return func() { + if old == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, old) + } + } +} + func TestCacheFilename(t *testing.T) { - _, te, _ := getMocks() - te.values["HOME"] = "homedir" // unix - te.values["USERPROFILE"] = "homedir" // windows + c1 := testSetEnv(t, "HOME", "homedir") + defer c1() + c2 := testSetEnv(t, "USERPROFILE", "homedir") + defer c2() - filename := CacheFilename() + filename := defaultCacheFilename() expected := "homedir/.kube/cache/aws-iam-authenticator/credentials.yaml" if filename != expected { t.Errorf("Incorrect default cacheFilename, expected %s, got %s", expected, filename) } - te.values["AWS_IAM_AUTHENTICATOR_CACHE_FILE"] = "special.yaml" - filename = CacheFilename() + c3 := testSetEnv(t, "AWS_IAM_AUTHENTICATOR_CACHE_FILE", "special.yaml") + defer c3() + filename = defaultCacheFilename() expected = "special.yaml" if filename != expected { t.Errorf("Incorrect custom cacheFilename, expected %s, got %s", @@ -206,85 +186,131 @@ func TestCacheFilename(t *testing.T) { func TestNewFileCacheProvider_Missing(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, tfl := getMocks() - // missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + })) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing cache file should result in expired cached credential") } - tf.err = nil } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, _ := getMocks() + // afero.MemMapFs always returns tempfile FileInfo, + // so we manually set the response to the Stat() call + tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - tf.fileinfo.mode = 0777 - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + ) if err == nil { t.Errorf("Expected error due to public permissions") } - if tf.filename != CacheFilename() { - t.Errorf("unexpected file checked, expected %s, got %s", - CacheFilename(), tf.filename) + wantMsg := fmt.Sprintf("cache file %s is not private", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } } func TestNewFileCacheProvider_Unlockable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, testFlock := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // unable to lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to lock failure") } - testFlock.success = true - testFlock.err = nil } func TestNewFileCacheProvider_Unreadable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, _ := getMocks() + tfs.Create(testFilename) + // afero.MemMapFs always returns tempfile FileInfo, + // so we manually set the response to the Stat() call + tfs.fileinfo = &testFileInfo{mode: 0000} - // unable to read existing cache - tf.err = errors.New("read failure") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + ) if err == nil { t.Errorf("Expected error due to read failure") + return + } + wantMsg := fmt.Sprintf("unable to read lock file %s: open %s: read-only file system", testFilename, testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } - tf.err = nil } func TestNewFileCacheProvider_Unparseable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // unable to parse yaml - tf.data = []byte("invalid: yaml: file") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + afero.WriteFile( + tfs, + testFilename, + []byte("invalid: yaml: file"), + 0700) + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to bad yaml") } + wantMsg := fmt.Sprintf("unable to parse file %s: yaml: mapping values are not allowed in this context", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) + } } func TestNewFileCacheProvider_Empty(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, _ = getMocks() + tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("empty cache file should result in expired cached credential") @@ -294,13 +320,24 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse existing cluster without matching arn - tf.data = []byte(`clusters: + tfs, tfl := getMocks() + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: CLUSTER: -`) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + ARN2: {} +`), + 0700) + // successfully parse existing cluster without matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing arn in cache file should result in expired cached credential") @@ -310,10 +347,7 @@ func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { func TestNewFileCacheProvider_ExistingARN(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -324,11 +358,27 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { providername: JKL expiration: 2018-01-02T03:04:56.789Z `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // successfully parse cluster with matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + }), + ) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { - t.Errorf("cached credential not extracted correctly") + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } // fiddle with clock p.cachedCredential.currentTime = func() time.Time { @@ -353,11 +403,17 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { creds: providerCredential, }) - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) credential, err := p.Retrieve() @@ -370,6 +426,7 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { } } +// makeExpirerCredentials returns an expiring credential func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { providerCredential = makeCredential() expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) @@ -385,17 +442,23 @@ func makeExpirerCredentials() (providerCredential credentials.Value, expiration func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, testFlock := getMocks() + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) validateFileCacheProvider(t, p, err, c) // retrieve credential, which will fetch from underlying Provider // fail to get write lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -409,16 +472,19 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providerCredential, expiration, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - // retrieve credential, which will fetch from underlying Provider - // fail to write cache - tf.err = errors.New("can't write cache") credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -427,14 +493,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { t.Errorf("Cache did not return provider credential, got %v, expected %v", credential, providerCredential) } - if tf.filename != CacheFilename() { - t.Errorf("Wrote to wrong file, expected %v, got %v", - CacheFilename(), tf.filename) - } - if tf.perm != 0600 { - t.Errorf("Wrote with wrong permissions, expected %o, got %o", - 0600, tf.perm) - } + expectedData := []byte(`clusters: CLUSTER: PROFILE: @@ -446,22 +505,31 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providername: stubProvider expiration: ` + expiration.Format(time.RFC3339Nano) + ` `) - if bytes.Compare(tf.data, expectedData) != 0 { + got, err := afero.ReadFile(tfs, testFilename) + if err != nil { + t.Errorf("unexpected error reading generated file: %v", err) + } + if !bytes.Equal(got, expectedData) { t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, tf.data) + expectedData, got) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - tf.err = nil // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -478,11 +546,13 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) + currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - tf, _, _ := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -491,15 +561,20 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { secretaccesskey: DEF sessiontoken: GHI providername: JKL - expiration: 2018-01-02T03:04:56.789Z + expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + })) validateFileCacheProvider(t, p, err, c) // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } + p.cachedCredential.currentTime = func() time.Time { return currentTime } credential, err := p.Retrieve() if err != nil { diff --git a/pkg/token/token.go b/pkg/token/token.go index 16ab8d92b..d9d7fd2e8 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -44,6 +44,7 @@ import ( clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -247,8 +248,8 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { - sess.Config.Credentials = credentials.NewCredentials(&cacheProvider) + if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) }