Skip to content

Commit

Permalink
chore: Modify exponential backoff implementation to have no initial s…
Browse files Browse the repository at this point in the history
…leep (#1547)

chore: No sleep on initial attempt in exponential backoff implementation

It is unintuitive that the initial attempt in the exponential backoff
loop sleeps. This can lead to subtle bugs in future call sites.

This patch refactors the exponential backoff to begin sleeping on the
2nd iteration so requests can be done in a single for loop.
  • Loading branch information
clundin25 authored Jul 11, 2024
1 parent 8338594 commit 461a3f5
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 43 deletions.
10 changes: 10 additions & 0 deletions google/auth/_exponential_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import random
import time

from google.auth import exceptions

# The default amount of retry attempts
_DEFAULT_RETRY_TOTAL_ATTEMPTS = 3

Expand Down Expand Up @@ -68,6 +70,11 @@ def __init__(
randomization_factor=_DEFAULT_RANDOMIZATION_FACTOR,
multiplier=_DEFAULT_MULTIPLIER,
):
if total_attempts < 1:
raise exceptions.InvalidValue(
f"total_attempts must be greater than or equal to 1 but was {total_attempts}"
)

self._total_attempts = total_attempts
self._initial_wait_seconds = initial_wait_seconds

Expand All @@ -87,6 +94,9 @@ def __next__(self):
raise StopIteration
self._backoff_count += 1

if self._backoff_count <= 1:
return self._backoff_count

jitter_variance = self._current_wait_in_seconds * self._randomization_factor
jitter = random.uniform(
self._current_wait_in_seconds - jitter_variance,
Expand Down
22 changes: 8 additions & 14 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def _token_endpoint_request_no_throw(
if headers:
headers_to_use.update(headers)

def _perform_request():
response_data = {}
retryable_error = False

retries = _exponential_backoff.ExponentialBackoff()
for _ in retries:
response = request(
method="POST", url=token_uri, headers=headers_to_use, body=body, **kwargs
)
Expand All @@ -192,7 +196,7 @@ def _perform_request():
if hasattr(response.data, "decode")
else response.data
)
response_data = ""

try:
# response_body should be a JSON
response_data = json.loads(response_body)
Expand All @@ -206,18 +210,8 @@ def _perform_request():
status_code=response.status, response_data=response_data
)

return False, response_data, retryable_error

request_succeeded, response_data, retryable_error = _perform_request()

if request_succeeded or not retryable_error or not can_retry:
return request_succeeded, response_data, retryable_error

retries = _exponential_backoff.ExponentialBackoff()
for _ in retries:
request_succeeded, response_data, retryable_error = _perform_request()
if request_succeeded or not retryable_error:
return request_succeeded, response_data, retryable_error
if not can_retry or not retryable_error:
return False, response_data, retryable_error

return False, response_data, retryable_error

Expand Down
20 changes: 7 additions & 13 deletions google/oauth2/_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ async def _token_endpoint_request_no_throw(
if access_token:
headers["Authorization"] = "Bearer {}".format(access_token)

async def _perform_request():
response_data = {}
retryable_error = False

retries = _exponential_backoff.ExponentialBackoff()
for _ in retries:
response = await request(
method="POST", url=token_uri, headers=headers, body=body
)
Expand All @@ -93,18 +97,8 @@ async def _perform_request():
status_code=response.status, response_data=response_data
)

return False, response_data, retryable_error

request_succeeded, response_data, retryable_error = await _perform_request()

if request_succeeded or not retryable_error or not can_retry:
return request_succeeded, response_data, retryable_error

retries = _exponential_backoff.ExponentialBackoff()
for _ in retries:
request_succeeded, response_data, retryable_error = await _perform_request()
if request_succeeded or not retryable_error:
return request_succeeded, response_data, retryable_error
if not can_retry or not retryable_error:
return False, response_data, retryable_error

return False, response_data, retryable_error

Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
10 changes: 5 additions & 5 deletions tests/oauth2/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def test__token_endpoint_request_internal_failure_error():
_client._token_endpoint_request(
request, "http://example.com", {"error_description": "internal_failure"}
)
# request should be called once and then with 3 retries
assert request.call_count == 4
# request with 2 retries
assert request.call_count == 3

request = make_request(
{"error": "internal_failure"}, status=http_client.BAD_REQUEST
Expand All @@ -205,8 +205,8 @@ def test__token_endpoint_request_internal_failure_error():
_client._token_endpoint_request(
request, "http://example.com", {"error": "internal_failure"}
)
# request should be called once and then with 3 retries
assert request.call_count == 4
# request with 2 retries
assert request.call_count == 3


def test__token_endpoint_request_internal_failure_and_retry_failure_error():
Expand Down Expand Up @@ -625,6 +625,6 @@ def test__token_endpoint_request_no_throw_with_retry(can_retry):
)

if can_retry:
assert mock_request.call_count == 4
assert mock_request.call_count == 3
else:
assert mock_request.call_count == 1
35 changes: 25 additions & 10 deletions tests/test__exponential_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import mock
import pytest # type: ignore

from google.auth import _exponential_backoff
from google.auth import exceptions


@mock.patch("time.sleep", return_value=None)
Expand All @@ -24,18 +26,31 @@ def test_exponential_backoff(mock_time):
iteration_count = 0

for attempt in eb:
backoff_interval = mock_time.call_args[0][0]
jitter = curr_wait * eb._randomization_factor

assert (curr_wait - jitter) <= backoff_interval <= (curr_wait + jitter)
assert attempt == iteration_count + 1
assert eb.backoff_count == iteration_count + 1
assert eb._current_wait_in_seconds == eb._multiplier ** (iteration_count + 1)

curr_wait = eb._current_wait_in_seconds
if attempt == 1:
assert mock_time.call_count == 0
else:
backoff_interval = mock_time.call_args[0][0]
jitter = curr_wait * eb._randomization_factor

assert (curr_wait - jitter) <= backoff_interval <= (curr_wait + jitter)
assert attempt == iteration_count + 1
assert eb.backoff_count == iteration_count + 1
assert eb._current_wait_in_seconds == eb._multiplier ** iteration_count

curr_wait = eb._current_wait_in_seconds
iteration_count += 1

assert eb.total_attempts == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
assert eb.backoff_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
assert iteration_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
assert mock_time.call_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
assert (
mock_time.call_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS - 1
)


def test_minimum_total_attempts():
with pytest.raises(exceptions.InvalidValue):
_exponential_backoff.ExponentialBackoff(total_attempts=0)
with pytest.raises(exceptions.InvalidValue):
_exponential_backoff.ExponentialBackoff(total_attempts=-1)
_exponential_backoff.ExponentialBackoff(total_attempts=1)
2 changes: 1 addition & 1 deletion tests_async/oauth2/test__client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,6 @@ async def test__token_endpoint_request_no_throw_with_retry(can_retry):
)

if can_retry:
assert mock_request.call_count == 4
assert mock_request.call_count == 3
else:
assert mock_request.call_count == 1

0 comments on commit 461a3f5

Please sign in to comment.