Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sso-support #3610

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions aws/credentials/ssocreds/sso_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package ssocreds

import (
"encoding/json"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/service/sso"
"github.com/aws/aws-sdk-go/service/sso/ssoiface"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)

// ProviderName is the name of the credentials provider.
const (
ErrCodeSSOCredentials = "SSOCredentialErr"
ProviderName = `SSOProvider`
)

// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = time.Now

// awsSSOCachePath holds the path to the AWS SSO cache dir
var awsSSOCachePath string

// SSOProvider is used to retrieve credentials using an SSO access token
type SSOProvider struct {
credentials.Expiry

// Duration the STS credentials will be valid for. Truncated to seconds.
// If unset, the assumed role will use AssumeRoleWithWebIdentity's default
// expiry duration. See
// https://docs.aws.amazon.com/sdk-for-go/api/service/sts/#STS.AssumeRoleWithWebIdentity
// for more information.
Duration time.Duration

// The amount of time the credentials will be refreshed before they expire.
// This is useful refresh credentials before they expire to reduce risk of
// using credentials as they expire. If unset, will default to no expiry
// window.
ExpiryWindow time.Duration

client ssoiface.SSOAPI

accountID string
roleName string

cache *SSOCache
}

// SSOCache represents an AWS SSO cache file
type SSOCache struct {
StartURL string `json:"startUrl"`
Region string `json:"region"`
AccessToken string `json:"accessToken"`
ExpiresAt string `json:"expiresAt"`
}

// NewSSOCredentials will return a new set of temporary credentials based on the SSO role & token
func NewSSOCredentials(c client.ConfigProvider, ssoAccountID, ssoRoleName string) (*credentials.Credentials, error) {
svc := sso.New(c)
p, err := NewSSOProvider(svc, ssoAccountID, ssoRoleName)
if err != nil {
return nil, awserr.New(ErrCodeSSOCredentials, "failed to retrieve credentials", err)
}
return credentials.NewCredentials(p), nil
}

// NewSSOProvider will return a new SSOProvider configured with the
// details from the SSO cache
func NewSSOProvider(svc ssoiface.SSOAPI, accountID, roleName string) (*SSOProvider, error) {
cache, err := getCache(filepath.Join(shareddefaults.UserHomeDir(), ".aws/sso/cache"))
if err != nil {
return nil, err
}
return &SSOProvider{
client: svc,
accountID: accountID,
roleName: roleName,
cache: cache,
}, nil
}

// Retrieve will attempt to get a set of temporary credentials
// using an AWS SSO token from the SSO Cache
func (p *SSOProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext will attempt to get a set of temporary credentials
// using an AWS SSO token from the SSO Cache
func (p *SSOProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
in := &sso.GetRoleCredentialsInput{
AccountId: &p.accountID,
RoleName: &p.roleName,
AccessToken: &p.cache.AccessToken,
}
req, resp := p.client.GetRoleCredentialsRequest(in)
req.SetContext(ctx)

if err := req.Send(); err != nil {
return credentials.Value{}, awserr.New(ErrCodeSSOCredentials, "failed to retrieve credentials", err)
}

t := time.Unix(0, *resp.RoleCredentials.Expiration*int64(time.Millisecond))
p.SetExpiration(t.UTC(), p.ExpiryWindow)

return credentials.Value{
ProviderName: "SSOCredentialProvider",
AccessKeyID: *resp.RoleCredentials.AccessKeyId,
SecretAccessKey: *resp.RoleCredentials.SecretAccessKey,
SessionToken: *resp.RoleCredentials.SessionToken,
}, nil
}

func getCache(cacheDir string) (*SSOCache, error) {

cache := &SSOCache{}

err := filepath.Walk(cacheDir, func(path string, info os.FileInfo, err error) error {
// handle failure accessing a path
if err != nil {
return err
}
// skip directories (excluding the cache dir itself)
if info.IsDir() && path != cacheDir {
return filepath.SkipDir
}
// skip anything that's not a json file
if !strings.HasSuffix(path, ".json") {
return nil
}
// skip the botocore files
if strings.HasPrefix(filepath.Base(path), "botocore-") {
return nil
}
// get the cache details from file
cache, err = getCacheFile(path)
if err != nil {
return err
}
return io.EOF
})

if err != nil && err != io.EOF {
return nil, err
}

return cache, nil

}

func getCacheFile(path string) (*SSOCache, error) {
cache := &SSOCache{}
b, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
err = json.Unmarshal(b, cache)
if err != nil {
return nil, err
}
return cache, nil
}
26 changes: 26 additions & 0 deletions aws/session/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/ssocreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
Expand Down Expand Up @@ -119,6 +120,11 @@ func resolveCredsFromProfile(cfg *aws.Config,
sharedCfg.RoleSessionName,
)

case len(sharedCfg.SSORoleName) != 0:
return resolveCredsFromSSOToken(cfg, handlers,
sharedCfg,
)

default:
// Fallback to default credentials provider, include mock errors for
// the credential chain so user can identify why credentials failed to
Expand Down Expand Up @@ -265,3 +271,23 @@ func (c credProviderError) Retrieve() (credentials.Value, error) {
func (c credProviderError) IsExpired() bool {
return true
}

func resolveCredsFromSSOToken(cfg *aws.Config,
handlers request.Handlers,
sharedCfg sharedConfig,
) (*credentials.Credentials, error) {

creds, err := ssocreds.NewSSOCredentials(
&Session{
Config: cfg.WithRegion(sharedCfg.SSORegion),
Handlers: handlers.Copy(),
},
sharedCfg.SSOAccountID, sharedCfg.SSORoleName,
)

if err != nil {
return nil, err
}

return creds, nil
}
17 changes: 17 additions & 0 deletions aws/session/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ const (
roleSessionNameKey = `role_session_name` // optional
roleDurationSecondsKey = "duration_seconds" // optional

// SSO Credentials group
ssoStartURLKey = `sso_start_url` // optional
ssoRegionKey = `sso_region` // optional
ssoAccountIDKey = `sso_account_id` // optional
ssoRoleNameKey = `sso_role_name` // optional

// CSM options
csmEnabledKey = `csm_enabled`
csmHostKey = `csm_host`
Expand Down Expand Up @@ -81,6 +87,11 @@ type sharedConfig struct {
MFASerial string
AssumeRoleDuration *time.Duration

SSOStartURL string
SSORegion string
SSOAccountID string
SSORoleName string

SourceProfileName string
SourceProfile *sharedConfig

Expand Down Expand Up @@ -277,6 +288,12 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, e
updateString(&cfg.CredentialSource, section, credentialSourceKey)
updateString(&cfg.Region, section, regionKey)

// SSO Parameters
updateString(&cfg.SSOAccountID, section, ssoAccountIDKey)
updateString(&cfg.SSORegion, section, ssoRegionKey)
updateString(&cfg.SSORoleName, section, ssoRoleNameKey)
updateString(&cfg.SSOStartURL, section, ssoStartURLKey)

if section.Has(roleDurationSecondsKey) {
d := time.Duration(section.Int(roleDurationSecondsKey)) * time.Second
cfg.AssumeRoleDuration = &d
Expand Down
2 changes: 1 addition & 1 deletion aws/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"

// SDKVersion is the version of this SDK
const SDKVersion = "1.35.13"
const SDKVersion = "1.36.0"
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/aws/aws-sdk-go v1.35.12 h1:qpxQ/DXfgsTNSYn8mUaCgQiJkCjBP8iHKw5ju+wkucU=
github.com/aws/aws-sdk-go v1.35.12/go.mod h1:tlPOdRjfxPBpNIwqDj61rmsnA85v9jc0Ps9+muhnW+k=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
Expand Down