From ba2e418bfb176d90f4f4ae562a84ff1f1914a05f Mon Sep 17 00:00:00 2001 From: Anton Osmond Date: Thu, 22 Oct 2020 18:04:12 +0100 Subject: [PATCH] Add sso-suuport --- aws/credentials/ssocreds/sso_provider.go | 174 +++++++++++++++++++++++ aws/session/credentials.go | 26 ++++ aws/session/shared_config.go | 17 +++ aws/version.go | 2 +- go.sum | 2 + 5 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 aws/credentials/ssocreds/sso_provider.go diff --git a/aws/credentials/ssocreds/sso_provider.go b/aws/credentials/ssocreds/sso_provider.go new file mode 100644 index 00000000000..3c5777b6449 --- /dev/null +++ b/aws/credentials/ssocreds/sso_provider.go @@ -0,0 +1,174 @@ +package ssocreds + +import ( + "encoding/json" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/service/sso" + "github.com/aws/aws-sdk-go/service/sso/ssoiface" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/internal/shareddefaults" +) + +// ProviderName is the name of the credentials provider. +const ( + ErrCodeSSOCredentials = "SSOCredentialErr" + ProviderName = `SSOProvider` +) + +// now is used to return a time.Time object representing +// the current time. This can be used to easily test and +// compare test values. +var now = time.Now + +// awsSSOCachePath holds the path to the AWS SSO cache dir +var awsSSOCachePath string + +// SSOProvider is used to retrieve credentials using an SSO access token +type SSOProvider struct { + credentials.Expiry + + // Duration the STS credentials will be valid for. Truncated to seconds. + // If unset, the assumed role will use AssumeRoleWithWebIdentity's default + // expiry duration. See + // https://docs.aws.amazon.com/sdk-for-go/api/service/sts/#STS.AssumeRoleWithWebIdentity + // for more information. + Duration time.Duration + + // The amount of time the credentials will be refreshed before they expire. + // This is useful refresh credentials before they expire to reduce risk of + // using credentials as they expire. If unset, will default to no expiry + // window. + ExpiryWindow time.Duration + + client ssoiface.SSOAPI + + accountID string + roleName string + + cache *SSOCache +} + +// SSOCache represents an AWS SSO cache file +type SSOCache struct { + StartURL string `json:"startUrl"` + Region string `json:"region"` + AccessToken string `json:"accessToken"` + ExpiresAt string `json:"expiresAt"` +} + +// NewSSOCredentials will return a new set of temporary credentials based on the SSO role & token +func NewSSOCredentials(c client.ConfigProvider, ssoAccountID, ssoRoleName string) (*credentials.Credentials, error) { + svc := sso.New(c) + p, err := NewSSOProvider(svc, ssoAccountID, ssoRoleName) + if err != nil { + return nil, awserr.New(ErrCodeSSOCredentials, "failed to retrieve credentials", err) + } + return credentials.NewCredentials(p), nil +} + +// NewSSOProvider will return a new SSOProvider configured with the +// details from the SSO cache +func NewSSOProvider(svc ssoiface.SSOAPI, accountID, roleName string) (*SSOProvider, error) { + cache, err := getCache(filepath.Join(shareddefaults.UserHomeDir(), ".aws/sso/cache")) + if err != nil { + return nil, err + } + return &SSOProvider{ + client: svc, + accountID: accountID, + roleName: roleName, + cache: cache, + }, nil +} + +// Retrieve will attempt to get a set of temporary credentials +// using an AWS SSO token from the SSO Cache +func (p *SSOProvider) Retrieve() (credentials.Value, error) { + return p.RetrieveWithContext(aws.BackgroundContext()) +} + +// RetrieveWithContext will attempt to get a set of temporary credentials +// using an AWS SSO token from the SSO Cache +func (p *SSOProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { + in := &sso.GetRoleCredentialsInput{ + AccountId: &p.accountID, + RoleName: &p.roleName, + AccessToken: &p.cache.AccessToken, + } + req, resp := p.client.GetRoleCredentialsRequest(in) + req.SetContext(ctx) + + if err := req.Send(); err != nil { + return credentials.Value{}, awserr.New(ErrCodeSSOCredentials, "failed to retrieve credentials", err) + } + + t := time.Unix(0, *resp.RoleCredentials.Expiration*int64(time.Millisecond)) + p.SetExpiration(t.UTC(), p.ExpiryWindow) + + return credentials.Value{ + ProviderName: "SSOCredentialProvider", + AccessKeyID: *resp.RoleCredentials.AccessKeyId, + SecretAccessKey: *resp.RoleCredentials.SecretAccessKey, + SessionToken: *resp.RoleCredentials.SessionToken, + }, nil +} + +func getCache(cacheDir string) (*SSOCache, error) { + + cache := &SSOCache{} + + err := filepath.Walk(cacheDir, func(path string, info os.FileInfo, err error) error { + // handle failure accessing a path + if err != nil { + return err + } + // skip directories (excluding the cache dir itself) + if info.IsDir() && path != cacheDir { + return filepath.SkipDir + } + // skip anything that's not a json file + if !strings.HasSuffix(path, ".json") { + return nil + } + // skip the botocore files + if strings.HasPrefix(filepath.Base(path), "botocore-") { + return nil + } + // get the cache details from file + cache, err = getCacheFile(path) + if err != nil { + return err + } + return io.EOF + }) + + if err != nil && err != io.EOF { + return nil, err + } + + return cache, nil + +} + +func getCacheFile(path string) (*SSOCache, error) { + cache := &SSOCache{} + b, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, cache) + if err != nil { + return nil, err + } + return cache, nil +} diff --git a/aws/session/credentials.go b/aws/session/credentials.go index fe6dac1f476..cb1357df491 100644 --- a/aws/session/credentials.go +++ b/aws/session/credentials.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/processcreds" + "github.com/aws/aws-sdk-go/aws/credentials/ssocreds" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/request" @@ -119,6 +120,11 @@ func resolveCredsFromProfile(cfg *aws.Config, sharedCfg.RoleSessionName, ) + case len(sharedCfg.SSORoleName) != 0: + return resolveCredsFromSSOToken(cfg, handlers, + sharedCfg, + ) + default: // Fallback to default credentials provider, include mock errors for // the credential chain so user can identify why credentials failed to @@ -265,3 +271,23 @@ func (c credProviderError) Retrieve() (credentials.Value, error) { func (c credProviderError) IsExpired() bool { return true } + +func resolveCredsFromSSOToken(cfg *aws.Config, + handlers request.Handlers, + sharedCfg sharedConfig, +) (*credentials.Credentials, error) { + + creds, err := ssocreds.NewSSOCredentials( + &Session{ + Config: cfg.WithRegion(sharedCfg.SSORegion), + Handlers: handlers.Copy(), + }, + sharedCfg.SSOAccountID, sharedCfg.SSORoleName, + ) + + if err != nil { + return nil, err + } + + return creds, nil +} diff --git a/aws/session/shared_config.go b/aws/session/shared_config.go index 680805a38ad..8085bda965f 100644 --- a/aws/session/shared_config.go +++ b/aws/session/shared_config.go @@ -25,6 +25,12 @@ const ( roleSessionNameKey = `role_session_name` // optional roleDurationSecondsKey = "duration_seconds" // optional + // SSO Credentials group + ssoStartURLKey = `sso_start_url` // optional + ssoRegionKey = `sso_region` // optional + ssoAccountIDKey = `sso_account_id` // optional + ssoRoleNameKey = `sso_role_name` // optional + // CSM options csmEnabledKey = `csm_enabled` csmHostKey = `csm_host` @@ -81,6 +87,11 @@ type sharedConfig struct { MFASerial string AssumeRoleDuration *time.Duration + SSOStartURL string + SSORegion string + SSOAccountID string + SSORoleName string + SourceProfileName string SourceProfile *sharedConfig @@ -277,6 +288,12 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, e updateString(&cfg.CredentialSource, section, credentialSourceKey) updateString(&cfg.Region, section, regionKey) + // SSO Parameters + updateString(&cfg.SSOAccountID, section, ssoAccountIDKey) + updateString(&cfg.SSORegion, section, ssoRegionKey) + updateString(&cfg.SSORoleName, section, ssoRoleNameKey) + updateString(&cfg.SSOStartURL, section, ssoStartURLKey) + if section.Has(roleDurationSecondsKey) { d := time.Duration(section.Int(roleDurationSecondsKey)) * time.Second cfg.AssumeRoleDuration = &d diff --git a/aws/version.go b/aws/version.go index 8915ec5f3c7..c0e056ea83a 100644 --- a/aws/version.go +++ b/aws/version.go @@ -5,4 +5,4 @@ package aws const SDKName = "aws-sdk-go" // SDKVersion is the version of this SDK -const SDKVersion = "1.35.13" +const SDKVersion = "1.36.0" diff --git a/go.sum b/go.sum index f8852e49034..21530c38160 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/aws/aws-sdk-go v1.35.12 h1:qpxQ/DXfgsTNSYn8mUaCgQiJkCjBP8iHKw5ju+wkucU= +github.com/aws/aws-sdk-go v1.35.12/go.mod h1:tlPOdRjfxPBpNIwqDj61rmsnA85v9jc0Ps9+muhnW+k= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=