diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 69b7b5245..b66d9f9b3 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -28,6 +28,7 @@ from google.auth import environment_vars from google.auth import exceptions from google.auth import metrics +from google.auth import transport from google.auth._exponential_backoff import ExponentialBackoff _LOGGER = logging.getLogger(__name__) @@ -204,7 +205,17 @@ def get( for attempt in backoff: try: response = request(url=url, method="GET", headers=headers_to_use) - break + if response.status in transport.DEFAULT_RETRYABLE_STATUS_CODES: + _LOGGER.warning( + "Compute Engine Metadata server unavailable on " + "attempt %s of %s. Response status: %s", + attempt, + retry_count, + response.status, + ) + continue + else: + break except exceptions.TransportError as e: _LOGGER.warning( diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index aab772bbb..a38b5d5bc 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index 391422b04..f49886d71 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -431,6 +431,74 @@ def test_get_universe_domain_not_found(): assert universe_domain == "googleapis.com" +def test_get_universe_domain_retryable_error_failure(): + # Test that if the universe domain endpoint returns a retryable error + # we should retry. + # + # In this case, the error persists, and we still fail after retrying. + request = make_request("too many requests", status=http_client.TOO_MANY_REQUESTS) + + with pytest.raises(exceptions.TransportError) as excinfo: + _metadata.get_universe_domain(request) + + assert excinfo.match(r"Compute Engine Metadata server unavailable") + + request.assert_called_with( + method="GET", + url=_metadata._METADATA_ROOT + "universe/universe_domain", + headers=_metadata._METADATA_HEADERS, + ) + assert request.call_count == 5 + + +def test_get_universe_domain_retryable_error_success(): + # Test that if the universe domain endpoint returns a retryable error + # we should retry. + # + # In this case, the error is temporary, and we succeed after retrying. + request_error = make_request( + "too many requests", status=http_client.TOO_MANY_REQUESTS + ) + request_ok = make_request( + "fake_universe_domain", headers={"content-type": "text/plain"} + ) + + class _RequestErrorOnce: + """This class forwards the request parameters to `request_error` once. + + All subsequent calls are forwarded to `request_ok`. + """ + + def __init__(self, request_error, request_ok): + self._request_error = request_error + self._request_ok = request_ok + self._call_index = 0 + + def request(self, *args, **kwargs): + if self._call_index == 0: + self._call_index += 1 + return self._request_error(*args, **kwargs) + + return self._request_ok(*args, **kwargs) + + request = _RequestErrorOnce(request_error, request_ok).request + + universe_domain = _metadata.get_universe_domain(request) + + request_error.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + "universe/universe_domain", + headers=_metadata._METADATA_HEADERS, + ) + request_ok.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + "universe/universe_domain", + headers=_metadata._METADATA_HEADERS, + ) + + assert universe_domain == "fake_universe_domain" + + def test_get_universe_domain_other_error(): # Test that if the universe domain endpoint returns an error other than 404 # we should throw the error