Skip to content

Commit

Permalink
fix: Add AWS session token to metadata requests (#958)
Browse files Browse the repository at this point in the history
* fix: add aws session token to metadata requests

* Fix tests

* update config name

* Add test for session token

* add coverage

* add coverage

* Run blacken and lint
  • Loading branch information
sai-sunder-s authored Feb 5, 2022
1 parent 3fd0987 commit 5c7f734
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 14 deletions.
65 changes: 55 additions & 10 deletions google/auth/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def __init__(
self._cred_verification_url = credential_source.get(
"regional_cred_verification_url"
)
self._aws_session_token_url = credential_source.get("aws_session_token_url")
self._region = None
self._request_signer = None
self._target_resource = audience
Expand Down Expand Up @@ -458,15 +459,34 @@ def retrieve_subject_token(self, request):
Returns:
str: The retrieved subject token.
"""
# Fetch the session token required to make meta data endpoint calls to aws
if request is not None and self._aws_session_token_url is not None:
headers = {"X-aws-ec2-metadata-token-ttl-seconds": "21600"}

session_token_response = request(
url=self._aws_session_token_url, method="PUT", headers=headers
)

if session_token_response.status != 200:
raise exceptions.RefreshError(
"Unable to retrieve AWS Session Token", session_token_response.data
)

session_token = session_token_response.data
else:
session_token = None

# Initialize the request signer if not yet initialized after determining
# the current AWS region.
if self._request_signer is None:
self._region = self._get_region(request, self._region_url)
self._region = self._get_region(request, self._region_url, session_token)
self._request_signer = RequestSigner(self._region)

# Retrieve the AWS security credentials needed to generate the signed
# request.
aws_security_credentials = self._get_security_credentials(request)
aws_security_credentials = self._get_security_credentials(
request, session_token
)
# Generate the signed request to AWS STS GetCallerIdentity API.
# Use the required regional endpoint. Otherwise, the request will fail.
request_options = self._request_signer.get_request_options(
Expand Down Expand Up @@ -511,14 +531,16 @@ def retrieve_subject_token(self, request):
json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True)
)

def _get_region(self, request, url):
def _get_region(self, request, url, session_token):
"""Retrieves the current AWS region from either the AWS_REGION or
AWS_DEFAULT_REGION environment variable or from the AWS metadata server.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
url (str): The AWS metadata server region URL.
session_token (str): The AWS session token to be added as a
header in the requests to AWS metadata endpoint.
Returns:
str: The current AWS region.
Expand All @@ -540,7 +562,12 @@ def _get_region(self, request, url):

if not self._region_url:
raise exceptions.RefreshError("Unable to determine AWS region")
response = request(url=self._region_url, method="GET")

headers = None
if session_token is not None:
headers = {"X-aws-ec2-metadata-token": session_token}

response = request(url=self._region_url, method="GET", headers=headers)

# Support both string and bytes type response.data.
response_body = (
Expand All @@ -558,14 +585,16 @@ def _get_region(self, request, url):
# Only the us-east-2 part should be used.
return response_body[:-1]

def _get_security_credentials(self, request):
def _get_security_credentials(self, request, session_token):
"""Retrieves the AWS security credentials required for signing AWS
requests from either the AWS security credentials environment variables
or from the AWS metadata server.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
session_token (str): The AWS session token to be added as a
header in the requests to AWS metadata endpoint.
Returns:
Mapping[str, str]: The AWS security credentials dictionary object.
Expand All @@ -591,18 +620,20 @@ def _get_security_credentials(self, request):
}

# Get role name.
role_name = self._get_metadata_role_name(request)
role_name = self._get_metadata_role_name(request, session_token)

# Get security credentials.
credentials = self._get_metadata_security_credentials(request, role_name)
credentials = self._get_metadata_security_credentials(
request, role_name, session_token
)

return {
"access_key_id": credentials.get("AccessKeyId"),
"secret_access_key": credentials.get("SecretAccessKey"),
"security_token": credentials.get("Token"),
}

def _get_metadata_security_credentials(self, request, role_name):
def _get_metadata_security_credentials(self, request, role_name, session_token):
"""Retrieves the AWS security credentials required for signing AWS
requests from the AWS metadata server.
Expand All @@ -612,6 +643,8 @@ def _get_metadata_security_credentials(self, request, role_name):
role_name (str): The AWS role name required by the AWS metadata
server security_credentials endpoint in order to return the
credentials.
session_token (str): The AWS session token to be added as a
header in the requests to AWS metadata endpoint.
Returns:
Mapping[str, str]: The AWS metadata server security credentials
Expand All @@ -622,6 +655,9 @@ def _get_metadata_security_credentials(self, request, role_name):
retrieving the AWS security credentials.
"""
headers = {"Content-Type": "application/json"}
if session_token is not None:
headers["X-aws-ec2-metadata-token"] = session_token

response = request(
url="{}/{}".format(self._security_credentials_url, role_name),
method="GET",
Expand All @@ -644,7 +680,7 @@ def _get_metadata_security_credentials(self, request, role_name):

return credentials_response

def _get_metadata_role_name(self, request):
def _get_metadata_role_name(self, request, session_token):
"""Retrieves the AWS role currently attached to the current AWS
workload by querying the AWS metadata server. This is needed for the
AWS metadata server security credentials endpoint in order to retrieve
Expand All @@ -653,6 +689,8 @@ def _get_metadata_role_name(self, request):
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
session_token (str): The AWS session token to be added as a
header in the requests to AWS metadata endpoint.
Returns:
str: The AWS role name.
Expand All @@ -665,7 +703,14 @@ def _get_metadata_role_name(self, request):
raise exceptions.RefreshError(
"Unable to determine the AWS metadata server security credentials endpoint"
)
response = request(url=self._security_credentials_url, method="GET")

headers = None
if session_token is not None:
headers = {"X-aws-ec2-metadata-token": session_token}

response = request(
url=self._security_credentials_url, method="GET", headers=headers
)

# support both string and bytes type response.data
response_body = (
Expand Down
147 changes: 143 additions & 4 deletions tests/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
SUBJECT_TOKEN_TYPE = "urn:ietf:params:aws:token-type:aws4_request"
AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID"
REGION_URL = "http://169.254.169.254/latest/meta-data/placement/availability-zone"
AWS_SESSION_TOKEN_URL = "http://169.254.169.254/latest/api/token"
SECURITY_CREDS_URL = "http://169.254.169.254/latest/meta-data/iam/security-credentials"
CRED_VERIFICATION_URL = (
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
Expand Down Expand Up @@ -578,6 +579,7 @@ class TestCredentials(object):
"SecretAccessKey": SECRET_ACCESS_KEY,
"Token": TOKEN,
}
AWS_SESSION_TOKEN = "awssessiontoken"
AWS_SIGNATURE_TIME = "2020-08-11T06:55:22Z"
CREDENTIAL_SOURCE = {
"environment_id": "aws1",
Expand Down Expand Up @@ -654,13 +656,22 @@ def make_mock_request(
token_data=None,
impersonation_status=None,
impersonation_data=None,
session_token_status=None,
session_token_data=None,
):
"""Utility function to generate a mock HTTP request object.
This will facilitate testing various edge cases by specify how the
various endpoints will respond while generating a Google Access token
in an AWS environment.
"""
responses = []
if session_token_status:
# AWS session token request
session_response = mock.create_autospec(transport.Response, instance=True)
session_response.status = session_token_status
session_response.data = session_token_data
responses.append(session_response)

if region_status:
# AWS region request.
region_response = mock.create_autospec(transport.Response, instance=True)
Expand Down Expand Up @@ -735,14 +746,16 @@ def make_credentials(
)

@classmethod
def assert_aws_metadata_request_kwargs(cls, request_kwargs, url, headers=None):
def assert_aws_metadata_request_kwargs(
cls, request_kwargs, url, headers=None, method="GET"
):
assert request_kwargs["url"] == url
# All used AWS metadata server endpoints use GET HTTP method.
assert request_kwargs["method"] == "GET"
assert request_kwargs["method"] == method
if headers:
assert request_kwargs["headers"] == headers
else:
assert "headers" not in request_kwargs
assert "headers" not in request_kwargs or request_kwargs["headers"] is None
# None of the endpoints used require any data in request.
assert "body" not in request_kwargs

Expand Down Expand Up @@ -995,7 +1008,7 @@ def test_retrieve_subject_token_success_temp_creds_no_environment_vars(

credentials.retrieve_subject_token(new_request)

# Only 2 requests should be sent as the region is cached.
# Only 3 requests should be sent as the region is cached.
assert len(new_request.call_args_list) == 2
# Assert role request.
self.assert_aws_metadata_request_kwargs(
Expand All @@ -1008,6 +1021,132 @@ def test_retrieve_subject_token_success_temp_creds_no_environment_vars(
{"Content-Type": "application/json"},
)

@mock.patch("google.auth._helpers.utcnow")
def test_retrieve_subject_token_success_temp_creds_no_environment_vars_idmsv2(
self, utcnow
):
utcnow.return_value = datetime.datetime.strptime(
self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
)
request = self.make_mock_request(
region_status=http_client.OK,
region_name=self.AWS_REGION,
role_status=http_client.OK,
role_name=self.AWS_ROLE,
security_credentials_status=http_client.OK,
security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
session_token_status=http_client.OK,
session_token_data=self.AWS_SESSION_TOKEN,
)
credential_source_token_url = self.CREDENTIAL_SOURCE.copy()
credential_source_token_url["aws_session_token_url"] = AWS_SESSION_TOKEN_URL
credentials = self.make_credentials(
credential_source=credential_source_token_url
)

subject_token = credentials.retrieve_subject_token(request)

assert subject_token == self.make_serialized_aws_signed_request(
{
"access_key_id": ACCESS_KEY_ID,
"secret_access_key": SECRET_ACCESS_KEY,
"security_token": TOKEN,
}
)
# Assert session token request
self.assert_aws_metadata_request_kwargs(
request.call_args_list[0][1],
AWS_SESSION_TOKEN_URL,
{"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
"PUT",
)
# Assert region request.
self.assert_aws_metadata_request_kwargs(
request.call_args_list[1][1],
REGION_URL,
{"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN},
)
# Assert role request.
self.assert_aws_metadata_request_kwargs(
request.call_args_list[2][1],
SECURITY_CREDS_URL,
{"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN},
)
# Assert security credentials request.
self.assert_aws_metadata_request_kwargs(
request.call_args_list[3][1],
"{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE),
{
"Content-Type": "application/json",
"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN,
},
)

# Retrieve subject_token again. Region should not be queried again.
new_request = self.make_mock_request(
role_status=http_client.OK,
role_name=self.AWS_ROLE,
security_credentials_status=http_client.OK,
security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
session_token_status=http_client.OK,
session_token_data=self.AWS_SESSION_TOKEN,
)

credentials.retrieve_subject_token(new_request)

# Only 3 requests should be sent as the region is cached.
assert len(new_request.call_args_list) == 3
# Assert session token request
self.assert_aws_metadata_request_kwargs(
request.call_args_list[0][1],
AWS_SESSION_TOKEN_URL,
{"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
"PUT",
)
# Assert role request.
self.assert_aws_metadata_request_kwargs(
new_request.call_args_list[1][1],
SECURITY_CREDS_URL,
{"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN},
)
# Assert security credentials request.
self.assert_aws_metadata_request_kwargs(
new_request.call_args_list[2][1],
"{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE),
{
"Content-Type": "application/json",
"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN,
},
)

@mock.patch("google.auth._helpers.utcnow")
def test_retrieve_subject_token_session_error_idmsv2(self, utcnow):
utcnow.return_value = datetime.datetime.strptime(
self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
)
request = self.make_mock_request(
session_token_status=http_client.UNAUTHORIZED,
session_token_data="unauthorized",
)
credential_source_token_url = self.CREDENTIAL_SOURCE.copy()
credential_source_token_url["aws_session_token_url"] = AWS_SESSION_TOKEN_URL
credentials = self.make_credentials(
credential_source=credential_source_token_url
)

with pytest.raises(exceptions.RefreshError) as excinfo:
credentials.retrieve_subject_token(request)

assert excinfo.match(r"Unable to retrieve AWS Session Token")

# Assert session token request
self.assert_aws_metadata_request_kwargs(
request.call_args_list[0][1],
AWS_SESSION_TOKEN_URL,
{"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
"PUT",
)

@mock.patch("google.auth._helpers.utcnow")
def test_retrieve_subject_token_success_permanent_creds_no_environment_vars(
self, utcnow
Expand Down

0 comments on commit 5c7f734

Please sign in to comment.