Skip to content

Commit

Permalink
Refactored token filecache
Browse files Browse the repository at this point in the history
The token filecache used to use a private global function for creating a
filelock, and overrode it in tests with a hand-crafted mocks for
filesystem and environment variable operations.

This change adds adds injectability to the filecache's filesystem and
file lock using afero. This change also will simplify future changes
when updating the AWS SDK with new credential interfaces.

Signed-off-by: Micah Hausler <[email protected]>
  • Loading branch information
micahhausler committed Aug 29, 2024
1 parent b154c1d commit 975cca8
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 260 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/manifoldco/promptui v0.9.0
github.com/prometheus/client_golang v1.19.1
github.com/sirupsen/logrus v1.9.3
github.com/spf13/afero v1.11.0
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.18.2
golang.org/x/time v0.5.0
Expand Down Expand Up @@ -61,7 +62,6 @@ require (
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
Expand Down
173 changes: 82 additions & 91 deletions pkg/token/filecache.go → pkg/filecache/filecache.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package token
package filecache

import (
"context"
Expand All @@ -12,68 +12,22 @@ import (

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/gofrs/flock"
"github.com/spf13/afero"
"gopkg.in/yaml.v2"
)

// env variable name for custom credential cache file location
const cacheFileNameEnv = "AWS_IAM_AUTHENTICATOR_CACHE_FILE"

// A mockable filesystem interface
var f filesystem = osFS{}

type filesystem interface {
Stat(filename string) (os.FileInfo, error)
ReadFile(filename string) ([]byte, error)
WriteFile(filename string, data []byte, perm os.FileMode) error
MkdirAll(path string, perm os.FileMode) error
}

// default os based implementation
type osFS struct{}

func (osFS) Stat(filename string) (os.FileInfo, error) {
return os.Stat(filename)
}

func (osFS) ReadFile(filename string) ([]byte, error) {
return os.ReadFile(filename)
}

func (osFS) WriteFile(filename string, data []byte, perm os.FileMode) error {
return os.WriteFile(filename, data, perm)
}

func (osFS) MkdirAll(path string, perm os.FileMode) error {
return os.MkdirAll(path, perm)
}

// A mockable environment interface
var e environment = osEnv{}

type environment interface {
Getenv(key string) string
LookupEnv(key string) (string, bool)
}

// default os based implementation
type osEnv struct{}

func (osEnv) Getenv(key string) string {
return os.Getenv(key)
}

func (osEnv) LookupEnv(key string) (string, bool) {
return os.LookupEnv(key)
}

// A mockable flock interface
type filelock interface {
// FileLocker is a subset of the methods exposed by *flock.Flock
type FileLocker interface {
Unlock() error
TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error)
TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error)
}

var newFlock = func(filename string) filelock {
// NewFileLocker returns a *flock.Flock that satisfies FileLocker
func NewFileLocker(filename string) FileLocker {
return flock.New(filename)
}

Expand Down Expand Up @@ -135,11 +89,11 @@ func (c *cachedCredential) IsExpired() bool {
// readCacheWhileLocked reads the contents of the credential cache and returns the
// parsed yaml as a cacheFile object. This method must be called while a shared
// lock is held on the filename.
func readCacheWhileLocked(filename string) (cache cacheFile, err error) {
func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) {
cache = cacheFile{
map[string]map[string]map[string]cachedCredential{},
}
data, err := f.ReadFile(filename)
data, err := afero.ReadFile(fs, filename)
if err != nil {
err = fmt.Errorf("unable to open file %s: %v", filename, err)
return
Expand All @@ -155,76 +109,113 @@ func readCacheWhileLocked(filename string) (cache cacheFile, err error) {
// writeCacheWhileLocked writes the contents of the credential cache using the
// yaml marshaled form of the passed cacheFile object. This method must be
// called while an exclusive lock is held on the filename.
func writeCacheWhileLocked(filename string, cache cacheFile) error {
func writeCacheWhileLocked(fs afero.Fs, filename string, cache cacheFile) error {
data, err := yaml.Marshal(cache)
if err == nil {
// write privately owned by the user
err = f.WriteFile(filename, data, 0600)
err = afero.WriteFile(fs, filename, data, 0600)
}
return err
}

// FileCacheProvider is a Provider implementation that wraps an underlying Provider
type FileCacheOpt func(*FileCacheProvider)

// WithFs returns a FileCacheOpt that sets the cache's filesystem
func WithFs(fs afero.Fs) FileCacheOpt {
return func(p *FileCacheProvider) {
p.fs = fs
}
}

// WithFilename returns a FileCacheOpt that sets the cache's file
func WithFilename(filename string) FileCacheOpt {
return func(p *FileCacheProvider) {
p.filename = filename
}
}

// WithFileLockCreator returns a FileCacheOpt that sets the cache's FileLocker
// creation function
func WithFileLockerCreator(f func(string) FileLocker) FileCacheOpt {
return func(p *FileCacheProvider) {
p.filelockCreator = f
}
}

// FileCacheProvider is a credentials.Provider implementation that wraps an underlying Provider
// (contained in Credentials) and provides caching support for credentials for the
// specified clusterID, profile, and roleARN (contained in cacheKey)
type FileCacheProvider struct {
fs afero.Fs
filelockCreator func(string) FileLocker
filename string
credentials *credentials.Credentials // the underlying implementation that has the *real* Provider
cacheKey cacheKey // cache key parameters used to create Provider
cachedCredential cachedCredential // the cached credential, if it exists
}

var _ credentials.Provider = &FileCacheProvider{}

// NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials,
// and works with an on disk cache to speed up credential usage when the cached copy is not expired.
// If there are any problems accessing or initializing the cache, an error will be returned, and
// callers should just use the existing credentials provider.
func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials) (FileCacheProvider, error) {
func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) {
if creds == nil {
return FileCacheProvider{}, errors.New("no underlying Credentials object provided")
return nil, errors.New("no underlying Credentials object provided")
}

resp := &FileCacheProvider{
fs: afero.NewOsFs(),
filelockCreator: NewFileLocker,
filename: defaultCacheFilename(),
credentials: creds,
cacheKey: cacheKey{clusterID, profile, roleARN},
cachedCredential: cachedCredential{},
}
filename := CacheFilename()
cacheKey := cacheKey{clusterID, profile, roleARN}
cachedCredential := cachedCredential{}

// override defaults
for _, opt := range opts {
opt(resp)
}

// ensure path to cache file exists
_ = f.MkdirAll(filepath.Dir(filename), 0700)
if info, err := f.Stat(filename); err == nil {
_ = resp.fs.MkdirAll(filepath.Dir(resp.filename), 0700)
if info, err := resp.fs.Stat(resp.filename); err == nil {
if info.Mode()&0077 != 0 {
// cache file has secret credentials and should only be accessible to the user, refuse to use it.
return FileCacheProvider{}, fmt.Errorf("cache file %s is not private", filename)
return nil, fmt.Errorf("cache file %s is not private", resp.filename)
}

// do file locking on cache to prevent inconsistent reads
lock := newFlock(filename)
lock := resp.filelockCreator(resp.filename)
defer lock.Unlock()
// wait up to a second for the file to lock
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()
ok, err := lock.TryRLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second
if !ok {
// unable to lock the cache, something is wrong, refuse to use it.
return FileCacheProvider{}, fmt.Errorf("unable to read lock file %s: %v", filename, err)
return nil, fmt.Errorf("unable to read lock file %s: %v", resp.filename, err)
}

cache, err := readCacheWhileLocked(filename)
cache, err := readCacheWhileLocked(resp.fs, resp.filename)
if err != nil {
// can't read or parse cache, refuse to use it.
return FileCacheProvider{}, err
return nil, err
}

cachedCredential = cache.Get(cacheKey)
resp.cachedCredential = cache.Get(resp.cacheKey)
} else {
if errors.Is(err, fs.ErrNotExist) {
// cache file is missing. maybe this is the very first run? continue to use cache.
_, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", filename)
_, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", resp.filename)
} else {
return FileCacheProvider{}, fmt.Errorf("couldn't stat cache file: %w", err)
return nil, fmt.Errorf("couldn't stat cache file: %w", err)
}
}

return FileCacheProvider{
creds,
cacheKey,
cachedCredential,
}, nil
return resp, nil
}

// Retrieve() implements the Provider interface, returning the cached credential if is not expired,
Expand All @@ -243,17 +234,17 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) {
}
if expiration, err := f.credentials.ExpiresAt(); err == nil {
// underlying provider supports Expirer interface, so we can cache
filename := CacheFilename()

// do file locking on cache to prevent inconsistent writes
lock := newFlock(filename)
lock := f.filelockCreator(f.filename)
defer lock.Unlock()
// wait up to a second for the file to lock
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()
ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second
if !ok {
// can't get write lock to create/update cache, but still return the credential
_, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", filename, err)
_, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err)
return credential, nil
}
f.cachedCredential = cachedCredential{
Expand All @@ -262,12 +253,12 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) {
nil,
}
// don't really care about read error. Either read the cache, or we create a new cache.
cache, _ := readCacheWhileLocked(filename)
cache, _ := readCacheWhileLocked(f.fs, f.filename)
cache.Put(f.cacheKey, f.cachedCredential)
err = writeCacheWhileLocked(filename, cache)
err = writeCacheWhileLocked(f.fs, f.filename, cache)
if err != nil {
// can't write cache, but still return the credential
_, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", filename, err)
_, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", f.filename, err)
err = nil
} else {
_, _ = fmt.Fprintf(os.Stderr, "Updated cached credential\n")
Expand All @@ -292,23 +283,23 @@ func (f *FileCacheProvider) ExpiresAt() time.Time {
return f.cachedCredential.Expiration
}

// CacheFilename returns the name of the credential cache file, which can either be
// defaultCacheFilename returns the name of the credential cache file, which can either be
// set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml
func CacheFilename() string {
if filename, ok := e.LookupEnv(cacheFileNameEnv); ok {
func defaultCacheFilename() string {
if filename := os.Getenv(cacheFileNameEnv); filename != "" {
return filename
} else {
return filepath.Join(UserHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml")
return filepath.Join(userHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml")
}
}

// UserHomeDir returns the home directory for the user the process is
// userHomeDir returns the home directory for the user the process is
// running under.
func UserHomeDir() string {
func userHomeDir() string {
if runtime.GOOS == "windows" { // Windows
return e.Getenv("USERPROFILE")
return os.Getenv("USERPROFILE")
}

// *nix
return e.Getenv("HOME")
return os.Getenv("HOME")
}
Loading

0 comments on commit 975cca8

Please sign in to comment.