Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Async auth lock #359

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions src/firebolt/client/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
from typing import AsyncGenerator, Generator, Optional

from anyio import Lock
from anyio import Lock, get_current_task
from httpx import Auth as HttpxAuth
from httpx import Request, Response, codes

Expand Down Expand Up @@ -149,16 +149,27 @@ async def async_auth_flow(
if self.requires_request_body:
await request.aread()

async with self._lock:
flow = self.auth_flow(request)
request = next(flow)
if not self.token or self.expired:
await self._lock.acquire()
# If another task has already updated the token,
# we don't need to hold the lock
if self.token and not self.expired:
self._lock.release()

while True:
response = yield request
if self.requires_response_body:
await response.aread()
flow = self.auth_flow(request)
request = next(flow)

try:
request = flow.send(response)
except StopIteration:
break
while True:
response = yield request
if self.requires_response_body:
await response.aread()

try:
request = flow.send(response)
except StopIteration:
break
finally:
# token gets updated only after flow.send is called
# so unlock only after that
if self._lock.locked() and self._lock._owner_task == get_current_task():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there situations where we end up here, but the task that has acquired the lock if different?
In case no, can we still use context manager?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I remember reproducing it when the lock was still held but another task has reached this point. I don't remember specifics anymore (it's been a while) but I think there is a point when the token has been updated but the lock is not yet released so another thread progresses up until this point by skipping lock on L152.

Using a context manager is not really possible as we want to lock it only when doing the first auth. However, async_auth_flow is called on every request so we want to avoid always locking like we did before, since it leads to synchronous execution.

self._lock.release()
57 changes: 55 additions & 2 deletions tests/unit/client/V1/test_client_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import random
from queue import Queue
from re import Pattern, compile
from types import MethodType
from typing import Any, Callable
Expand All @@ -12,7 +14,7 @@
from firebolt.client.auth import Token, UsernamePassword
from firebolt.utils.urls import AUTH_URL
from firebolt.utils.util import fix_url_schema
from tests.unit.conftest import Response
from tests.unit.conftest import Response, retry_if_failed


async def test_client_retry(
Expand Down Expand Up @@ -157,7 +159,7 @@ def check_credentials(
httpx_mock.add_callback(check_credentials, url=auth_url)

async with AsyncClient(
auth=UsernamePassword(test_username, test_password),
auth=UsernamePassword(test_username, test_password, False),
api_endpoint=server,
) as c:
c._send_handling_redirects = MethodType(mock_send_handling_redirects, c)
Expand All @@ -167,3 +169,54 @@ def check_credentials(
nursery.start_soon(c.get, url)

assert checked_creds_times == 1


# test that client requests are truly concurrent
# and are executed not in order that they were started
# but in order of completion
@retry_if_failed(3)
async def test_true_concurent_requests(
httpx_mock: HTTPXMock,
test_username: str,
test_password: str,
auth_url: str,
auth_callback: Callable,
server: str,
):
url = "https://url"
CONCURENT_COUNT = 10

queue = Queue(CONCURENT_COUNT)

# create callback that uses check_token_callback but also pushes URl to a queue
async def check_token_callback_with_queue(request: Request, **kwargs) -> Response:
nonlocal queue
queue.put(str(request.url))
return Response(status_code=codes.OK, headers={"content-length": "0"})

async def mock_send_handling_redirects(self, *args: Any, **kwargs: Any) -> Response:
# simulate network delay so the context switches
# random delay to make sure that requests are not executed in order
await sleep(0.1 * random.random())
return await AsyncClient._send_handling_redirects(self, *args, **kwargs)

httpx_mock.add_callback(auth_callback, url=auth_url)

httpx_mock.add_callback(check_token_callback_with_queue, url=compile(f"{url}/."))

urls = [f"{url}/{i}" for i in range(CONCURENT_COUNT)]
async with AsyncClient(
auth=UsernamePassword(test_username, test_password),
api_endpoint=server,
) as c:
c._send_handling_redirects = MethodType(mock_send_handling_redirects, c)
async with open_nursery() as nursery:
for url in urls:
nursery.start_soon(c.get, url)

assert queue.qsize() == CONCURENT_COUNT
# Make sure the order is random and not sequential
assert list(queue.queue) != urls
# Cover the case when requests might be queued in reverse order
urls.reverse()
assert list(queue.queue) != urls
57 changes: 56 additions & 1 deletion tests/unit/client/V2/test_client_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import random
from queue import Queue
from re import Pattern, compile
from types import MethodType
from typing import Any, Callable
Expand All @@ -11,7 +13,7 @@
from firebolt.client.auth import Auth, ClientCredentials
from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL
from firebolt.utils.util import fix_url_schema
from tests.unit.conftest import Response
from tests.unit.conftest import Response, retry_if_failed


async def test_client_retry(
Expand Down Expand Up @@ -186,3 +188,56 @@ def check_credentials(
nursery.start_soon(c.get, url)

assert checked_creds_times == 1


# test that client requests are truly concurrent
# and are executed not in order that they were started
# but in order of completion
@retry_if_failed(3)
async def test_true_concurent_requests(
httpx_mock: HTTPXMock,
account_name: str,
client_id: str,
client_secret: str,
auth_url: str,
auth_callback: Callable,
server: str,
):
url = "https://url"
CONCURENT_COUNT = 10

queue = Queue(CONCURENT_COUNT)

# create callback that uses check_token_callback but also pushes URl to a queue
async def check_token_callback_with_queue(request: Request, **kwargs) -> Response:
nonlocal queue
queue.put(str(request.url))
return Response(status_code=codes.OK, headers={"content-length": "0"})

async def mock_send_handling_redirects(self, *args: Any, **kwargs: Any) -> Response:
# simulate network delay so the context switches
# random delay to make sure that requests are not executed in order
await sleep(0.1 * random.random())
return await AsyncClient._send_handling_redirects(self, *args, **kwargs)

httpx_mock.add_callback(auth_callback, url=auth_url)

httpx_mock.add_callback(check_token_callback_with_queue, url=compile(f"{url}/."))

urls = [f"{url}/{i}" for i in range(CONCURENT_COUNT)]
async with AsyncClient(
auth=ClientCredentials(client_id, client_secret),
api_endpoint=server,
account_name=account_name,
) as c:
c._send_handling_redirects = MethodType(mock_send_handling_redirects, c)
async with open_nursery() as nursery:
for url in urls:
nursery.start_soon(c.get, url)

assert queue.qsize() == CONCURENT_COUNT
# Make sure the order is random and not sequential
assert list(queue.queue) != urls
# Cover the case when requests might be queued in reverse order
urls.reverse()
assert list(queue.queue) != urls
20 changes: 20 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from re import Pattern, compile
from typing import Callable

Expand Down Expand Up @@ -373,3 +374,22 @@ def settings(
account_name=account_name,
)
return seett


# Retry decorator that allows to retry test N number of times in case one
# ot the asserts fail
def retry_if_failed(num_retries):
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
for i in range(num_retries):
try:
await func(*args, **kwargs)
break
except AssertionError as e:
if i == num_retries - 1:
raise e

return wrapper

return decorator
Loading