From 5c7f7342179d007e9e779ffe8734d540cdf36fde Mon Sep 17 00:00:00 2001 From: sai-sunder-s <4540365+sai-sunder-s@users.noreply.github.com> Date: Sat, 5 Feb 2022 00:11:49 +0000 Subject: [PATCH] fix: Add AWS session token to metadata requests (#958) * fix: add aws session token to metadata requests * Fix tests * update config name * Add test for session token * add coverage * add coverage * Run blacken and lint --- google/auth/aws.py | 65 +++++++++++++++++--- tests/test_aws.py | 147 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 198 insertions(+), 14 deletions(-) diff --git a/google/auth/aws.py b/google/auth/aws.py index 2fd96d058..358a1cf96 100644 --- a/google/auth/aws.py +++ b/google/auth/aws.py @@ -406,6 +406,7 @@ def __init__( self._cred_verification_url = credential_source.get( "regional_cred_verification_url" ) + self._aws_session_token_url = credential_source.get("aws_session_token_url") self._region = None self._request_signer = None self._target_resource = audience @@ -458,15 +459,34 @@ def retrieve_subject_token(self, request): Returns: str: The retrieved subject token. """ + # Fetch the session token required to make meta data endpoint calls to aws + if request is not None and self._aws_session_token_url is not None: + headers = {"X-aws-ec2-metadata-token-ttl-seconds": "21600"} + + session_token_response = request( + url=self._aws_session_token_url, method="PUT", headers=headers + ) + + if session_token_response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve AWS Session Token", session_token_response.data + ) + + session_token = session_token_response.data + else: + session_token = None + # Initialize the request signer if not yet initialized after determining # the current AWS region. if self._request_signer is None: - self._region = self._get_region(request, self._region_url) + self._region = self._get_region(request, self._region_url, session_token) self._request_signer = RequestSigner(self._region) # Retrieve the AWS security credentials needed to generate the signed # request. - aws_security_credentials = self._get_security_credentials(request) + aws_security_credentials = self._get_security_credentials( + request, session_token + ) # Generate the signed request to AWS STS GetCallerIdentity API. # Use the required regional endpoint. Otherwise, the request will fail. request_options = self._request_signer.get_request_options( @@ -511,7 +531,7 @@ def retrieve_subject_token(self, request): json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True) ) - def _get_region(self, request, url): + def _get_region(self, request, url, session_token): """Retrieves the current AWS region from either the AWS_REGION or AWS_DEFAULT_REGION environment variable or from the AWS metadata server. @@ -519,6 +539,8 @@ def _get_region(self, request, url): request (google.auth.transport.Request): A callable used to make HTTP requests. url (str): The AWS metadata server region URL. + session_token (str): The AWS session token to be added as a + header in the requests to AWS metadata endpoint. Returns: str: The current AWS region. @@ -540,7 +562,12 @@ def _get_region(self, request, url): if not self._region_url: raise exceptions.RefreshError("Unable to determine AWS region") - response = request(url=self._region_url, method="GET") + + headers = None + if session_token is not None: + headers = {"X-aws-ec2-metadata-token": session_token} + + response = request(url=self._region_url, method="GET", headers=headers) # Support both string and bytes type response.data. response_body = ( @@ -558,7 +585,7 @@ def _get_region(self, request, url): # Only the us-east-2 part should be used. return response_body[:-1] - def _get_security_credentials(self, request): + def _get_security_credentials(self, request, session_token): """Retrieves the AWS security credentials required for signing AWS requests from either the AWS security credentials environment variables or from the AWS metadata server. @@ -566,6 +593,8 @@ def _get_security_credentials(self, request): Args: request (google.auth.transport.Request): A callable used to make HTTP requests. + session_token (str): The AWS session token to be added as a + header in the requests to AWS metadata endpoint. Returns: Mapping[str, str]: The AWS security credentials dictionary object. @@ -591,10 +620,12 @@ def _get_security_credentials(self, request): } # Get role name. - role_name = self._get_metadata_role_name(request) + role_name = self._get_metadata_role_name(request, session_token) # Get security credentials. - credentials = self._get_metadata_security_credentials(request, role_name) + credentials = self._get_metadata_security_credentials( + request, role_name, session_token + ) return { "access_key_id": credentials.get("AccessKeyId"), @@ -602,7 +633,7 @@ def _get_security_credentials(self, request): "security_token": credentials.get("Token"), } - def _get_metadata_security_credentials(self, request, role_name): + def _get_metadata_security_credentials(self, request, role_name, session_token): """Retrieves the AWS security credentials required for signing AWS requests from the AWS metadata server. @@ -612,6 +643,8 @@ def _get_metadata_security_credentials(self, request, role_name): role_name (str): The AWS role name required by the AWS metadata server security_credentials endpoint in order to return the credentials. + session_token (str): The AWS session token to be added as a + header in the requests to AWS metadata endpoint. Returns: Mapping[str, str]: The AWS metadata server security credentials @@ -622,6 +655,9 @@ def _get_metadata_security_credentials(self, request, role_name): retrieving the AWS security credentials. """ headers = {"Content-Type": "application/json"} + if session_token is not None: + headers["X-aws-ec2-metadata-token"] = session_token + response = request( url="{}/{}".format(self._security_credentials_url, role_name), method="GET", @@ -644,7 +680,7 @@ def _get_metadata_security_credentials(self, request, role_name): return credentials_response - def _get_metadata_role_name(self, request): + def _get_metadata_role_name(self, request, session_token): """Retrieves the AWS role currently attached to the current AWS workload by querying the AWS metadata server. This is needed for the AWS metadata server security credentials endpoint in order to retrieve @@ -653,6 +689,8 @@ def _get_metadata_role_name(self, request): Args: request (google.auth.transport.Request): A callable used to make HTTP requests. + session_token (str): The AWS session token to be added as a + header in the requests to AWS metadata endpoint. Returns: str: The AWS role name. @@ -665,7 +703,14 @@ def _get_metadata_role_name(self, request): raise exceptions.RefreshError( "Unable to determine the AWS metadata server security credentials endpoint" ) - response = request(url=self._security_credentials_url, method="GET") + + headers = None + if session_token is not None: + headers = {"X-aws-ec2-metadata-token": session_token} + + response = request( + url=self._security_credentials_url, method="GET", headers=headers + ) # support both string and bytes type response.data response_body = ( diff --git a/tests/test_aws.py b/tests/test_aws.py index d37131afb..2bace5d07 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -42,6 +42,7 @@ SUBJECT_TOKEN_TYPE = "urn:ietf:params:aws:token-type:aws4_request" AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID" REGION_URL = "http://169.254.169.254/latest/meta-data/placement/availability-zone" +AWS_SESSION_TOKEN_URL = "http://169.254.169.254/latest/api/token" SECURITY_CREDS_URL = "http://169.254.169.254/latest/meta-data/iam/security-credentials" CRED_VERIFICATION_URL = ( "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" @@ -578,6 +579,7 @@ class TestCredentials(object): "SecretAccessKey": SECRET_ACCESS_KEY, "Token": TOKEN, } + AWS_SESSION_TOKEN = "awssessiontoken" AWS_SIGNATURE_TIME = "2020-08-11T06:55:22Z" CREDENTIAL_SOURCE = { "environment_id": "aws1", @@ -654,6 +656,8 @@ def make_mock_request( token_data=None, impersonation_status=None, impersonation_data=None, + session_token_status=None, + session_token_data=None, ): """Utility function to generate a mock HTTP request object. This will facilitate testing various edge cases by specify how the @@ -661,6 +665,13 @@ def make_mock_request( in an AWS environment. """ responses = [] + if session_token_status: + # AWS session token request + session_response = mock.create_autospec(transport.Response, instance=True) + session_response.status = session_token_status + session_response.data = session_token_data + responses.append(session_response) + if region_status: # AWS region request. region_response = mock.create_autospec(transport.Response, instance=True) @@ -735,14 +746,16 @@ def make_credentials( ) @classmethod - def assert_aws_metadata_request_kwargs(cls, request_kwargs, url, headers=None): + def assert_aws_metadata_request_kwargs( + cls, request_kwargs, url, headers=None, method="GET" + ): assert request_kwargs["url"] == url # All used AWS metadata server endpoints use GET HTTP method. - assert request_kwargs["method"] == "GET" + assert request_kwargs["method"] == method if headers: assert request_kwargs["headers"] == headers else: - assert "headers" not in request_kwargs + assert "headers" not in request_kwargs or request_kwargs["headers"] is None # None of the endpoints used require any data in request. assert "body" not in request_kwargs @@ -995,7 +1008,7 @@ def test_retrieve_subject_token_success_temp_creds_no_environment_vars( credentials.retrieve_subject_token(new_request) - # Only 2 requests should be sent as the region is cached. + # Only 3 requests should be sent as the region is cached. assert len(new_request.call_args_list) == 2 # Assert role request. self.assert_aws_metadata_request_kwargs( @@ -1008,6 +1021,132 @@ def test_retrieve_subject_token_success_temp_creds_no_environment_vars( {"Content-Type": "application/json"}, ) + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_temp_creds_no_environment_vars_idmsv2( + self, utcnow + ): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + session_token_status=http_client.OK, + session_token_data=self.AWS_SESSION_TOKEN, + ) + credential_source_token_url = self.CREDENTIAL_SOURCE.copy() + credential_source_token_url["aws_session_token_url"] = AWS_SESSION_TOKEN_URL + credentials = self.make_credentials( + credential_source=credential_source_token_url + ) + + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + # Assert session token request + self.assert_aws_metadata_request_kwargs( + request.call_args_list[0][1], + AWS_SESSION_TOKEN_URL, + {"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + "PUT", + ) + # Assert region request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[1][1], + REGION_URL, + {"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN}, + ) + # Assert role request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[2][1], + SECURITY_CREDS_URL, + {"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN}, + ) + # Assert security credentials request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[3][1], + "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE), + { + "Content-Type": "application/json", + "X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN, + }, + ) + + # Retrieve subject_token again. Region should not be queried again. + new_request = self.make_mock_request( + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + session_token_status=http_client.OK, + session_token_data=self.AWS_SESSION_TOKEN, + ) + + credentials.retrieve_subject_token(new_request) + + # Only 3 requests should be sent as the region is cached. + assert len(new_request.call_args_list) == 3 + # Assert session token request + self.assert_aws_metadata_request_kwargs( + request.call_args_list[0][1], + AWS_SESSION_TOKEN_URL, + {"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + "PUT", + ) + # Assert role request. + self.assert_aws_metadata_request_kwargs( + new_request.call_args_list[1][1], + SECURITY_CREDS_URL, + {"X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN}, + ) + # Assert security credentials request. + self.assert_aws_metadata_request_kwargs( + new_request.call_args_list[2][1], + "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE), + { + "Content-Type": "application/json", + "X-aws-ec2-metadata-token": self.AWS_SESSION_TOKEN, + }, + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_session_error_idmsv2(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request( + session_token_status=http_client.UNAUTHORIZED, + session_token_data="unauthorized", + ) + credential_source_token_url = self.CREDENTIAL_SOURCE.copy() + credential_source_token_url["aws_session_token_url"] = AWS_SESSION_TOKEN_URL + credentials = self.make_credentials( + credential_source=credential_source_token_url + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(request) + + assert excinfo.match(r"Unable to retrieve AWS Session Token") + + # Assert session token request + self.assert_aws_metadata_request_kwargs( + request.call_args_list[0][1], + AWS_SESSION_TOKEN_URL, + {"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + "PUT", + ) + @mock.patch("google.auth._helpers.utcnow") def test_retrieve_subject_token_success_permanent_creds_no_environment_vars( self, utcnow