diff --git a/src/easynetwork/api_async/client/abc.py b/src/easynetwork/api_async/client/abc.py index bf994a2b..e5c36cfe 100644 --- a/src/easynetwork/api_async/client/abc.py +++ b/src/easynetwork/api_async/client/abc.py @@ -19,12 +19,12 @@ __all__ = ["AbstractAsyncNetworkClient"] import math -import time from abc import ABCMeta, abstractmethod from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Generic, Self from ..._typevars import _ReceivedPacketT, _SentPacketT +from ...lowlevel import _utils from ...lowlevel.socket import SocketAddress if TYPE_CHECKING: @@ -220,20 +220,16 @@ async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIter if timeout is None: timeout = math.inf - perf_counter = time.perf_counter timeout_after = self.get_backend().timeout while True: try: - with timeout_after(timeout): - _start = perf_counter() + with timeout_after(timeout), _utils.ElapsedTime() as elapsed: packet = await self.recv_packet() - _end = perf_counter() except OSError: return yield packet - timeout -= _end - _start - timeout = max(timeout, 0) + timeout = elapsed.recompute_timeout(timeout) @abstractmethod def get_backend(self) -> AsyncBackend: diff --git a/src/easynetwork/api_sync/client/abc.py b/src/easynetwork/api_sync/client/abc.py index bccf8976..02e59244 100644 --- a/src/easynetwork/api_sync/client/abc.py +++ b/src/easynetwork/api_sync/client/abc.py @@ -18,12 +18,12 @@ __all__ = ["AbstractNetworkClient"] -import time from abc import ABCMeta, abstractmethod from collections.abc import Iterator from typing import TYPE_CHECKING, Generic, Self from ..._typevars import _ReceivedPacketT, _SentPacketT +from ...lowlevel import _utils from ...lowlevel.socket import SocketAddress if TYPE_CHECKING: @@ -174,19 +174,15 @@ def iter_received_packets(self, *, timeout: float | None = 0) -> Iterator[_Recei Yields: the received packet. """ - perf_counter = time.perf_counter - while True: try: - _start = perf_counter() - packet = self.recv_packet(timeout=timeout) - _end = perf_counter() + with _utils.ElapsedTime() as elapsed: + packet = self.recv_packet(timeout=timeout) except OSError: return yield packet if timeout is not None: - timeout -= _end - _start - timeout = max(timeout, 0) + timeout = elapsed.recompute_timeout(timeout) @abstractmethod def fileno(self) -> int: diff --git a/src/easynetwork/api_sync/server/_base.py b/src/easynetwork/api_sync/server/_base.py index e1a98eac..18770d60 100644 --- a/src/easynetwork/api_sync/server/_base.py +++ b/src/easynetwork/api_sync/server/_base.py @@ -21,7 +21,6 @@ import concurrent.futures import contextlib import threading as _threading -import time from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, NoReturn @@ -84,11 +83,12 @@ def shutdown(self, timeout: float | None = None) -> None: if timeout is None: portal.run_coroutine(self.__server.shutdown) else: - _start = time.perf_counter() + elapsed = _utils.ElapsedTime() try: - portal.run_coroutine(self.__do_shutdown_with_timeout, timeout) + with elapsed: + portal.run_coroutine(self.__do_shutdown_with_timeout, timeout) finally: - timeout -= time.perf_counter() - _start + timeout = elapsed.recompute_timeout(timeout) self.__is_shutdown.wait(timeout) async def __do_shutdown_with_timeout(self, timeout_delay: float) -> None: diff --git a/src/easynetwork/api_sync/server/thread.py b/src/easynetwork/api_sync/server/thread.py index fc6b8df5..d34f1570 100644 --- a/src/easynetwork/api_sync/server/thread.py +++ b/src/easynetwork/api_sync/server/thread.py @@ -21,8 +21,8 @@ ] import threading as _threading -import time +from ...lowlevel import _utils from .abc import AbstractNetworkServer @@ -50,9 +50,8 @@ def run(self) -> None: self.__is_up_event.set() def join(self, timeout: float | None = None) -> None: - _start = time.perf_counter() - self.__server.shutdown(timeout=timeout) - _end = time.perf_counter() + with _utils.ElapsedTime() as elapsed: + self.__server.shutdown(timeout=timeout) if timeout is not None: - timeout -= _end - _start + timeout = elapsed.recompute_timeout(timeout) super().join(timeout=timeout) diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index 7d339c08..be82d93c 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -15,6 +15,7 @@ from __future__ import annotations __all__ = [ + "ElapsedTime", "check_real_socket_state", "check_socket_family", "check_socket_no_ssl", @@ -41,7 +42,7 @@ import threading import time from collections.abc import Callable, Iterable, Iterator -from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeGuard, TypeVar +from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeGuard, TypeVar try: import ssl as _ssl @@ -228,6 +229,43 @@ def remove_traceback_frames_in_place(exc: _ExcType, n: int) -> _ExcType: return exc.with_traceback(tb) +class ElapsedTime: + __slots__ = ("_current_time_func", "_start_time", "_end_time") + + def __init__(self) -> None: + self._current_time_func: Callable[[], float] = time.perf_counter + self._start_time: float | None = None + self._end_time: float | None = None + + def __enter__(self) -> Self: + if self._start_time is not None: + raise RuntimeError("Already entered") + self._start_time = self._current_time_func() + return self + + def __exit__(self, *args: Any) -> None: + end_time = self._current_time_func() + if self._end_time is not None: + raise RuntimeError("Already exited") + self._end_time = end_time + + def get_elapsed(self) -> float: + start_time = self._start_time + if start_time is None: + raise RuntimeError("Not entered") + end_time = self._end_time + if end_time is None: + raise RuntimeError("Within context") + return end_time - start_time + + def recompute_timeout(self, old_timeout: float) -> float: + elapsed_time = self.get_elapsed() + new_timeout = old_timeout - elapsed_time + if new_timeout < 0.0: + new_timeout = 0.0 + return new_timeout + + @contextlib.contextmanager def lock_with_timeout( lock: threading.RLock | threading.Lock, @@ -238,19 +276,16 @@ def lock_with_timeout( yield timeout return timeout = validate_timeout_delay(timeout, positive_check=True) - perf_counter = time.perf_counter with contextlib.ExitStack() as stack: # Try to acquire without blocking first if lock.acquire(blocking=False): stack.push(lock) else: - _start = perf_counter() - if timeout == 0 or not lock.acquire(True, timeout): - raise error_from_errno(_errno.ETIMEDOUT) + with ElapsedTime() as elapsed: + if timeout == 0 or not lock.acquire(True, timeout): + raise error_from_errno(_errno.ETIMEDOUT) stack.push(lock) - _end = perf_counter() - timeout -= _end - _start - timeout = max(timeout, 0.0) + timeout = elapsed.recompute_timeout(timeout) yield timeout diff --git a/src/easynetwork/lowlevel/api_sync/endpoints/stream.py b/src/easynetwork/lowlevel/api_sync/endpoints/stream.py index d540ad87..91d9b732 100644 --- a/src/easynetwork/lowlevel/api_sync/endpoints/stream.py +++ b/src/easynetwork/lowlevel/api_sync/endpoints/stream.py @@ -20,7 +20,6 @@ import errno as _errno import math -import time from collections.abc import Callable, Mapping from typing import Any, Generic, TypeGuard @@ -192,12 +191,10 @@ def recv_packet(self, *, timeout: float | None = None) -> _ReceivedPacketT: raise EOFError("end-of-stream") bufsize: int = self.__max_recv_size - perf_counter = time.perf_counter # pull function to local namespace while True: - _start = perf_counter() - chunk: bytes = transport.recv(bufsize, timeout) - _end = perf_counter() + with _utils.ElapsedTime() as elapsed: + chunk: bytes = transport.recv(bufsize, timeout) if not chunk: self.__eof_reached = True raise EOFError("end-of-stream") @@ -211,8 +208,7 @@ def recv_packet(self, *, timeout: float | None = None) -> _ReceivedPacketT: return next(consumer) except StopIteration: if timeout > 0: - timeout -= _end - _start - timeout = max(timeout, 0.0) + timeout = elapsed.recompute_timeout(timeout) elif buffer_not_full: break # Loop break diff --git a/src/easynetwork/lowlevel/api_sync/transports/abc.py b/src/easynetwork/lowlevel/api_sync/transports/abc.py index 8b7fe64b..8012ced8 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/abc.py +++ b/src/easynetwork/lowlevel/api_sync/transports/abc.py @@ -26,11 +26,10 @@ "StreamWriteTransport", ] -import time from abc import ABCMeta, abstractmethod from collections.abc import Iterable -from ... import typed_attr +from ... import _utils, typed_attr class BaseTransport(typed_attr.TypedAttributeProvider, metaclass=ABCMeta): @@ -129,7 +128,6 @@ def send_all(self, data: bytes | bytearray | memoryview, timeout: float) -> None TimeoutError: Operation timed out. """ - perf_counter = time.perf_counter # pull function to local namespace total_sent: int = 0 with memoryview(data) as data: nb_bytes_to_send = len(data) @@ -139,15 +137,12 @@ def send_all(self, data: bytes | bytearray | memoryview, timeout: float) -> None raise RuntimeError("transport.send() returned a negative value") return while total_sent < nb_bytes_to_send: - with data[total_sent:] as buffer: - _start = perf_counter() + with data[total_sent:] as buffer, _utils.ElapsedTime() as elapsed: sent = self.send(buffer, timeout) - _end = perf_counter() if sent < 0: raise RuntimeError("transport.send() returned a negative value") total_sent += sent - timeout -= _end - _start - timeout = max(timeout, 0.0) + timeout = elapsed.recompute_timeout(timeout) def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview], timeout: float) -> None: """ diff --git a/src/easynetwork/lowlevel/api_sync/transports/base_selector.py b/src/easynetwork/lowlevel/api_sync/transports/base_selector.py index 08aab567..ae8a846e 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/base_selector.py +++ b/src/easynetwork/lowlevel/api_sync/transports/base_selector.py @@ -31,7 +31,6 @@ import errno as _errno import math import selectors -import time from abc import abstractmethod from collections.abc import Callable from typing import TypeVar @@ -99,7 +98,6 @@ def _retry( callback: Callable[[], _R], timeout: float, ) -> _R: - perf_counter = time.perf_counter # pull function to local namespace timeout = _utils.validate_timeout_delay(timeout, positive_check=True) retry_interval = self._retry_interval event: int @@ -133,10 +131,9 @@ def _retry( if not available: raise RuntimeError("timeout error with infinite timeout") else: - _start = perf_counter() - available = bool(selector.select(wait_time)) - _end = perf_counter() - timeout -= _end - _start + with _utils.ElapsedTime() as elapsed: + available = bool(selector.select(wait_time)) + timeout = elapsed.recompute_timeout(timeout) if not available: if not is_retry_interval: break diff --git a/tests/unit_test/test_tools/test_utils.py b/tests/unit_test/test_tools/test_utils.py index 027f5cfa..de58de13 100644 --- a/tests/unit_test/test_tools/test_utils.py +++ b/tests/unit_test/test_tools/test_utils.py @@ -12,6 +12,7 @@ from easynetwork.exceptions import BusyResourceError from easynetwork.lowlevel._utils import ( + ElapsedTime, ResourceGuard, check_real_socket_state, check_socket_family, @@ -557,6 +558,58 @@ def func() -> None: assert len(list(traceback.walk_tb(exception.__traceback__))) == 0 +def test____ElapsedTime____catch_elapsed_time(mocker: MockerFixture) -> None: + # Arrange + now: float = 798546132.0 + mocker.patch("time.perf_counter", autospec=True, side_effect=[now, now + 12.0]) + + # Act + with ElapsedTime() as elapsed: + pass + + # Assert + assert elapsed.get_elapsed() == pytest.approx(12.0) + assert elapsed.recompute_timeout(42.4) == pytest.approx(30.4) + assert elapsed.recompute_timeout(8.0) == 0.0 + + +def test____ElapsedTime____not_reentrant() -> None: + # Arrange + with ElapsedTime() as elapsed: + # Act & Assert + with pytest.raises(RuntimeError, match=r"^Already entered$"): + with elapsed: + pytest.fail("Should not enter") + + +def test____ElapsedTime____double_exit() -> None: + # Arrange + + # Act & Assert + with pytest.raises(RuntimeError, match=r"^Already exited$"): + with contextlib.ExitStack() as stack: + elapsed = stack.enter_context(ElapsedTime()) + stack.push(elapsed) + + +def test____ElapsedTime____get_elapsed____not_entered() -> None: + # Arrange + elapsed = ElapsedTime() + + # Act & Assert + with pytest.raises(RuntimeError, match=r"^Not entered$"): + elapsed.get_elapsed() + + +def test____ElapsedTime____get_elapsed____within_context() -> None: + # Arrange + + # Act & Assert + with ElapsedTime() as elapsed: + with pytest.raises(RuntimeError, match=r"^Within context$"): + elapsed.get_elapsed() + + def test____lock_with_timeout____acquire_and_release_with_timeout_at_None() -> None: # Arrange lock = threading.Lock()