Skip to content

Commit

Permalink
Use a future task for closing evicted sessions
Browse files Browse the repository at this point in the history
- When evicting sessions from the async cache, once the cache has already been cleared, pop out of the lock and issue a ``threading.Timer`` to close the evicted sessions with a bit more time than the `DEFAULT_TIMEOUT` for a call. This ensures any evicted session has time to go through with its call before it is actually closed.
  • Loading branch information
fselmo committed Aug 2, 2022
1 parent c3d00ed commit 806b513
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 26 deletions.
93 changes: 76 additions & 17 deletions tests/core/utilities/test_request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import pytest

from aiohttp import (
Expand All @@ -20,6 +21,7 @@
)
from web3._utils.request import (
SessionCache,
get_async_session,
)


Expand Down Expand Up @@ -94,24 +96,7 @@ def test_precached_session(mocker):
assert adapter._pool_maxsize == 100


@pytest.mark.asyncio
async def test_async_precached_session(mocker):
# Add a session
session = ClientSession()
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure the session isn't duplicated
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure a request with a different URI adds another cached session
await request.cache_async_session(f"{URI}/test", session)
assert len(request._async_session_cache) == 2


def test_cache_session_class():

cache = SessionCache(2)
evicted_items = cache.cache("1", "Hello1")
assert cache.get_cache_entry("1") == "Hello1"
Expand All @@ -130,9 +115,83 @@ def test_cache_session_class():
evicted_items = cache.cache("3", "Hello3")
assert "2" in cache
assert "3" in cache

assert "1" not in cache
assert "1" in evicted_items

with pytest.raises(KeyError):
# This should throw a KeyError since the cache size was 2 and 3 were inserted
# the first inserted cached item was removed and returned in evicted items
cache.get_cache_entry("1")


# -- async -- #


@pytest.mark.asyncio
async def test_async_precached_session():
# Add a session
session = ClientSession()
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure the session isn't duplicated
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure a request with a different URI adds another cached session
await request.cache_async_session(f"{URI}/test", session)
assert len(request._async_session_cache) == 2


@pytest.mark.asyncio
async def test_async_cache_does_not_close_session_before_a_call():
unique_uris = [
"https://www.test1.com",
"https://www.test2.com",
"https://www.test3.com",
"https://www.test4.com",
"https://www.test5.com",
]

# save default values
session_cache_default = request._async_session_cache
timeout_default = request.DEFAULT_TIMEOUT

# set cache size to 1 + set future session close thread time to 0.02s
request._async_session_cache = SessionCache(1)
_timeout_for_testing = 0.01
request.DEFAULT_TIMEOUT = _timeout_for_testing

async def cache_uri_and_return_evicted_session(uri):
_session = await get_async_session(uri)

# sort of simulate a call taking 0.1s to return a response
await asyncio.sleep(0.01)

assert not _session.closed
return _session

tasks = [cache_uri_and_return_evicted_session(uri) for uri in unique_uris]

evicted_sessions = await asyncio.gather(*tasks)
assert len(evicted_sessions) == len(unique_uris)
assert all(isinstance(s, ClientSession) for s in evicted_sessions)

# last session remains in cache
cache_data = request._async_session_cache._data
assert len(cache_data) == 1

# -- teardown -- #

# appropriately close the cached session before exiting test
_key, cached_session = cache_data.popitem()
await cached_session.close()

# reset to default values
request._async_session_cache = session_cache_default
request.DEFAULT_TIMEOUT = timeout_default

# assert all evicted sessions were closed
await asyncio.sleep(_timeout_for_testing + 0.1)
assert all(session.closed for session in evicted_sessions)
72 changes: 63 additions & 9 deletions web3/_utils/request.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import asyncio
from collections import (
OrderedDict,
)
from concurrent.futures import (
ThreadPoolExecutor,
)
import contextlib
import logging
import os
import threading
from typing import (
Any,
AsyncGenerator,
Dict,
List,
Union,
)

Expand All @@ -23,6 +31,11 @@
generate_cache_key,
)

logger = logging.Logger(__file__)
logger.addHandler(logging.StreamHandler())

DEFAULT_TIMEOUT = 10


class SessionCache:
def __init__(self, size: int):
Expand Down Expand Up @@ -89,7 +102,7 @@ def get_session(
def get_response_from_get_request(
endpoint_uri: URI, *args: Any, **kwargs: Any
) -> requests.Response:
kwargs.setdefault("timeout", 10)
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
session = get_session(endpoint_uri)
response = session.get(endpoint_uri, *args, **kwargs)
return response
Expand All @@ -98,7 +111,7 @@ def get_response_from_get_request(
def get_response_from_post_request(
endpoint_uri: URI, *args: Any, **kwargs: Any
) -> requests.Response:
kwargs.setdefault("timeout", 10)
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
session = get_session(endpoint_uri)
response = session.post(endpoint_uri, *args, **kwargs)
return response
Expand All @@ -114,19 +127,31 @@ def make_post_request(

# --- async --- #


_async_session_cache = SessionCache(size=20)
_async_session_cache_lock = threading.Lock()
_pool = ThreadPoolExecutor(max_workers=1)


async def cache_async_session(endpoint_uri: URI, session: ClientSession) -> None:
await get_async_session(endpoint_uri, session)


@contextlib.asynccontextmanager
async def async_lock(lock: threading.Lock) -> AsyncGenerator[None, None]:
loop = asyncio.get_event_loop()
await loop.run_in_executor(_pool, lock.acquire)
try:
yield
finally:
lock.release()


async def get_async_session(
endpoint_uri: URI, session: ClientSession = None
) -> ClientSession:
cache_key = generate_cache_key(endpoint_uri)
with _async_session_cache_lock:
async with async_lock(_async_session_cache_lock):
evicted_items = None
if session is not None:
evicted_items = _async_session_cache.cache(cache_key, session)
Expand All @@ -135,16 +160,45 @@ async def get_async_session(
cache_key, ClientSession(raise_for_status=True)
)

if evicted_items is not None:
for _key, session in evicted_items.items():
await session.close()
return _async_session_cache.get_cache_entry(cache_key)
session = _async_session_cache.get_cache_entry(cache_key)
logger.debug(f"Session cached: {endpoint_uri}, {session}")

if evicted_items is not None:
# At this point the evicted sessions (if any) are already popped out of the
# cache and just stored in the `evicted_sessions` dict. So we can kick off a
# future task to close them and it should be safe to pop out of the lock here.
evicted_sessions = [session for _key, session in evicted_items.items()]
for session in evicted_sessions:
logger.debug(
f"Session cache full. Session evicted from cache: {session}",
)
# Kick off a future task, in a separate thread, to close the evicted
# sessions. In the case that the cache filled very quickly and some
# sessions have been evicted before their original request has been made,
# we set the timer to a bit more than the `DEFAULT_TIMEOUT` for a call. This
# should guarantee that any call from an evicted session can still be made
# before the session is closed.
threading.Timer(
DEFAULT_TIMEOUT + 0.1, _close_evicted_sessions, args=[evicted_sessions]
).start()

return session


def _close_evicted_sessions(evicted_sessions: List[ClientSession]) -> None:
loop = asyncio.new_event_loop()
for i, evicted_session in enumerate(evicted_sessions):
loop.run_until_complete(evicted_session.close())
logger.debug(f"Closed evicted session: {evicted_session}")
evicted_sessions.pop(i)
assert len(evicted_sessions) == 0
loop.close()


async def async_get_response_from_get_request(
endpoint_uri: URI, *args: Any, **kwargs: Any
) -> ClientResponse:
kwargs.setdefault("timeout", ClientTimeout(10))
kwargs.setdefault("timeout", ClientTimeout(DEFAULT_TIMEOUT))
session = await get_async_session(endpoint_uri)
response = await session.get(endpoint_uri, *args, **kwargs)
return response
Expand All @@ -153,7 +207,7 @@ async def async_get_response_from_get_request(
async def async_get_response_from_post_request(
endpoint_uri: URI, *args: Any, **kwargs: Any
) -> ClientResponse:
kwargs.setdefault("timeout", ClientTimeout(10))
kwargs.setdefault("timeout", ClientTimeout(DEFAULT_TIMEOUT))
session = await get_async_session(endpoint_uri)
response = await session.post(endpoint_uri, *args, **kwargs)
return response
Expand Down

0 comments on commit 806b513

Please sign in to comment.