diff --git a/google/auth/iam.py b/google/auth/iam.py index bba1624c1..1c2eb913b 100644 --- a/google/auth/iam.py +++ b/google/auth/iam.py @@ -23,10 +23,18 @@ import http.client as http_client import json +from google.auth import _exponential_backoff from google.auth import _helpers from google.auth import crypt from google.auth import exceptions +IAM_RETRY_CODES = { + http_client.INTERNAL_SERVER_ERROR, + http_client.BAD_GATEWAY, + http_client.SERVICE_UNAVAILABLE, + http_client.GATEWAY_TIMEOUT, +} + _IAM_SCOPE = ["https://www.googleapis.com/auth/iam"] @@ -88,15 +96,22 @@ def _make_signing_request(self, message): {"payload": base64.b64encode(message).decode("utf-8")} ).encode("utf-8") - self._credentials.before_request(self._request, method, url, headers) - response = self._request(url=url, method=method, body=body, headers=headers) + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + self._credentials.before_request(self._request, method, url, headers) + + response = self._request(url=url, method=method, body=body, headers=headers) + + if response.status in IAM_RETRY_CODES: + continue - if response.status != http_client.OK: - raise exceptions.TransportError( - "Error calling the IAM signBlob API: {}".format(response.data) - ) + if response.status != http_client.OK: + raise exceptions.TransportError( + "Error calling the IAM signBlob API: {}".format(response.data) + ) - return json.loads(response.data.decode("utf-8")) + return json.loads(response.data.decode("utf-8")) + raise exceptions.TransportError("exhausted signBlob endpoint retries") @property def key_id(self): diff --git a/google/auth/impersonated_credentials.py b/google/auth/impersonated_credentials.py index c42a93643..afac4120b 100644 --- a/google/auth/impersonated_credentials.py +++ b/google/auth/impersonated_credentials.py @@ -31,6 +31,7 @@ import http.client as http_client import json +from google.auth import _exponential_backoff from google.auth import _helpers from google.auth import credentials from google.auth import exceptions @@ -288,18 +289,22 @@ def sign_bytes(self, message): authed_session = AuthorizedSession(self._source_credentials) try: - response = authed_session.post( - url=iam_sign_endpoint, headers=headers, json=body - ) + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + response = authed_session.post( + url=iam_sign_endpoint, headers=headers, json=body + ) + if response.status_code in iam.IAM_RETRY_CODES: + continue + if response.status_code != http_client.OK: + raise exceptions.TransportError( + "Error calling sign_bytes: {}".format(response.json()) + ) + + return base64.b64decode(response.json()["signedBlob"]) finally: authed_session.close() - - if response.status_code != http_client.OK: - raise exceptions.TransportError( - "Error calling sign_bytes: {}".format(response.json()) - ) - - return base64.b64decode(response.json()["signedBlob"]) + raise exceptions.TransportError("exhausted signBlob endpoint retries") @property def signer_email(self): diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index ebddbfe05..4a6a25883 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/test_iam.py b/tests/test_iam.py index 6706afb4b..01c2fa085 100644 --- a/tests/test_iam.py +++ b/tests/test_iam.py @@ -91,6 +91,7 @@ def test_sign_bytes(self): assert returned_signature == signature kwargs = request.call_args[1] assert kwargs["headers"]["Content-Type"] == "application/json" + request.call_count == 1 def test_sign_bytes_failure(self): request = make_request(http_client.UNAUTHORIZED) @@ -100,3 +101,15 @@ def test_sign_bytes_failure(self): with pytest.raises(exceptions.TransportError): signer.sign("123") + request.call_count == 1 + + @mock.patch("time.sleep", return_value=None) + def test_sign_bytes_retryable_failure(self, mock_time): + request = make_request(http_client.INTERNAL_SERVER_ERROR) + credentials = make_credentials() + + signer = iam.Signer(request, credentials, mock.sentinel.service_account_email) + + with pytest.raises(exceptions.TransportError): + signer.sign("123") + request.call_count == 3 diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py index 83e260638..f467269e2 100644 --- a/tests/test_impersonated_credentials.py +++ b/tests/test_impersonated_credentials.py @@ -426,12 +426,28 @@ def test_sign_bytes_failure(self): "google.auth.transport.requests.AuthorizedSession.request", autospec=True ) as auth_session: data = {"error": {"code": 403, "message": "unauthorized"}} - auth_session.return_value = MockResponse(data, http_client.FORBIDDEN) + mock_response = MockResponse(data, http_client.UNAUTHORIZED) + auth_session.return_value = mock_response with pytest.raises(exceptions.TransportError) as excinfo: credentials.sign_bytes(b"foo") assert excinfo.match("'code': 403") + @mock.patch("time.sleep", return_value=None) + def test_sign_bytes_retryable_failure(self, mock_time): + credentials = self.make_credentials(lifetime=None) + + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.request", autospec=True + ) as auth_session: + data = {"error": {"code": 500, "message": "internal_failure"}} + mock_response = MockResponse(data, http_client.INTERNAL_SERVER_ERROR) + auth_session.return_value = mock_response + + with pytest.raises(exceptions.TransportError) as excinfo: + credentials.sign_bytes(b"foo") + assert excinfo.match("exhausted signBlob endpoint retries") + def test_with_quota_project(self): credentials = self.make_credentials()