Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sso credential provider to support token provider #4875

Merged
merged 8 commits into from
Jun 15, 2023
54 changes: 40 additions & 14 deletions aws/credentials/ssocreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ type Provider struct {

// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
StartURL string

// The filepath the cached token will be retrieved from. If unset Provider will
// use the startURL to determine the filepath at.
//
// ~/.aws/sso/cache/<sha1-hex-encoded-startURL>.json
//
// If custom cached token filepath is used, the Provider's startUrl
// parameter will be ignored.
CachedTokenFilepath string

// Used by the SSOCredentialProvider if a token configuration
// profile is used in the shared config
SSOTokenProvider *SSOTokenProvider
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aajtodd @lucix-aws i am curious about your opinion here.

in my V2 implementation i did this same thing.

at the time, not knowing Go that well (or the Go SDK), this seemed fine to me. but i think i prob should have used the TokenProvider interface rather than the SSOTokenProvider struct implementation. this is because our use of it doesnt seem to depend on anything specific to the SSOTokenProvider: its just calling RetrieveBearerToken. also, another benefit as @wty-Bryant pointed out in a conversation I had with him, if this were an interface we can more easily unit test it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance I think that's ok if you want to go down that route.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I will try to implement the interface route

}

// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
Expand Down Expand Up @@ -88,13 +101,31 @@ func (p *Provider) Retrieve() (credentials.Value, error) {
// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
tokenFile, err := loadTokenFile(p.StartURL)
if err != nil {
return credentials.Value{}, err
var accessToken *string
if p.SSOTokenProvider != nil {
token, err := p.SSOTokenProvider.RetrieveBearerToken(ctx)
if err != nil {
return credentials.Value{}, err
}
accessToken = &token.Value
} else {
if p.CachedTokenFilepath == "" {
cachedTokenFilePath, err := getCachedFilePath(p.StartURL)
if err != nil {
return credentials.Value{}, err
}
p.CachedTokenFilepath = cachedTokenFilePath
}

tokenFile, err := loadTokenFile(p.CachedTokenFilepath)
if err != nil {
return credentials.Value{}, err
}
accessToken = &tokenFile.AccessToken
}

output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
AccessToken: &tokenFile.AccessToken,
AccessToken: accessToken,
AccountId: &p.AccountID,
RoleName: &p.RoleName,
})
Expand All @@ -113,13 +144,13 @@ func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Val
}, nil
}

func getCacheFileName(url string) (string, error) {
func getCachedFilePath(startUrl string) (string, error) {
hash := sha1.New()
_, err := hash.Write([]byte(url))
_, err := hash.Write([]byte(startUrl))
if err != nil {
return "", err
}
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
return filepath.Join(defaultCacheLocation(), strings.ToLower(hex.EncodeToString(hash.Sum(nil)))+".json"), nil
}

type token struct {
Expand All @@ -133,13 +164,8 @@ func (t token) Expired() bool {
return nowTime().Round(0).After(time.Time(t.ExpiresAt))
}

func loadTokenFile(startURL string) (t token, err error) {
key, err := getCacheFileName(startURL)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}

fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
func loadTokenFile(cachedTokenPath string) (t token, err error) {
fileBytes, err := ioutil.ReadFile(cachedTokenPath)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
Expand Down
50 changes: 41 additions & 9 deletions aws/credentials/ssocreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package ssocreds

import (
"fmt"
"path/filepath"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -88,11 +89,12 @@ func TestProvider(t *testing.T) {
defer restoreTime()

cases := map[string]struct {
Client mockClient
AccountID string
Region string
RoleName string
StartURL string
Client mockClient
AccountID string
Region string
RoleName string
StartURL string
CachedTokenFilePath string

ExpectedErr bool
ExpectedCredentials credentials.Value
Expand Down Expand Up @@ -131,6 +133,35 @@ func TestProvider(t *testing.T) {
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"custom cached token file": {
Client: mockClient{
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedAccessToken: "ZhbHVldGhpcyBpcyBub3QgYSByZWFsIH",
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
return &sso.GetRoleCredentialsOutput{
RoleCredentials: &sso.RoleCredentials{
AccessKeyId: aws.String("AccessKey"),
SecretAccessKey: aws.String("SecretKey"),
SessionToken: aws.String("SessionToken"),
Expiration: aws.Int64(1611177743123),
},
}, nil
},
},
CachedTokenFilePath: filepath.Join("testdata", "custom_cached_token.json"),
AccountID: "012345678901",
Region: "us-west-2",
RoleName: "TestRole",
StartURL: "ignored value",
ExpectedCredentials: credentials.Value{
AccessKeyID: "AccessKey",
SecretAccessKey: "SecretKey",
SessionToken: "SessionToken",
ProviderName: ProviderName,
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"expired access token": {
StartURL: "https://expired",
ExpectedErr: true,
Expand Down Expand Up @@ -158,10 +189,11 @@ func TestProvider(t *testing.T) {
tt.Client.t = t

provider := &Provider{
Client: tt.Client,
AccountID: tt.AccountID,
RoleName: tt.RoleName,
StartURL: tt.StartURL,
Client: tt.Client,
AccountID: tt.AccountID,
RoleName: tt.RoleName,
StartURL: tt.StartURL,
CachedTokenFilepath: tt.CachedTokenFilePath,
}

provider.Expiry.CurrentTime = nowTime
Expand Down
4 changes: 4 additions & 0 deletions aws/credentials/ssocreds/testdata/custom_cached_token.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"accessToken": "ZhbHVldGhpcyBpcyBub3QgYSByZWFsIH",
"expiresAt": "2021-01-19T23:00:00Z"
}
21 changes: 20 additions & 1 deletion aws/session/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/aws/aws-sdk-go/service/ssooidc"
"github.com/aws/aws-sdk-go/service/sts"
)

Expand Down Expand Up @@ -173,8 +174,25 @@ func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers req
return nil, err
}

var optFns []func(provider *ssocreds.Provider)
cfgCopy := cfg.Copy()
cfgCopy.Region = &sharedCfg.SSORegion

if sharedCfg.SSOSession != nil {
cfgCopy.Region = &sharedCfg.SSOSession.SSORegion
cachedPath, err := ssocreds.StandardCachedTokenFilepath(sharedCfg.SSOSession.Name)
if err != nil {
return nil, err
}
mySession := Must(NewSession())
oidcClient := ssooidc.New(mySession, cfgCopy)
tokenProvider := ssocreds.NewSSOTokenProvider(oidcClient, cachedPath)
optFns = append(optFns, func(p *ssocreds.Provider) {
p.SSOTokenProvider = tokenProvider
p.CachedTokenFilepath = cachedPath
})
} else {
cfgCopy.Region = &sharedCfg.SSORegion
}

return ssocreds.NewCredentials(
&Session{
Expand All @@ -184,6 +202,7 @@ func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers req
sharedCfg.SSOAccountID,
sharedCfg.SSORoleName,
sharedCfg.SSOStartURL,
optFns...,
), nil
}

Expand Down