Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to wait for cache arrival in RequestCache #1327

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions doc/basics/requestcache_tutorial_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asyncio import create_task, run, sleep
from asyncio import run

from ipv8.requestcache import NumberCacheWithName, RequestCache

Expand Down Expand Up @@ -29,12 +29,10 @@ async def bar() -> None:
"""
# Normally, you would add this to a network overlay instance.
request_cache = RequestCache()
request_cache.register_anonymous_task("Add later", foo, request_cache, delay=1.23)

_ = create_task(foo(request_cache))
cache = await request_cache.wait_for(MyState, 0)

while not request_cache.has(MyState, 0):
await sleep(0.1)
cache = request_cache.pop(MyState, 0)
print("I found a cache with the state:", cache.state)


Expand Down
43 changes: 42 additions & 1 deletion ipv8/requestcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from contextlib import contextmanager, suppress
from random import random
from threading import Lock
from typing import TYPE_CHECKING, TypeVar, overload
from typing import TYPE_CHECKING, TypeVar, cast, overload

from typing_extensions import Protocol

from .taskmanager import TaskManager
from .util import succeed

if TYPE_CHECKING:
from collections.abc import Generator, Iterable
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(self) -> None:
self._logger = logging.getLogger(self.__class__.__name__)

self._identifiers: dict[str, NumberCache] = {}
self._waiters: dict[tuple[str, int], Future] = {}
self.lock = Lock()
self._shutdown = False

Expand Down Expand Up @@ -210,6 +212,9 @@ def add(self, cache: ACT) -> ACT | None:
timeout_delay = self._timeout_override

self.register_task(cache, self._on_timeout, cache, delay=timeout_delay)
waiter = self._waiters.pop((cache.prefix, cache.number), None)
if waiter is not None and not waiter.done():
waiter.set_result(cache)
return cache

@overload
Expand All @@ -228,6 +233,42 @@ def has(self, prefix: str | type[CacheTypeVar], number: int) -> bool:
return self._create_identifier(number, prefix) in self._identifiers
return self.has(prefix.name, number)

def _watch_future(self, key: tuple[str, int]) -> None:
"""
Ensure that a given future is killed after some timeout.
"""
future = self._waiters.pop(key, None)
if future is not None and not future.done():
future.cancel()

@overload
def wait_for(self, prefix: str, number: int, timeout: float | None = None) -> Future[NumberCache]:
pass

@overload
def wait_for(self, prefix: type[CacheTypeVar], number: int, timeout: float | None = None) -> Future[CacheTypeVar]:
pass

def wait_for(self, prefix: str | type[CacheTypeVar], number: int,
timeout: float | None = None) -> Future[CacheTypeVar] | Future[NumberCache]:
"""
Returns a future that fires if or when the given cache is registered.
"""
result = self.get(prefix, number)
if result is not None:
# This is just to please ``Mypy``: ``return succeed(result)`` is functionally equivalent in this block.
if isinstance(prefix, str):
return succeed(cast(NumberCache, result))
return succeed(cast(CacheTypeVar, result))
if isinstance(prefix, str):
fut: Future[NumberCache] = Future()
self._waiters[(prefix, number)] = fut
if timeout is not None:
self.register_anonymous_task(f"Watch RequestCache Future {fut}", self._watch_future, (prefix, number),
delay=timeout)
return self.register_anonymous_task(f"RequestCache wait for {prefix}", fut)
return cast(Future[CacheTypeVar], self.wait_for(prefix.name, number))

@overload
def get(self, prefix: str, number: int) -> NumberCache | None:
pass
Expand Down
88 changes: 61 additions & 27 deletions ipv8/test/test_requestcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def on_timeout(self) -> None:
"""


class MockNamedNumberCache(NumberCache):
"""
A "normal" NumberCache that has a name.
"""

name = "test"


class TestRequestCache(TestBase):
"""
Tests related to the request cache.
Expand All @@ -115,6 +123,13 @@ def setUp(self) -> None:
super().setUp()
self.request_cache = RequestCache()

async def tearDown(self) -> None:
"""
Destroy the request cache.
"""
await self.request_cache.shutdown()
await super().tearDown()

async def test_shutdown(self) -> None:
"""
Test if RequestCache does not allow new Caches after shutdown().
Expand All @@ -134,7 +149,6 @@ async def test_timeout(self) -> None:
cache = MockCache(self.request_cache)
self.request_cache.add(cache)
await cache.timed_out
await self.request_cache.shutdown()

async def test_add_duplicate(self) -> None:
"""
Expand All @@ -145,16 +159,13 @@ async def test_add_duplicate(self) -> None:

self.assertIsNone(self.request_cache.add(cache))

await self.request_cache.shutdown()

async def test_timeout_future_default_value(self) -> None:
"""
Test if a registered future gets set to None on timeout.
"""
cache = MockRegisteredCache(self.request_cache)
self.request_cache.add(cache)
self.assertEqual(None, (await cache.timed_out))
await self.request_cache.shutdown()

async def test_timeout_future_custom_value(self) -> None:
"""
Expand All @@ -166,8 +177,6 @@ async def test_timeout_future_custom_value(self) -> None:
cache.managed_futures[0] = (cache.managed_futures[0][0], 123)
self.assertEqual(123, (await cache.timed_out))

await self.request_cache.shutdown()

async def test_timeout_future_exception(self) -> None:
"""
Test if a registered future raises an exception on timeout.
Expand All @@ -179,8 +188,6 @@ async def test_timeout_future_exception(self) -> None:
with self.assertRaises(RuntimeError):
await cache.timed_out

await self.request_cache.shutdown()

async def test_cancel_future_after_shutdown(self) -> None:
"""
Test if a registered future is cancelled when the RequestCache has shutdown.
Expand Down Expand Up @@ -211,8 +218,6 @@ async def test_passthrough_noargs(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_timeout(self) -> None:
"""
Test if passthrough respects the timeout value.
Expand All @@ -225,8 +230,6 @@ async def test_passthrough_timeout(self) -> None:

self.assertFalse(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_one_match(self) -> None:
"""
Test if passthrough filters correctly with one filter, that matches.
Expand All @@ -239,8 +242,6 @@ async def test_passthrough_filter_one_match(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_one_mismatch(self) -> None:
"""
Test if passthrough filters correctly with one filter, that doesn't match.
Expand All @@ -253,8 +254,6 @@ async def test_passthrough_filter_one_mismatch(self) -> None:

self.assertFalse(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_many_match(self) -> None:
"""
Test if passthrough filters correctly with many filters, that all match.
Expand All @@ -267,8 +266,6 @@ async def test_passthrough_filter_many_match(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_some_match(self) -> None:
"""
Test if passthrough filters correctly with many filters, for which some match.
Expand All @@ -281,8 +278,6 @@ async def test_passthrough_filter_some_match(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_no_match(self) -> None:
"""
Test if passthrough filters correctly with many filters, for which none match.
Expand All @@ -295,8 +290,6 @@ async def test_passthrough_filter_no_match(self) -> None:

self.assertFalse(cache.timed_out)

await self.request_cache.shutdown()

async def test_has_by_class(self) -> None:
"""
Check if we can call ``.has()`` by cache class.
Expand All @@ -307,7 +300,52 @@ async def test_has_by_class(self) -> None:

self.assertTrue(self.request_cache.has(MockNamedCache, added.number))

await self.request_cache.shutdown()
async def test_wait_for_by_name(self) -> None:
"""
Check if we can call ``.wait_for()`` by cache name.
"""
cache = MockNamedNumberCache(self.request_cache, MockNamedNumberCache.name, 1337)
fut = self.request_cache.wait_for(MockNamedNumberCache.name, 1337)

added = self.request_cache.add(cache)
result = await fut

self.assertEqual(added, result)

async def test_wait_for_by_class(self) -> None:
"""
Check if we can call ``.wait_for()`` by cache class.
"""
cache = MockNamedNumberCache(self.request_cache, MockNamedNumberCache.name, 1337)
fut = self.request_cache.wait_for(MockNamedNumberCache, 1337)

added = self.request_cache.add(cache)
result = await fut

self.assertEqual(added, result)

async def test_wait_for_already_added(self) -> None:
"""
Check if we can ``.wait_for()`` returns when a cache is already available.
"""
cache = MockNamedNumberCache(self.request_cache, MockNamedNumberCache.name, 1337)
added = self.request_cache.add(cache)
fut = self.request_cache.wait_for(MockNamedNumberCache.name, 1337)

result = await fut

self.assertEqual(added, result)

async def test_wait_for_timeout(self) -> None:
"""
Check if ``.wait_for()`` cancels its future when a timeout occurs.
"""
fut = self.request_cache.wait_for(MockNamedNumberCache.name, 1337, timeout=0.0)

await sleep(0)

self.assertTrue(fut.done())
self.assertTrue(fut.cancelled())

async def test_get_by_class(self) -> None:
"""
Expand All @@ -318,8 +356,6 @@ async def test_get_by_class(self) -> None:

self.assertEqual(added, self.request_cache.get(MockNamedCache, added.number))

await self.request_cache.shutdown()

async def test_pop_by_class(self) -> None:
"""
Check if we can call ``.pop()`` by cache class.
Expand All @@ -328,5 +364,3 @@ async def test_pop_by_class(self) -> None:
added = self.request_cache.add(cache)

self.assertEqual(added, self.request_cache.pop(MockNamedCache, added.number))

await self.request_cache.shutdown()