Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AWS auth backend iam_request_headers to be TypeHeader #5320

Merged
merged 16 commits into from
Sep 12, 2018
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import (
logicaltest "github.com/hashicorp/vault/logical/testing"
)

const testVaultHeaderValue = "VaultAcceptanceTesting"
const testValidRoleName = "valid-role"
const testInvalidRoleName = "invalid-role"

func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
// create a backend
config := logical.TestBackendConfig()
Expand Down Expand Up @@ -1510,9 +1514,6 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
// it allows us to login to our role
// 6. Pass in a request that has a validly signed request, asking for
// the other role, ensure it fails
const testVaultHeaderValue = "VaultAcceptanceTesting"
const testValidRoleName = "valid-role"
const testInvalidRoleName = "invalid-role"

clientConfigData := map[string]interface{}{
"iam_server_id_header_value": testVaultHeaderValue,
Expand Down
51 changes: 4 additions & 47 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"crypto/subtle"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"
"time"
Expand Down Expand Up @@ -89,7 +87,7 @@ when using iam auth_type.`,
This must match the request body included in the signature.`,
},
"iam_request_headers": {
Type: framework.TypeString,
Type: framework.TypeHeader,
Description: `Base64-encoded JSON representation of the request headers when auth_type is
catsby marked this conversation as resolved.
Show resolved Hide resolved
iam. This must at a minimum include the headers over
which AWS has included a signature.`,
Expand Down Expand Up @@ -1149,17 +1147,10 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
}
body := string(bodyRaw)

headersB64 := data.Get("iam_request_headers").(string)
if headersB64 == "" {
headers := data.Get("iam_request_headers").(http.Header)
if len(headers) == 0 {
return logical.ErrorResponse("missing iam_request_headers"), nil
}
headers, err := parseIamRequestHeaders(headersB64)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error parsing iam_request_headers: %v", err)), nil
}
if headers == nil {
return logical.ErrorResponse("nil response when parsing iam_request_headers"), nil
}

config, err := b.lockedClientConfigEntry(ctx, req.Storage)
if err != nil {
Expand Down Expand Up @@ -1491,41 +1482,6 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
return result, err
}

func parseIamRequestHeaders(headersB64 string) (http.Header, error) {
headersJson, err := base64.StdEncoding.DecodeString(headersB64)
if err != nil {
return nil, fmt.Errorf("failed to base64 decode iam_request_headers")
}
var headersDecoded map[string]interface{}
err = jsonutil.DecodeJSON(headersJson, &headersDecoded)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("failed to JSON decode iam_request_headers %q: {{err}}", headersJson), err)
}
headers := make(http.Header)
for k, v := range headersDecoded {
switch typedValue := v.(type) {
case string:
headers.Add(k, typedValue)
case json.Number:
headers.Add(k, typedValue.String())
case []interface{}:
for _, individualVal := range typedValue {
switch possibleStrVal := individualVal.(type) {
case string:
headers.Add(k, possibleStrVal)
case json.Number:
headers.Add(k, possibleStrVal.String())
default:
return nil, fmt.Errorf("header %q contains value %q that has type %s, not string", k, individualVal, reflect.TypeOf(individualVal))
}
}
default:
return nil, fmt.Errorf("header %q value %q has type %s, not string or []interface", k, typedValue, reflect.TypeOf(v))
}
}
return headers, nil
}

func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
// The protection against this is that this method will only call the endpoint specified in the
Expand All @@ -1536,6 +1492,7 @@ func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, bo
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}

response, err := client.Do(request)
if err != nil {
return nil, errwrap.Wrapf("error making request: {{err}}", err)
Expand Down
Loading