Skip to content

Commit

Permalink
Add an sts_region parameter to the AWS auth engine's client config (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tyrannosaurus-becks committed Dec 11, 2019
1 parent b8c8e39 commit 41ec99a
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 5 deletions.
11 changes: 8 additions & 3 deletions builtin/credential/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, reg
endpoint := aws.String("")
var maxRetries int = aws.UseServiceDefaultRetries
if config != nil {
// Override the default endpoint with the configured endpoint.
// Override the defaults with configured values.
switch {
case clientType == "ec2" && config.Endpoint != "":
endpoint = aws.String(config.Endpoint)
case clientType == "iam" && config.IAMEndpoint != "":
endpoint = aws.String(config.IAMEndpoint)
case clientType == "sts" && config.STSEndpoint != "":
endpoint = aws.String(config.STSEndpoint)
case clientType == "sts":
if config.STSEndpoint != "" {
endpoint = aws.String(config.STSEndpoint)
}
if config.STSRegion != "" {
region = config.STSRegion
}
}

credsConfig.AccessKey = config.AccessKey
Expand Down
20 changes: 19 additions & 1 deletion builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ func (b *backend) pathConfigClient() *framework.Path {
Description: "URL to override the default generated endpoint for making AWS STS API calls.",
},

"sts_region": {
Type: framework.TypeString,
Default: "",
Description: "The region ID for the sts_endpoint, if set.",
},

"iam_server_id_header_value": {
Type: framework.TypeString,
Default: "",
Expand Down Expand Up @@ -127,6 +133,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"endpoint": clientConfig.Endpoint,
"iam_endpoint": clientConfig.IAMEndpoint,
"sts_endpoint": clientConfig.STSEndpoint,
"sts_region": clientConfig.STSRegion,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries,
},
Expand Down Expand Up @@ -217,7 +224,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
stsEndpointStr, ok := data.GetOk("sts_endpoint")
if ok {
if configEntry.STSEndpoint != stsEndpointStr.(string) {
// We don't directly cache STS clients as they are ever directly used.
// We don't directly cache STS clients as they are never directly used.
// However, they are potentially indirectly used as credential providers
// for the EC2 and IAM clients, and thus we would be indirectly caching
// them there. So, if we change the STS endpoint, we should flush those
Expand All @@ -229,6 +236,16 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
configEntry.STSEndpoint = data.Get("sts_endpoint").(string)
}

stsRegionStr, ok := data.GetOk("sts_region")
if ok {
if configEntry.STSRegion != stsRegionStr.(string) {
// Region is used when building STS clients. As such, all the comments
// regarding the sts_endpoint changing apply here as well.
changedCreds = true
configEntry.STSRegion = stsRegionStr.(string)
}
}

headerValStr, ok := data.GetOk("iam_server_id_header_value")
if ok {
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
Expand Down Expand Up @@ -281,6 +298,7 @@ type clientConfig struct {
Endpoint string `json:"endpoint"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
MaxRetries int `json:"max_retries"`
}
Expand Down
20 changes: 19 additions & 1 deletion builtin/credential/aws/path_config_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func TestBackend_pathConfigClient(t *testing.T) {

data := map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint.example.com",
"sts_region": "us-east-2",
"iam_server_id_header_value": "vault_server_identification_314159",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Expand All @@ -52,7 +53,6 @@ func TestBackend_pathConfigClient(t *testing.T) {
Data: data,
Storage: storage,
})

if err != nil {
t.Fatal(err)
}
Expand All @@ -75,8 +75,18 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
}
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}

data = map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint2.example.com",
"sts_region": "us-west-1",
"iam_server_id_header_value": "vault_server_identification_2718281",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Expand Down Expand Up @@ -108,4 +118,12 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
}
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}
}
87 changes: 87 additions & 0 deletions builtin/credential/aws/path_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package awsauth

import (
"context"
"os"
"reflect"
"strings"
"testing"

"github.com/go-test/deep"
"github.com/hashicorp/vault/helper/awsutil"
vlttesting "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -986,6 +989,90 @@ func TestAwsVersion(t *testing.T) {
}
}

// This test was used to reproduce https://github.com/hashicorp/vault/issues/7418
// and verify its fix.
// Please run it at least 3 times to ensure that passing tests are due to actually
// passing, rather than the region being randomly chosen tying to the one in the
// test through luck.
func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
if enabled := os.Getenv(vlttesting.TestEnvVar); enabled == "" {
t.Skip()
}

/* ARN of an AWS role that Vault can query during testing.
This role should exist in your current AWS account and your credentials
should have iam:GetRole permissions to query it.
*/
assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN")
if assumableRoleArn == "" {
t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset")
}

// Ensure aws credentials are available locally for testing.
credsConfig := &awsutil.CredentialsConfig{}
credsChain, err := credsConfig.GenerateCredentialChain()
if err != nil {
t.Fatal(err)
}
_, err = credsChain.Get()
if err != nil {
t.SkipNow()
}

config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage

b, err := Backend(config)
if err != nil {
t.Fatal(err)
}

err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}

// configure the client with an sts endpoint that should be used in creating the role
data := map[string]interface{}{
"sts_endpoint": "https://sts.eu-west-1.amazonaws.com",
// Note - if you comment this out, you can reproduce the error shown
// in the linked GH issue above. This essentially reproduces the problem
// we had when we didn't have an sts_region field.
"sts_region": "eu-west-1",
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "config/client",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}

data = map[string]interface{}{
"auth_type": iamAuthType,
"bound_iam_principal_arn": assumableRoleArn,
"resolve_aws_unique_ids": true,
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "role/MyRoleName",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}
}

func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) {
return "FakeUniqueId1", nil
}

0 comments on commit 41ec99a

Please sign in to comment.