Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove wireserver fallback for imds calls #3152

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion azurelinuxagent/common/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def initialize_vminfo_common_parameters(self, protocol):
logger.warn("Failed to get VM info from goal state; will be missing from telemetry: {0}", ustr(e))

try:
imds_client = get_imds_client(protocol.get_endpoint())
imds_client = get_imds_client()
imds_info = imds_client.get_compute()
parameters[CommonTelemetryEventSchema.Location].value = imds_info.location
parameters[CommonTelemetryEventSchema.SubscriptionId].value = imds_info.subscriptionId
Expand Down
11 changes: 4 additions & 7 deletions azurelinuxagent/common/protocol/imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
IMDS_INTERNAL_SERVER_ERROR = 3


def get_imds_client(wireserver_endpoint):
return ImdsClient(wireserver_endpoint)
def get_imds_client():
return ImdsClient()


# A *slightly* future proof list of endorsed distros.
Expand Down Expand Up @@ -256,7 +256,7 @@ def image_origin(self):


class ImdsClient(object):
def __init__(self, wireserver_endpoint, version=APIVERSION):
def __init__(self, version=APIVERSION):
self._api_version = version
self._headers = {
'User-Agent': restutil.HTTP_USER_AGENT,
Expand All @@ -268,7 +268,6 @@ def __init__(self, wireserver_endpoint, version=APIVERSION):
}
self._regex_ioerror = re.compile(r".*HTTP Failed. GET http://[^ ]+ -- IOError .*")
self._regex_throttled = re.compile(r".*HTTP Retry. GET http://[^ ]+ -- Status Code 429 .*")
self._wireserver_endpoint = wireserver_endpoint

def _get_metadata_url(self, endpoint, resource_path):
return BASE_METADATA_URI.format(endpoint, resource_path, self._api_version)
Expand Down Expand Up @@ -326,14 +325,12 @@ def get_metadata(self, resource_path, is_health):
endpoint = IMDS_ENDPOINT

status, resp = self._get_metadata_from_endpoint(endpoint, resource_path, headers)
if status == IMDS_CONNECTION_ERROR:
endpoint = self._wireserver_endpoint
status, resp = self._get_metadata_from_endpoint(endpoint, resource_path, headers)

if status == IMDS_RESPONSE_SUCCESS:
return MetadataResult(True, False, resp)
elif status == IMDS_INTERNAL_SERVER_ERROR:
return MetadataResult(False, True, resp)
# else it's a client-side error, e.g. IMDS_CONNECTION_ERROR
return MetadataResult(False, False, resp)

def get_compute(self):
Expand Down
6 changes: 3 additions & 3 deletions azurelinuxagent/ga/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ class SendImdsHeartbeat(PeriodicOperation):
Periodic operation to report the IDMS's health. The signal is 'Healthy' when we have successfully called and validated
a response in the last _IMDS_HEALTH_PERIOD.
"""
def __init__(self, protocol_util, health_service):
def __init__(self, health_service):
super(SendImdsHeartbeat, self).__init__(SendImdsHeartbeat._IMDS_HEARTBEAT_PERIOD)
self.health_service = health_service
self.imds_client = get_imds_client(protocol_util.get_wireserver_endpoint())
self.imds_client = get_imds_client()
self.imds_error_state = ErrorState(min_timedelta=SendImdsHeartbeat._IMDS_HEALTH_PERIOD)

_IMDS_HEARTBEAT_PERIOD = datetime.timedelta(minutes=1)
Expand Down Expand Up @@ -298,7 +298,7 @@ def daemon(self):
PollResourceUsage(),
PollSystemWideResourceUsage(),
SendHostPluginHeartbeat(protocol, health_service),
SendImdsHeartbeat(protocol_util, health_service)
SendImdsHeartbeat(health_service)
]

report_network_configuration_changes = ReportNetworkConfigurationChanges()
Expand Down
2 changes: 1 addition & 1 deletion azurelinuxagent/ga/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def _get_vm_size(self, protocol):
"""
if self._vm_size is None:

imds_client = get_imds_client(protocol.get_endpoint())
imds_client = get_imds_client()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

************* Module azurelinuxagent.ga.update
azurelinuxagent/ga/update.py:478:27: W0613: Unused argument 'protocol' (unused-argument)

we can remove protocol arg. Or we can remove this method entirely I don't think it's being used

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks; i removed the entire method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't catch that it was being mocked in tests still:
tests.lib.tools.AgentTestCaseWithGetVmSizeMock.setUp

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated


try:
imds_info = imds_client.get_compute()
Expand Down
155 changes: 71 additions & 84 deletions tests/common/protocol/test_imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TestImds(AgentTestCase):
def test_get(self, mock_http_get):
mock_http_get.return_value = get_mock_compute_response()

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
test_subject.get_compute()

self.assertEqual(1, mock_http_get.call_count)
Expand All @@ -71,21 +71,21 @@ def test_get(self, mock_http_get):
def test_get_bad_request(self, mock_http_get):
mock_http_get.return_value = MockHttpResponse(status=restutil.httpclient.BAD_REQUEST)

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
self.assertRaises(HttpError, test_subject.get_compute)

@patch("azurelinuxagent.common.protocol.imds.restutil.http_get")
def test_get_internal_service_error(self, mock_http_get):
mock_http_get.return_value = MockHttpResponse(status=restutil.httpclient.INTERNAL_SERVER_ERROR)

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
self.assertRaises(HttpError, test_subject.get_compute)

@patch("azurelinuxagent.common.protocol.imds.restutil.http_get")
def test_get_empty_response(self, mock_http_get):
mock_http_get.return_value = MockHttpResponse(status=httpclient.OK, body=''.encode('utf-8'))

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
self.assertRaises(ValueError, test_subject.get_compute)

def test_deserialize_ComputeInfo(self):
Expand Down Expand Up @@ -359,7 +359,7 @@ def _imds_response(f):
return fh.read()

def _assert_validation(self, http_status_code, http_response, expected_valid, expected_response):
test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
with patch("azurelinuxagent.common.utils.restutil.http_get") as mock_http_get:
mock_http_get.return_value = MockHttpResponse(status=http_status_code,
reason='reason',
Expand All @@ -386,99 +386,86 @@ def test_endpoint_fallback(self):
# http GET calls and enforces a single GET call (fallback would cause 2) and
# checks the url called.

test_subject = imds.ImdsClient("foo.bar")
test_subject = imds.ImdsClient()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably change this test name + comment since we're no longer using a fallback method.


# ensure user-agent gets set correctly
for is_health, expected_useragent in [(False, restutil.HTTP_USER_AGENT), (True, restutil.HTTP_USER_AGENT_HEALTH)]:
# set a different resource path for health query to make debugging unit test easier
resource_path = 'something/health' if is_health else 'something'

for has_primary_ioerror in (False, True):
# secondary endpoint unreachable
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, secondary_ioerror=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success) if has_primary_ioerror else self.assertTrue(result.success) # pylint: disable=expression-not-assigned
self.assertFalse(result.service_error)
if has_primary_ioerror:
self.assertEqual('IMDS error in /metadata/{0}: Unable to connect to endpoint'.format(resource_path), result.response)
else:
self.assertEqual('Mock success response', result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS success
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertTrue(result.success)
self.assertFalse(result.service_error)
self.assertEqual('Mock success response', result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS throttled
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, throttled=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: Throttled'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS gone error
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, gone_error=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertTrue(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: HTTP Failed with Status Code 410: Gone'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS bad request
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, bad_request=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

def _mock_imds_setup(self, primary_ioerror=False, secondary_ioerror=False, gone_error=False, throttled=False, bad_request=False):
self._mock_imds_expect_fallback = primary_ioerror # pylint: disable=attribute-defined-outside-init
self._mock_imds_primary_ioerror = primary_ioerror # pylint: disable=attribute-defined-outside-init
self._mock_imds_secondary_ioerror = secondary_ioerror # pylint: disable=attribute-defined-outside-init
# IMDS success
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup()
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertTrue(result.success)
self.assertFalse(result.service_error)
self.assertEqual('Mock success response', result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# Connection error
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(ioerror=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: Unable to connect to endpoint'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# IMDS throttled
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(throttled=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: Throttled'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# IMDS gone error
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(gone_error=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertTrue(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: HTTP Failed with Status Code 410: Gone'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# IMDS bad request
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(bad_request=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

def _mock_imds_setup(self, ioerror=False, gone_error=False, throttled=False, bad_request=False):
self._mock_imds_ioerror = ioerror # pylint: disable=attribute-defined-outside-init
self._mock_imds_gone_error = gone_error # pylint: disable=attribute-defined-outside-init
self._mock_imds_throttled = throttled # pylint: disable=attribute-defined-outside-init
self._mock_imds_bad_request = bad_request # pylint: disable=attribute-defined-outside-init

def _mock_http_get(self, *_, **kwargs):
if "foo.bar" == kwargs['endpoint'] and not self._mock_imds_expect_fallback:
raise Exception("Unexpected endpoint called")
if self._mock_imds_primary_ioerror and "169.254.169.254" == kwargs['endpoint']:
raise HttpError("[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
.format(kwargs['endpoint'], kwargs['resource_path']))
if self._mock_imds_secondary_ioerror and "foo.bar" == kwargs['endpoint']:
raise HttpError("[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
.format(kwargs['endpoint'], kwargs['resource_path']))
if self._mock_imds_ioerror:
raise HttpError("[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made".format(kwargs['endpoint'], kwargs['resource_path']))
if self._mock_imds_gone_error:
raise ResourceGoneError("Resource is gone")
if self._mock_imds_throttled:
raise HttpError("[HTTP Retry] GET http://{0}/metadata/{1} -- Status Code 429 -- 25 attempts made"
.format(kwargs['endpoint'], kwargs['resource_path']))
raise HttpError("[HTTP Retry] GET http://{0}/metadata/{1} -- Status Code 429 -- 25 attempts made".format(kwargs['endpoint'], kwargs['resource_path']))

resp = MagicMock()
resp.reason = 'reason'
Expand Down
Loading