-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AWS Single Sign-On (SSO) Provider Support (#3755)
* Implement SSO Provider Support * Implement Credential Chain Support for SSO Provider
- Loading branch information
Showing
15 changed files
with
811 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Package ssocreds provides a credential provider for retrieving temporary AWS credentials using an SSO access token. | ||
// | ||
// IMPORTANT: The provider in this package does not initiate or perform the AWS SSO login flow. The SDK provider | ||
// expects that you have already performed the SSO login flow using AWS CLI using the "aws sso login" command, or by | ||
// some other mechanism. The provider must find a valid non-expired access token for the AWS SSO user portal URL in | ||
// ~/.aws/sso/cache. If a cached token is not found, it is expired, or the file is malformed an error will be returned. | ||
// | ||
// Loading AWS SSO credentials with the AWS shared configuration file | ||
// | ||
// You can use configure AWS SSO credentials from the AWS shared configuration file by | ||
// providing the specifying the required keys in the profile: | ||
// | ||
// sso_account_id | ||
// sso_region | ||
// sso_role_name | ||
// sso_start_url | ||
// | ||
// For example, the following defines a profile "devsso" and specifies the AWS SSO parameters that defines the target | ||
// account, role, sign-on portal, and the region where the user portal is located. Note: all SSO arguments must be | ||
// provided, or an error will be returned. | ||
// | ||
// [profile devsso] | ||
// sso_start_url = https://my-sso-portal.awsapps.com/start | ||
// sso_role_name = SSOReadOnlyRole | ||
// sso_region = us-east-1 | ||
// sso_account_id = 123456789012 | ||
// | ||
// Using the config module, you can load the AWS SDK shared configuration, and specify that this profile be used to | ||
// retrieve credentials. For example: | ||
// | ||
// sess, err := session.NewSessionWithOptions(session.Options{ | ||
// SharedConfigState: session.SharedConfigEnable, | ||
// Profile: "devsso", | ||
// }) | ||
// if err != nil { | ||
// return err | ||
// } | ||
// | ||
// Programmatically loading AWS SSO credentials directly | ||
// | ||
// You can programmatically construct the AWS SSO Provider in your application, and provide the necessary information | ||
// to load and retrieve temporary credentials using an access token from ~/.aws/sso/cache. | ||
// | ||
// svc := sso.New(sess, &aws.Config{ | ||
// Region: aws.String("us-west-2"), // Client Region must correspond to the AWS SSO user portal region | ||
// }) | ||
// | ||
// provider := ssocreds.NewCredentialsWithClient(svc, "123456789012", "SSOReadOnlyRole", "https://my-sso-portal.awsapps.com/start") | ||
// | ||
// credentials, err := provider.Get() | ||
// if err != nil { | ||
// return err | ||
// } | ||
// | ||
// Additional Resources | ||
// | ||
// Configuring the AWS CLI to use AWS Single Sign-On: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html | ||
// | ||
// AWS Single Sign-On User Guide: https://docs.aws.amazon.com/singlesignon/latest/userguide/what-is.html | ||
package ssocreds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// +build !windows | ||
|
||
package ssocreds | ||
|
||
import "os" | ||
|
||
func getHomeDirectory() string { | ||
return os.Getenv("HOME") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package ssocreds | ||
|
||
import "os" | ||
|
||
func getHomeDirectory() string { | ||
return os.Getenv("USERPROFILE") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
package ssocreds | ||
|
||
import ( | ||
"crypto/sha1" | ||
"encoding/hex" | ||
"encoding/json" | ||
"fmt" | ||
"io/ioutil" | ||
"path/filepath" | ||
"strings" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/awserr" | ||
"github.com/aws/aws-sdk-go/aws/client" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/sso" | ||
"github.com/aws/aws-sdk-go/service/sso/ssoiface" | ||
) | ||
|
||
// ErrCodeSSOProviderInvalidToken is the code type that is returned if loaded token has expired or is otherwise invalid. | ||
// To refresh the SSO session run aws sso login with the corresponding profile. | ||
const ErrCodeSSOProviderInvalidToken = "SSOProviderInvalidToken" | ||
|
||
const invalidTokenMessage = "the SSO session has expired or is invalid" | ||
|
||
func init() { | ||
nowTime = time.Now | ||
defaultCacheLocation = defaultCacheLocationImpl | ||
} | ||
|
||
var nowTime func() time.Time | ||
|
||
// ProviderName is the name of the provider used to specify the source of credentials. | ||
const ProviderName = "SSOProvider" | ||
|
||
var defaultCacheLocation func() string | ||
|
||
func defaultCacheLocationImpl() string { | ||
return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache") | ||
} | ||
|
||
// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token. | ||
type Provider struct { | ||
credentials.Expiry | ||
|
||
// The Client which is configured for the AWS Region where the AWS SSO user portal is located. | ||
Client ssoiface.SSOAPI | ||
|
||
// The AWS account that is assigned to the user. | ||
AccountID string | ||
|
||
// The role name that is assigned to the user. | ||
RoleName string | ||
|
||
// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal. | ||
StartURL string | ||
} | ||
|
||
// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured | ||
// for the AWS Region where the AWS SSO user portal is located. | ||
func NewCredentials(configProvider client.ConfigProvider, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials { | ||
return NewCredentialsWithClient(sso.New(configProvider), accountID, roleName, startURL, optFns...) | ||
} | ||
|
||
// NewCredentialsWithClient returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured | ||
// for the AWS Region where the AWS SSO user portal is located. | ||
func NewCredentialsWithClient(client ssoiface.SSOAPI, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials { | ||
p := &Provider{ | ||
Client: client, | ||
AccountID: accountID, | ||
RoleName: roleName, | ||
StartURL: startURL, | ||
} | ||
|
||
for _, fn := range optFns { | ||
fn(p) | ||
} | ||
|
||
return credentials.NewCredentials(p) | ||
} | ||
|
||
// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal | ||
// by exchanging the accessToken present in ~/.aws/sso/cache. | ||
func (p *Provider) Retrieve() (credentials.Value, error) { | ||
return p.RetrieveWithContext(aws.BackgroundContext()) | ||
} | ||
|
||
// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal | ||
// by exchanging the accessToken present in ~/.aws/sso/cache. | ||
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { | ||
tokenFile, err := loadTokenFile(p.StartURL) | ||
if err != nil { | ||
return credentials.Value{}, err | ||
} | ||
|
||
output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{ | ||
AccessToken: &tokenFile.AccessToken, | ||
AccountId: &p.AccountID, | ||
RoleName: &p.RoleName, | ||
}) | ||
if err != nil { | ||
return credentials.Value{}, err | ||
} | ||
|
||
expireTime := time.Unix(0, aws.Int64Value(output.RoleCredentials.Expiration)*int64(time.Millisecond)).UTC() | ||
p.SetExpiration(expireTime, 0) | ||
|
||
return credentials.Value{ | ||
AccessKeyID: aws.StringValue(output.RoleCredentials.AccessKeyId), | ||
SecretAccessKey: aws.StringValue(output.RoleCredentials.SecretAccessKey), | ||
SessionToken: aws.StringValue(output.RoleCredentials.SessionToken), | ||
ProviderName: ProviderName, | ||
}, nil | ||
} | ||
|
||
func getCacheFileName(url string) (string, error) { | ||
hash := sha1.New() | ||
_, err := hash.Write([]byte(url)) | ||
if err != nil { | ||
return "", err | ||
} | ||
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil | ||
} | ||
|
||
type rfc3339 time.Time | ||
|
||
func (r *rfc3339) UnmarshalJSON(bytes []byte) error { | ||
var value string | ||
|
||
if err := json.Unmarshal(bytes, &value); err != nil { | ||
return err | ||
} | ||
|
||
parse, err := time.Parse(time.RFC3339, value) | ||
if err != nil { | ||
return fmt.Errorf("expected RFC3339 timestamp: %v", err) | ||
} | ||
|
||
*r = rfc3339(parse) | ||
|
||
return nil | ||
} | ||
|
||
type token struct { | ||
AccessToken string `json:"accessToken"` | ||
ExpiresAt rfc3339 `json:"expiresAt"` | ||
Region string `json:"region,omitempty"` | ||
StartURL string `json:"startUrl,omitempty"` | ||
} | ||
|
||
func (t token) Expired() bool { | ||
return nowTime().Round(0).After(time.Time(t.ExpiresAt)) | ||
} | ||
|
||
func loadTokenFile(startURL string) (t token, err error) { | ||
key, err := getCacheFileName(startURL) | ||
if err != nil { | ||
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) | ||
} | ||
|
||
fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key)) | ||
if err != nil { | ||
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) | ||
} | ||
|
||
if err := json.Unmarshal(fileBytes, &t); err != nil { | ||
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) | ||
} | ||
|
||
if len(t.AccessToken) == 0 { | ||
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil) | ||
} | ||
|
||
if t.Expired() { | ||
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil) | ||
} | ||
|
||
return t, nil | ||
} |
Oops, something went wrong.