diff --git a/go.mod b/go.mod index eb704fd2c..3d866c974 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/manifoldco/promptui v0.9.0 github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/afero v1.11.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.18.2 golang.org/x/time v0.5.0 @@ -61,7 +62,6 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect diff --git a/pkg/token/filecache.go b/pkg/filecache/filecache.go similarity index 71% rename from pkg/token/filecache.go rename to pkg/filecache/filecache.go index e1a0c2a84..41597edaa 100644 --- a/pkg/token/filecache.go +++ b/pkg/filecache/filecache.go @@ -1,4 +1,4 @@ -package token +package filecache import ( "context" @@ -12,68 +12,22 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" + "github.com/spf13/afero" "gopkg.in/yaml.v2" ) // env variable name for custom credential cache file location const cacheFileNameEnv = "AWS_IAM_AUTHENTICATOR_CACHE_FILE" -// A mockable filesystem interface -var f filesystem = osFS{} - -type filesystem interface { - Stat(filename string) (os.FileInfo, error) - ReadFile(filename string) ([]byte, error) - WriteFile(filename string, data []byte, perm os.FileMode) error - MkdirAll(path string, perm os.FileMode) error -} - -// default os based implementation -type osFS struct{} - -func (osFS) Stat(filename string) (os.FileInfo, error) { - return os.Stat(filename) -} - -func (osFS) ReadFile(filename string) ([]byte, error) { - return os.ReadFile(filename) -} - -func (osFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - return os.WriteFile(filename, data, perm) -} - -func (osFS) MkdirAll(path string, perm os.FileMode) error { - return os.MkdirAll(path, perm) -} - -// A mockable environment interface -var e environment = osEnv{} - -type environment interface { - Getenv(key string) string - LookupEnv(key string) (string, bool) -} - -// default os based implementation -type osEnv struct{} - -func (osEnv) Getenv(key string) string { - return os.Getenv(key) -} - -func (osEnv) LookupEnv(key string) (string, bool) { - return os.LookupEnv(key) -} - -// A mockable flock interface -type filelock interface { +// FileLocker is a subset of the methods exposed by *flock.Flock +type FileLocker interface { Unlock() error TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) } -var newFlock = func(filename string) filelock { +// NewFileLocker returns a *flock.Flock that satisfies FileLocker +func NewFileLocker(filename string) FileLocker { return flock.New(filename) } @@ -135,11 +89,11 @@ func (c *cachedCredential) IsExpired() bool { // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. -func readCacheWhileLocked(filename string) (cache cacheFile, err error) { +func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ map[string]map[string]map[string]cachedCredential{}, } - data, err := f.ReadFile(filename) + data, err := afero.ReadFile(fs, filename) if err != nil { err = fmt.Errorf("unable to open file %s: %v", filename, err) return @@ -155,45 +109,86 @@ func readCacheWhileLocked(filename string) (cache cacheFile, err error) { // writeCacheWhileLocked writes the contents of the credential cache using the // yaml marshaled form of the passed cacheFile object. This method must be // called while an exclusive lock is held on the filename. -func writeCacheWhileLocked(filename string, cache cacheFile) error { +func writeCacheWhileLocked(fs afero.Fs, filename string, cache cacheFile) error { data, err := yaml.Marshal(cache) if err == nil { // write privately owned by the user - err = f.WriteFile(filename, data, 0600) + err = afero.WriteFile(fs, filename, data, 0600) } return err } -// FileCacheProvider is a Provider implementation that wraps an underlying Provider +type FileCacheOpt func(*FileCacheProvider) + +// WithFs returns a FileCacheOpt that sets the cache's filesystem +func WithFs(fs afero.Fs) FileCacheOpt { + return func(p *FileCacheProvider) { + p.fs = fs + } +} + +// WithFilename returns a FileCacheOpt that sets the cache's file +func WithFilename(filename string) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filename = filename + } +} + +// WithFileLockCreator returns a FileCacheOpt that sets the cache's FileLocker +// creation function +func WithFileLockerCreator(f func(string) FileLocker) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filelockCreator = f + } +} + +// FileCacheProvider is a credentials.Provider implementation that wraps an underlying Provider // (contained in Credentials) and provides caching support for credentials for the // specified clusterID, profile, and roleARN (contained in cacheKey) type FileCacheProvider struct { + fs afero.Fs + filelockCreator func(string) FileLocker + filename string credentials *credentials.Credentials // the underlying implementation that has the *real* Provider cacheKey cacheKey // cache key parameters used to create Provider cachedCredential cachedCredential // the cached credential, if it exists } +var _ credentials.Provider = &FileCacheProvider{} + // NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials) (FileCacheProvider, error) { +func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { if creds == nil { - return FileCacheProvider{}, errors.New("no underlying Credentials object provided") + return nil, errors.New("no underlying Credentials object provided") + } + + resp := &FileCacheProvider{ + fs: afero.NewOsFs(), + filelockCreator: NewFileLocker, + filename: defaultCacheFilename(), + credentials: creds, + cacheKey: cacheKey{clusterID, profile, roleARN}, + cachedCredential: cachedCredential{}, } - filename := CacheFilename() - cacheKey := cacheKey{clusterID, profile, roleARN} - cachedCredential := cachedCredential{} + + // override defaults + for _, opt := range opts { + opt(resp) + } + // ensure path to cache file exists - _ = f.MkdirAll(filepath.Dir(filename), 0700) - if info, err := f.Stat(filename); err == nil { + _ = resp.fs.MkdirAll(filepath.Dir(resp.filename), 0700) + if info, err := resp.fs.Stat(resp.filename); err == nil { if info.Mode()&0077 != 0 { // cache file has secret credentials and should only be accessible to the user, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("cache file %s is not private", filename) + return nil, fmt.Errorf("cache file %s is not private", resp.filename) } // do file locking on cache to prevent inconsistent reads - lock := newFlock(filename) + lock := resp.filelockCreator(resp.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -201,30 +196,26 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials ok, err := lock.TryRLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // unable to lock the cache, something is wrong, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("unable to read lock file %s: %v", filename, err) + return nil, fmt.Errorf("unable to read lock file %s: %v", resp.filename, err) } - cache, err := readCacheWhileLocked(filename) + cache, err := readCacheWhileLocked(resp.fs, resp.filename) if err != nil { // can't read or parse cache, refuse to use it. - return FileCacheProvider{}, err + return nil, err } - cachedCredential = cache.Get(cacheKey) + resp.cachedCredential = cache.Get(resp.cacheKey) } else { if errors.Is(err, fs.ErrNotExist) { // cache file is missing. maybe this is the very first run? continue to use cache. - _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", filename) + _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", resp.filename) } else { - return FileCacheProvider{}, fmt.Errorf("couldn't stat cache file: %w", err) + return nil, fmt.Errorf("couldn't stat cache file: %w", err) } } - return FileCacheProvider{ - creds, - cacheKey, - cachedCredential, - }, nil + return resp, nil } // Retrieve() implements the Provider interface, returning the cached credential if is not expired, @@ -243,9 +234,9 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { } if expiration, err := f.credentials.ExpiresAt(); err == nil { // underlying provider supports Expirer interface, so we can cache - filename := CacheFilename() + // do file locking on cache to prevent inconsistent writes - lock := newFlock(filename) + lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -253,7 +244,7 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) return credential, nil } f.cachedCredential = cachedCredential{ @@ -262,12 +253,12 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { nil, } // don't really care about read error. Either read the cache, or we create a new cache. - cache, _ := readCacheWhileLocked(filename) + cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) - err = writeCacheWhileLocked(filename, cache) + err = writeCacheWhileLocked(f.fs, f.filename, cache) if err != nil { // can't write cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", f.filename, err) err = nil } else { _, _ = fmt.Fprintf(os.Stderr, "Updated cached credential\n") @@ -292,23 +283,23 @@ func (f *FileCacheProvider) ExpiresAt() time.Time { return f.cachedCredential.Expiration } -// CacheFilename returns the name of the credential cache file, which can either be +// defaultCacheFilename returns the name of the credential cache file, which can either be // set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml -func CacheFilename() string { - if filename, ok := e.LookupEnv(cacheFileNameEnv); ok { +func defaultCacheFilename() string { + if filename := os.Getenv(cacheFileNameEnv); filename != "" { return filename } else { - return filepath.Join(UserHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") + return filepath.Join(userHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") } } -// UserHomeDir returns the home directory for the user the process is +// userHomeDir returns the home directory for the user the process is // running under. -func UserHomeDir() string { +func userHomeDir() string { if runtime.GOOS == "windows" { // Windows - return e.Getenv("USERPROFILE") + return os.Getenv("USERPROFILE") } // *nix - return e.Getenv("HOME") + return os.Getenv("HOME") } diff --git a/pkg/token/filecache_test.go b/pkg/filecache/filecache_test.go similarity index 55% rename from pkg/token/filecache_test.go rename to pkg/filecache/filecache_test.go index d69c75937..fd87bdf7c 100644 --- a/pkg/token/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,21 +1,32 @@ -package token +package filecache import ( "bytes" "context" "errors" - "github.com/aws/aws-sdk-go/aws/credentials" + "fmt" + "io/fs" "os" "testing" "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/spf13/afero" +) + +const ( + testFilename = "/test.yaml" ) +// stubProvider implements credentials.Provider with configurable response values type stubProvider struct { creds credentials.Value expired bool err error } +var _ credentials.Provider = &stubProvider{} + func (s *stubProvider) Retrieve() (credentials.Value, error) { s.expired = false s.creds.ProviderName = "stubProvider" @@ -26,89 +37,54 @@ func (s *stubProvider) IsExpired() bool { return s.expired } +// stubProviderExpirer implements credentials.Expirer with configurable expiration type stubProviderExpirer struct { stubProvider expiration time.Time } +var _ credentials.Expirer = &stubProviderExpirer{} + func (s *stubProviderExpirer) ExpiresAt() time.Time { return s.expiration } +// testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string size int64 - mode os.FileMode + mode fs.FileMode modTime time.Time } +var _ fs.FileInfo = &testFileInfo{} + func (fs *testFileInfo) Name() string { return fs.name } func (fs *testFileInfo) Size() int64 { return fs.size } -func (fs *testFileInfo) Mode() os.FileMode { return fs.mode } +func (fs *testFileInfo) Mode() fs.FileMode { return fs.mode } func (fs *testFileInfo) ModTime() time.Time { return fs.modTime } func (fs *testFileInfo) IsDir() bool { return fs.Mode().IsDir() } func (fs *testFileInfo) Sys() interface{} { return nil } +// testFs wraps afero.Fs with an overridable Stat() method type testFS struct { - filename string - fileinfo testFileInfo - data []byte + afero.Fs + + fileinfo fs.FileInfo err error - perm os.FileMode } -func (t *testFS) Stat(filename string) (os.FileInfo, error) { - t.filename = filename - if t.err == nil { - return &t.fileinfo, nil - } else { +func (t *testFS) Stat(filename string) (fs.FileInfo, error) { + if t.err != nil { return nil, t.err } + if t.fileinfo != nil { + return t.fileinfo, nil + } + return t.Fs.Stat(filename) } -func (t *testFS) ReadFile(filename string) ([]byte, error) { - t.filename = filename - return t.data, t.err -} - -func (t *testFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - t.filename = filename - t.data = data - t.perm = perm - return t.err -} - -func (t *testFS) MkdirAll(path string, perm os.FileMode) error { - t.filename = path - t.perm = perm - return t.err -} - -func (t *testFS) reset() { - t.filename = "" - t.fileinfo = testFileInfo{} - t.data = []byte{} - t.err = nil - t.perm = 0600 -} - -type testEnv struct { - values map[string]string -} - -func (e *testEnv) Getenv(key string) string { - return e.values[key] -} - -func (e *testEnv) LookupEnv(key string) (string, bool) { - value, ok := e.values[key] - return value, ok -} - -func (e *testEnv) reset() { - e.values = map[string]string{} -} - +// testFileLock implements FileLocker with configurable response options type testFilelock struct { ctx context.Context retryDelay time.Duration @@ -116,6 +92,8 @@ type testFilelock struct { err error } +var _ FileLocker = &testFilelock{} + func (l *testFilelock) Unlock() error { return nil } @@ -132,28 +110,12 @@ func (l *testFilelock) TryRLockContext(ctx context.Context, retryDelay time.Dura return l.success, l.err } -func (l *testFilelock) reset() { - l.ctx = context.TODO() - l.retryDelay = 0 - l.success = true - l.err = nil -} - -func getMocks() (tf *testFS, te *testEnv, testFlock *testFilelock) { - tf = &testFS{} - tf.reset() - f = tf - te = &testEnv{} - te.reset() - e = te - testFlock = &testFilelock{} - testFlock.reset() - newFlock = func(filename string) filelock { - return testFlock - } - return +// getMocks returns a mocked filesystem and FileLocker +func getMocks() (*testFS, *testFilelock) { + return &testFS{Fs: afero.NewMemMapFs()}, &testFilelock{context.TODO(), 0, true, nil} } +// makeCredential returns a dummy AWS crdential func makeCredential() credentials.Value { return credentials.Value{ AccessKeyID: "AKID", @@ -163,7 +125,9 @@ func makeCredential() credentials.Value { } } -func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c *credentials.Credentials) { +// validateFileCacheProvider ensures that the cache provider is properly initialized +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { + t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -181,21 +145,37 @@ func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c * } } +// testSetEnv sets an env var, and returns a cleanup func +func testSetEnv(t *testing.T, key, value string) func() { + t.Helper() + old := os.Getenv(key) + os.Setenv(key, value) + return func() { + if old == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, old) + } + } +} + func TestCacheFilename(t *testing.T) { - _, te, _ := getMocks() - te.values["HOME"] = "homedir" // unix - te.values["USERPROFILE"] = "homedir" // windows + c1 := testSetEnv(t, "HOME", "homedir") + defer c1() + c2 := testSetEnv(t, "USERPROFILE", "homedir") + defer c2() - filename := CacheFilename() + filename := defaultCacheFilename() expected := "homedir/.kube/cache/aws-iam-authenticator/credentials.yaml" if filename != expected { t.Errorf("Incorrect default cacheFilename, expected %s, got %s", expected, filename) } - te.values["AWS_IAM_AUTHENTICATOR_CACHE_FILE"] = "special.yaml" - filename = CacheFilename() + c3 := testSetEnv(t, "AWS_IAM_AUTHENTICATOR_CACHE_FILE", "special.yaml") + defer c3() + filename = defaultCacheFilename() expected = "special.yaml" if filename != expected { t.Errorf("Incorrect custom cacheFilename, expected %s, got %s", @@ -206,85 +186,131 @@ func TestCacheFilename(t *testing.T) { func TestNewFileCacheProvider_Missing(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, tfl := getMocks() - // missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + })) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing cache file should result in expired cached credential") } - tf.err = nil } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, _ := getMocks() + // afero.MemMapFs always returns tempfile FileInfo, + // so we manually set the response to the Stat() call + tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - tf.fileinfo.mode = 0777 - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + ) if err == nil { t.Errorf("Expected error due to public permissions") } - if tf.filename != CacheFilename() { - t.Errorf("unexpected file checked, expected %s, got %s", - CacheFilename(), tf.filename) + wantMsg := fmt.Sprintf("cache file %s is not private", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } } func TestNewFileCacheProvider_Unlockable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, testFlock := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // unable to lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to lock failure") } - testFlock.success = true - testFlock.err = nil } func TestNewFileCacheProvider_Unreadable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, _ := getMocks() + tfs.Create(testFilename) + // afero.MemMapFs always returns tempfile FileInfo, + // so we manually set the response to the Stat() call + tfs.fileinfo = &testFileInfo{mode: 0000} - // unable to read existing cache - tf.err = errors.New("read failure") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + ) if err == nil { t.Errorf("Expected error due to read failure") + return + } + wantMsg := fmt.Sprintf("unable to read lock file %s: open %s: read-only file system", testFilename, testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } - tf.err = nil } func TestNewFileCacheProvider_Unparseable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // unable to parse yaml - tf.data = []byte("invalid: yaml: file") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + afero.WriteFile( + tfs, + testFilename, + []byte("invalid: yaml: file"), + 0700) + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to bad yaml") } + wantMsg := fmt.Sprintf("unable to parse file %s: yaml: mapping values are not allowed in this context", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) + } } func TestNewFileCacheProvider_Empty(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, _ = getMocks() + tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("empty cache file should result in expired cached credential") @@ -294,13 +320,24 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse existing cluster without matching arn - tf.data = []byte(`clusters: + tfs, tfl := getMocks() + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: CLUSTER: -`) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + ARN2: {} +`), + 0700) + // successfully parse existing cluster without matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing arn in cache file should result in expired cached credential") @@ -310,10 +347,7 @@ func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { func TestNewFileCacheProvider_ExistingARN(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -324,11 +358,27 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { providername: JKL expiration: 2018-01-02T03:04:56.789Z `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // successfully parse cluster with matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + }), + ) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { - t.Errorf("cached credential not extracted correctly") + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } // fiddle with clock p.cachedCredential.currentTime = func() time.Time { @@ -353,11 +403,17 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { creds: providerCredential, }) - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) credential, err := p.Retrieve() @@ -370,6 +426,7 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { } } +// makeExpirerCredentials returns an expiring credential func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { providerCredential = makeCredential() expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) @@ -385,17 +442,23 @@ func makeExpirerCredentials() (providerCredential credentials.Value, expiration func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, testFlock := getMocks() + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) validateFileCacheProvider(t, p, err, c) // retrieve credential, which will fetch from underlying Provider // fail to get write lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -409,16 +472,19 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providerCredential, expiration, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - // retrieve credential, which will fetch from underlying Provider - // fail to write cache - tf.err = errors.New("can't write cache") credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -427,14 +493,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { t.Errorf("Cache did not return provider credential, got %v, expected %v", credential, providerCredential) } - if tf.filename != CacheFilename() { - t.Errorf("Wrote to wrong file, expected %v, got %v", - CacheFilename(), tf.filename) - } - if tf.perm != 0600 { - t.Errorf("Wrote with wrong permissions, expected %o, got %o", - 0600, tf.perm) - } + expectedData := []byte(`clusters: CLUSTER: PROFILE: @@ -446,22 +505,31 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providername: stubProvider expiration: ` + expiration.Format(time.RFC3339Nano) + ` `) - if bytes.Compare(tf.data, expectedData) != 0 { + got, err := afero.ReadFile(tfs, testFilename) + if err != nil { + t.Errorf("unexpected error reading generated file: %v", err) + } + if !bytes.Equal(got, expectedData) { t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, tf.data) + expectedData, got) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - tf.err = nil // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -478,11 +546,13 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) + currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - tf, _, _ := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -491,15 +561,20 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { secretaccesskey: DEF sessiontoken: GHI providername: JKL - expiration: 2018-01-02T03:04:56.789Z + expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + })) validateFileCacheProvider(t, p, err, c) // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } + p.cachedCredential.currentTime = func() time.Time { return currentTime } credential, err := p.Retrieve() if err != nil { diff --git a/pkg/token/token.go b/pkg/token/token.go index 16ab8d92b..d9d7fd2e8 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -44,6 +44,7 @@ import ( clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -247,8 +248,8 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { - sess.Config.Credentials = credentials.NewCredentials(&cacheProvider) + if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) }