Skip to content

Commit

Permalink
aws/ec2metadata: Add support for EC2Metadata client secure token (a…
Browse files Browse the repository at this point in the history
…ws#453)

Adds support for EC2Metadata client to use secure tokens provided by the IMDS. Modifies and adds tests to verify the behavior of the EC2Metadata client.

Fixes aws#457
  • Loading branch information
skotambkar authored and jasdel committed Dec 17, 2019
1 parent 0734f39 commit 13e3dc8
Show file tree
Hide file tree
Showing 7 changed files with 1,233 additions and 188 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SDK Features

SDK Enhancements
---
* `aws/ec2metadata`: Adds support for EC2Metadata client to use secure tokens provided by the IMDS ([#453](https://github.com/aws/aws-sdk-go-v2/pull/453))
* Modifies and adds tests to verify the behavior of the EC2Metadata client.

SDK Bugs
--
Expand Down
108 changes: 92 additions & 16 deletions aws/ec2metadata/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net"
"net/http"
"os"
"strconv"
"strings"
"time"

Expand All @@ -22,7 +23,25 @@ import (
"github.com/aws/aws-sdk-go-v2/aws/defaults"
)

const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
const (
// ServiceName is the name of the service.
ServiceName = "ec2metadata"
disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"

// Headers for Token and TTL
ttlHeader = "x-aws-ec2-metadata-token-ttl-seconds"
tokenHeader = "x-aws-ec2-metadata-token"

// Named Handler constants
fetchTokenHandlerName = "FetchTokenHandler"
unmarshalMetadataHandlerName = "unmarshalMetadataHandler"
unmarshalTokenHandlerName = "unmarshalTokenHandler"
enableTokenProviderHandlerName = "enableTokenProviderHandler"

// TTL constants
defaultTTL = 21600 * time.Second
ttlExpirationWindow = 30 * time.Second
)

// A Client is an EC2 Instance Metadata service Client.
type Client struct {
Expand Down Expand Up @@ -61,7 +80,20 @@ func New(config aws.Config) *Client {
),
}

svc.Handlers.Unmarshal.PushBack(unmarshalHandler)
// token provider instance
tp := newTokenProvider(svc, defaultTTL)
// NamedHandler for fetching token
svc.Handlers.Sign.PushBackNamed(aws.NamedHandler{
Name: fetchTokenHandlerName,
Fn: tp.fetchTokenHandler,
})
// NamedHandler for enabling token provider
svc.Handlers.Complete.PushBackNamed(aws.NamedHandler{
Name: enableTokenProviderHandlerName,
Fn: tp.enableTokenProviderHandler,
})

svc.Handlers.Unmarshal.PushBackNamed(unmarshalHandler)
svc.Handlers.UnmarshalError.PushBack(unmarshalError)
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)
Expand Down Expand Up @@ -91,30 +123,74 @@ type metadataOutput struct {
Content string
}

func unmarshalHandler(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
return
}
type tokenOutput struct {
Token string
TTL time.Duration
}

if data, ok := r.Data.(*metadataOutput); ok {
data.Content = b.String()
}
// unmarshal token handler is used to parse the response of a getToken operation
var unmarshalTokenHandler = aws.NamedHandler{
Name: unmarshalTokenHandlerName,
Fn: func(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
var b bytes.Buffer
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(aws.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}

v := r.HTTPResponse.Header.Get(ttlHeader)
data, ok := r.Data.(*tokenOutput)
if !ok {
return
}

data.Token = b.String()
// TTL is in seconds
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(aws.ParamFormatErrCode,
"unable to parse EC2 token TTL response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
t := time.Duration(i) * time.Second
data.TTL = t
},
}

var unmarshalHandler = aws.NamedHandler{
Name: unmarshalMetadataHandlerName,
Fn: func(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
var b bytes.Buffer
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(aws.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}

if data, ok := r.Data.(*metadataOutput); ok {
data.Content = b.String()
}
},
}

func unmarshalError(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
var b bytes.Buffer

if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(aws.ErrCodeSerialization, "unable to unmarshal EC2 metadata error response", err),
r.HTTPResponse.StatusCode, r.RequestID)
return
}

// Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error
r.Error = awserr.New("EC2MetadataError", "failed to make Client request", errors.New(b.String()))
r.Error = awserr.NewRequestFailure(awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())),
r.HTTPResponse.StatusCode, r.RequestID)
}

func validateEndpointHandler(r *aws.Request) {
Expand Down
2 changes: 1 addition & 1 deletion aws/ec2metadata/api_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestClientDisableIMDS(t *testing.T) {
cfg.Logger = t

svc := ec2metadata.New(cfg)
resp, err := svc.Region()
resp, err := svc.GetUserData()
if err == nil {
t.Fatalf("expect error, got none")
}
Expand Down
54 changes: 43 additions & 11 deletions aws/ec2metadata/api_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,49 @@ import (
"fmt"
"net/http"
"path"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/awserr"
)

// getToken uses the duration to return a token for EC2 metadata service,
// or an error if the request failed.
func (c *Client) getToken(duration time.Duration) (tokenOutput, error) {
op := &aws.Operation{
Name: "GetToken",
HTTPMethod: "PUT",
HTTPPath: "/api/token",
}

var output tokenOutput
req := c.NewRequest(op, nil, &output)

// remove the fetch token handler from the request handlers to avoid infinite recursion
req.Handlers.Sign.RemoveByName(fetchTokenHandlerName)

// Swap the unmarshalMetadataHandler with unmarshalTokenHandler on this request.
req.Handlers.Unmarshal.Swap(unmarshalMetadataHandlerName, unmarshalTokenHandler)

ttl := strconv.FormatInt(int64(duration/time.Second), 10)
req.HTTPRequest.Header.Set(ttlHeader, ttl)

err := req.Send()

// Errors with bad request status should be returned.
if err != nil {
err = awserr.NewRequestFailure(
awserr.New(req.HTTPResponse.Status, http.StatusText(req.HTTPResponse.StatusCode), err),
req.HTTPResponse.StatusCode, req.RequestID)
}

return output, err
}

// GetMetadata uses the path provided to request information from the EC2
// instance metdata service. The content will be returned as a string, or
// instance metadata service. The content will be returned as a string, or
// error if the request failed.
func (c *Client) GetMetadata(p string) (string, error) {
op := &aws.Operation{
Expand All @@ -40,12 +74,6 @@ func (c *Client) GetUserData() (string, error) {

output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
req.Handlers.UnmarshalError.PushBack(func(r *aws.Request) {
if r.HTTPResponse.StatusCode == http.StatusNotFound {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})

return output.Content, req.Send()
}

Expand Down Expand Up @@ -113,13 +141,17 @@ func (c *Client) IAMInfo() (EC2IAMInfo, error) {

// Region returns the region the instance is running in.
func (c *Client) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone")
ec2InstanceIdentityDocument, err := c.GetInstanceIdentityDocument()
if err != nil {
return "", err
}

// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
// extract region from the ec2InstanceIdentityDocument
region := ec2InstanceIdentityDocument.Region
if len(region) == 0 {
return "", awserr.New("EC2MetadataError", "invalid region received for ec2metadata instance", nil)
}
// returns region
return region, nil
}

// Available returns if the application has access to the EC2 Instance Metadata
Expand Down
Loading

0 comments on commit 13e3dc8

Please sign in to comment.