Skip to content

Commit

Permalink
AWS Single Sign-On (SSO) Provider Support (#3755)
Browse files Browse the repository at this point in the history
* Implement SSO Provider Support
* Implement Credential Chain Support for SSO Provider
  • Loading branch information
skmcgrail authored Jan 27, 2021
1 parent 7051404 commit 04e0775
Show file tree
Hide file tree
Showing 15 changed files with 811 additions and 39 deletions.
60 changes: 60 additions & 0 deletions aws/credentials/ssocreds/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Package ssocreds provides a credential provider for retrieving temporary AWS credentials using an SSO access token.
//
// IMPORTANT: The provider in this package does not initiate or perform the AWS SSO login flow. The SDK provider
// expects that you have already performed the SSO login flow using AWS CLI using the "aws sso login" command, or by
// some other mechanism. The provider must find a valid non-expired access token for the AWS SSO user portal URL in
// ~/.aws/sso/cache. If a cached token is not found, it is expired, or the file is malformed an error will be returned.
//
// Loading AWS SSO credentials with the AWS shared configuration file
//
// You can use configure AWS SSO credentials from the AWS shared configuration file by
// providing the specifying the required keys in the profile:
//
// sso_account_id
// sso_region
// sso_role_name
// sso_start_url
//
// For example, the following defines a profile "devsso" and specifies the AWS SSO parameters that defines the target
// account, role, sign-on portal, and the region where the user portal is located. Note: all SSO arguments must be
// provided, or an error will be returned.
//
// [profile devsso]
// sso_start_url = https://my-sso-portal.awsapps.com/start
// sso_role_name = SSOReadOnlyRole
// sso_region = us-east-1
// sso_account_id = 123456789012
//
// Using the config module, you can load the AWS SDK shared configuration, and specify that this profile be used to
// retrieve credentials. For example:
//
// sess, err := session.NewSessionWithOptions(session.Options{
// SharedConfigState: session.SharedConfigEnable,
// Profile: "devsso",
// })
// if err != nil {
// return err
// }
//
// Programmatically loading AWS SSO credentials directly
//
// You can programmatically construct the AWS SSO Provider in your application, and provide the necessary information
// to load and retrieve temporary credentials using an access token from ~/.aws/sso/cache.
//
// svc := sso.New(sess, &aws.Config{
// Region: aws.String("us-west-2"), // Client Region must correspond to the AWS SSO user portal region
// })
//
// provider := ssocreds.NewCredentialsWithClient(svc, "123456789012", "SSOReadOnlyRole", "https://my-sso-portal.awsapps.com/start")
//
// credentials, err := provider.Get()
// if err != nil {
// return err
// }
//
// Additional Resources
//
// Configuring the AWS CLI to use AWS Single Sign-On: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html
//
// AWS Single Sign-On User Guide: https://docs.aws.amazon.com/singlesignon/latest/userguide/what-is.html
package ssocreds
9 changes: 9 additions & 0 deletions aws/credentials/ssocreds/os.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// +build !windows

package ssocreds

import "os"

func getHomeDirectory() string {
return os.Getenv("HOME")
}
7 changes: 7 additions & 0 deletions aws/credentials/ssocreds/os_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ssocreds

import "os"

func getHomeDirectory() string {
return os.Getenv("USERPROFILE")
}
180 changes: 180 additions & 0 deletions aws/credentials/ssocreds/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package ssocreds

import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sso"
"github.com/aws/aws-sdk-go/service/sso/ssoiface"
)

// ErrCodeSSOProviderInvalidToken is the code type that is returned if loaded token has expired or is otherwise invalid.
// To refresh the SSO session run aws sso login with the corresponding profile.
const ErrCodeSSOProviderInvalidToken = "SSOProviderInvalidToken"

const invalidTokenMessage = "the SSO session has expired or is invalid"

func init() {
nowTime = time.Now
defaultCacheLocation = defaultCacheLocationImpl
}

var nowTime func() time.Time

// ProviderName is the name of the provider used to specify the source of credentials.
const ProviderName = "SSOProvider"

var defaultCacheLocation func() string

func defaultCacheLocationImpl() string {
return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache")
}

// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token.
type Provider struct {
credentials.Expiry

// The Client which is configured for the AWS Region where the AWS SSO user portal is located.
Client ssoiface.SSOAPI

// The AWS account that is assigned to the user.
AccountID string

// The role name that is assigned to the user.
RoleName string

// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
StartURL string
}

// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
// for the AWS Region where the AWS SSO user portal is located.
func NewCredentials(configProvider client.ConfigProvider, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
return NewCredentialsWithClient(sso.New(configProvider), accountID, roleName, startURL, optFns...)
}

// NewCredentialsWithClient returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured
// for the AWS Region where the AWS SSO user portal is located.
func NewCredentialsWithClient(client ssoiface.SSOAPI, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
p := &Provider{
Client: client,
AccountID: accountID,
RoleName: roleName,
StartURL: startURL,
}

for _, fn := range optFns {
fn(p)
}

return credentials.NewCredentials(p)
}

// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
tokenFile, err := loadTokenFile(p.StartURL)
if err != nil {
return credentials.Value{}, err
}

output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
AccessToken: &tokenFile.AccessToken,
AccountId: &p.AccountID,
RoleName: &p.RoleName,
})
if err != nil {
return credentials.Value{}, err
}

expireTime := time.Unix(0, aws.Int64Value(output.RoleCredentials.Expiration)*int64(time.Millisecond)).UTC()
p.SetExpiration(expireTime, 0)

return credentials.Value{
AccessKeyID: aws.StringValue(output.RoleCredentials.AccessKeyId),
SecretAccessKey: aws.StringValue(output.RoleCredentials.SecretAccessKey),
SessionToken: aws.StringValue(output.RoleCredentials.SessionToken),
ProviderName: ProviderName,
}, nil
}

func getCacheFileName(url string) (string, error) {
hash := sha1.New()
_, err := hash.Write([]byte(url))
if err != nil {
return "", err
}
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
}

type rfc3339 time.Time

func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
var value string

if err := json.Unmarshal(bytes, &value); err != nil {
return err
}

parse, err := time.Parse(time.RFC3339, value)
if err != nil {
return fmt.Errorf("expected RFC3339 timestamp: %v", err)
}

*r = rfc3339(parse)

return nil
}

type token struct {
AccessToken string `json:"accessToken"`
ExpiresAt rfc3339 `json:"expiresAt"`
Region string `json:"region,omitempty"`
StartURL string `json:"startUrl,omitempty"`
}

func (t token) Expired() bool {
return nowTime().Round(0).After(time.Time(t.ExpiresAt))
}

func loadTokenFile(startURL string) (t token, err error) {
key, err := getCacheFileName(startURL)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}

fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}

if err := json.Unmarshal(fileBytes, &t); err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}

if len(t.AccessToken) == 0 {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
}

if t.Expired() {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
}

return t, nil
}
Loading

0 comments on commit 04e0775

Please sign in to comment.