Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use IMDSv2 for autodiscovering EC2 region #5207

Merged
merged 3 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement,-IMDS-59801.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement,",
"category": "IMDS",
"description": "Use IMDSv2 for autodiscovering EC2 region. Fixes `#4976 <https://github.com/aws/aws-cli/issues/4976>`__"
}
2 changes: 2 additions & 0 deletions awscli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def retrieve_region(self):
return None

def _get_region(self):
token = self._fetch_metadata_token()
response = self._get_request(
url_path=self._URL_PATH,
retry_func=self._default_retry,
token=token
)
availability_zone = response.text
region = availability_zone[:-1]
Expand Down
47 changes: 37 additions & 10 deletions tests/functional/test_clidriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@
from awscli.clidriver import create_clidriver


class RegionCapture(object):
def __init__(self):
self.region = None

def __call__(self, request, **kwargs):
url = request.url
region = re.match(
'https://.*?\.(.*?)\.amazonaws\.com', url).groups(1)[0]
self.region = region


class TestSession(BaseCLIDriverTest):
def setUp(self):
super(TestSession, self).setUp()
Expand Down Expand Up @@ -54,27 +65,43 @@ def add_response(self, body, status_code=200):
)
self._responses.append(response)

def assert_correct_region(self, expected_region, request, **kwargs):
url = request.url
region = re.match(
'https://.*?\.(.*?)\.amazonaws\.com', url).groups(1)[0]
self.assertEqual(expected_region, region)
def test_imds_region_is_used_as_fallback_wo_v2_support(self):
# Remove region override from the environment variables.
self.environ.pop('AWS_DEFAULT_REGION', 0)
# First response should be from the IMDS server for security token
# if server supports IMDSv1 only there will be no response for token
self.add_response(None)
# Then another response from the IMDS server for an availibility
# zone.
self.add_response(b'us-mars-2a')
# Once a region is fetched form the IMDS server we need to mock an
# XML response from ec2 so that the CLI driver doesn't throw an error
# during parsing.
self.add_response(
b'<?xml version="1.0" ?><foo><bar>text</bar></foo>')
capture = RegionCapture()
self.session.register('before-send.ec2.*', capture)
self.driver.main(['ec2', 'describe-instances'])
self.assertEqual(capture.region, 'us-mars-2')

def test_imds_region_is_used_as_fallback(self):
def test_imds_region_is_used_as_fallback_with_v2_support(self):
# Remove region override from the environment variables.
self.environ.pop('AWS_DEFAULT_REGION', 0)
# First response should be from the IMDS server for an availibility
# First response should be from the IMDS server for security token
# if server supports IMDSv2 it'll return token
self.add_response(b'token')
# Then another response from the IMDS server for an availibility
# zone.
self.add_response(b'us-mars-2a')
# Once a region is fetched form the IMDS server we need to mock an
# XML response from ec2 so that the CLI driver doesn't throw an error
# during parsing.
self.add_response(
b'<?xml version="1.0" ?><foo><bar>text</bar></foo>')
assert_correct_region = functools.partial(
self.assert_correct_region, 'us-mars-2')
self.session.register('before-send.ec2.*', assert_correct_region)
capture = RegionCapture()
self.session.register('before-send.ec2.*', capture)
self.driver.main(['ec2', 'describe-instances'])
self.assertEqual(capture.region, 'us-mars-2')


class TestPlugins(BaseCLIDriverTest):
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ def add_get_region_imds_response(self, region=None):
region = self._region
self.add_imds_response(body=region.encode('utf-8'))

def add_imds_token_response(self):
self.add_imds_response(status_code=200, body=b'token')

def add_imds_connection_error(self, exception):
self._imds_responses.append(exception)

Expand Down Expand Up @@ -351,6 +354,7 @@ def test_disabling_env_var_not_true(self):
url = 'https://example.com/'
env = {'AWS_EC2_METADATA_DISABLED': 'false'}

self.add_imds_token_response()
self.add_get_region_imds_response()

fetcher = InstanceMetadataRegionFetcher(base_url=url, env=env)
Expand All @@ -361,6 +365,7 @@ def test_disabling_env_var_not_true(self):

def test_includes_user_agent_header(self):
user_agent = 'my-user-agent'
self.add_imds_token_response()
self.add_get_region_imds_response()

InstanceMetadataRegionFetcher(
Expand All @@ -372,6 +377,7 @@ def test_includes_user_agent_header(self):
def test_non_200_response_for_region_is_retried(self):
# Response for role name that have a non 200 status code should
# be retried.
self.add_imds_token_response()
self.add_imds_response(
status_code=429, body=b'{"message": "Slow down"}')
self.add_get_region_imds_response()
Expand All @@ -393,6 +399,7 @@ def test_empty_response_for_region_is_retried(self):
def test_non_200_response_is_retried(self):
# Response for creds that has a 200 status code but has an empty
# body should be retried.
self.add_imds_token_response()
self.add_imds_response(
status_code=429, body=b'{"message": "Slow down"}')
self.add_get_region_imds_response()
Expand All @@ -403,6 +410,7 @@ def test_non_200_response_is_retried(self):

def test_http_connection_errors_is_retried(self):
# Connection related errors should be retried
self.add_imds_token_response()
self.add_imds_connection_error(ConnectionClosedError(endpoint_url=''))
self.add_get_region_imds_response()
result = InstanceMetadataRegionFetcher(
Expand All @@ -421,6 +429,7 @@ def test_empty_response_is_retried(self):
self.assertEqual(result, expected_result)

def test_exhaust_retries_on_region_request(self):
self.add_imds_token_response()
self.add_imds_response(status_code=400, body=b'')
result = InstanceMetadataRegionFetcher(
num_attempts=1).retrieve_region()
Expand Down