Skip to content

Commit

Permalink
feat($mypy): support more function types working with @debounce or …
Browse files Browse the repository at this point in the history
…`@throttle`

* T = TypeVar("T")
* R = TypeVar("R")
  • Loading branch information
johnnymillergh committed May 3, 2023
1 parent 2c13a69 commit 3e88e1c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 16 deletions.
46 changes: 32 additions & 14 deletions python_boilerplate/common/debounce_throttle.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,75 @@
import functools
import time
from typing import Any, Callable
from typing import Any, Callable, Optional, TypeVar

from loguru import logger

R = TypeVar("R")

def debounce(wait: float) -> Callable[..., Callable[..., None]]:

def debounce(wait: float) -> Callable[..., Callable[..., Optional[R]]]:
"""
Debounce function decorator.
@param wait: wait time in seconds
@return: a decorated function
:param wait: wait time in seconds
:return: a decorated function
"""

def decorator(func: Callable[..., None]) -> Callable[..., None]:
def decorator(func: Callable[..., Optional[R]]) -> Callable[..., Optional[R]]:
last_called: float = 0

@functools.wraps(func)
def debounced_func(*args: Any, **kwargs: Any) -> None:
def debounced_func(*args: Any, **kwargs: Any) -> Optional[R]:
nonlocal last_called
elapsed = time.monotonic() - last_called
if elapsed > wait:
func(*args, **kwargs)
logger.debug("Calling function due to elapsed > wait time")
result: Optional[R] = func(*args, **kwargs)
last_called = time.monotonic()
return result
else:
logger.debug(
f"Refused to call function {func.__qualname__}(args={args}, kwargs={kwargs})"
)
return None

return debounced_func

return decorator


def throttle(limit: float) -> Callable[..., Callable[..., None]]:
def throttle(limit: float) -> Callable[..., Callable[..., Optional[R]]]:
"""
Throttle function decorator.
@param limit: throttle limit in seconds
@return: a decorated function
:param limit: throttle limit in seconds
:return: a decorated function
"""

def decorator(func: Callable[..., None]) -> Callable[..., None]:
def decorator(func: Callable[..., Optional[R]]) -> Callable[..., Optional[R]]:
last_called: float = 0
called = False

@functools.wraps(func)
def throttled_func(*args: Any, **kwargs: Any) -> None:
def throttled_func(*args: Any, **kwargs: Any) -> Optional[R]:
nonlocal last_called, called
elapsed = time.monotonic() - last_called
if not called:
logger.debug("Calling func due to not called")
called = True
func(*args, **kwargs)
result = func(*args, **kwargs)
last_called = time.monotonic()
return result
elif elapsed > limit:
logger.debug("Calling func due to elapsed > limit")
func(*args, **kwargs)
result = func(*args, **kwargs)
last_called = time.monotonic()
return result
else:
logger.debug(
f"Refused to call function `{func.__qualname__}`(args={args}, kwargs={kwargs})"
)
return None

return throttled_func

Expand Down
4 changes: 2 additions & 2 deletions python_boilerplate/common/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def async_trace(func: Callable[..., R]) -> Callable[..., R]:
"""

@functools.wraps(func)
def wrapped(*arg: Any, **kwarg: Any) -> Any:
def wrapped(*arg: Any, **kwarg: Any) -> R:
function_arguments = {"arg": arg, "kwarg": kwarg}
trace_log = TraceLog(
called_by=inspect.stack()[1][3],
Expand Down Expand Up @@ -70,7 +70,7 @@ def trace(func: Callable[[Any], R]) -> Callable[[Any], R]:
"""

@functools.wraps(func)
def wrapped(*arg: Any, **kwarg: Any) -> Any:
def wrapped(*arg: Any, **kwarg: Any) -> R:
function_arguments = {"arg": arg, "kwarg": kwarg}
trace_log = TraceLog(
called_by=inspect.stack()[1][3],
Expand Down
38 changes: 38 additions & 0 deletions tests/common/test_debounce_throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from python_boilerplate.common.trace import async_trace, trace

times_for_debounce: int = 0
times_for_debounce2: int = 0

times_for_throttle: int = 0
times_for_throttle2: int = 0


def test_debounce(mocker: MockFixture) -> None:
Expand All @@ -20,6 +23,17 @@ def test_debounce(mocker: MockFixture) -> None:
call_count -= 1
spy.assert_called()
assert times_for_debounce == 1
result1 = debounce_function2()
assert result1 is not None
assert len(result1) > 0
logger.info(result1)
result2 = debounce_function2()
assert result2 is None
logger.info(result2)
result3 = debounce_function2()
assert result3 is None
logger.info(result3)
assert times_for_debounce2 == 1


def test_throttle(mocker: MockFixture) -> None:
Expand All @@ -36,6 +50,16 @@ def test_throttle(mocker: MockFixture) -> None:
assert False, f"Failed to test throttle_function(). {ex}"
spy.assert_called()
assert times_for_throttle >= 2
throttled1 = throttle_function2()
assert throttled1 is not None
logger.info(throttled1)
throttled2 = throttle_function2()
assert throttled2 is None
logger.info(throttled2)
throttled3 = throttle_function2()
assert throttled3 is None
logger.info(throttled3)
assert times_for_throttle2 == 1


@trace
Expand All @@ -48,6 +72,13 @@ def debounce_function(a_int: int) -> None:
)


@debounce(wait=1)
def debounce_function2() -> str:
global times_for_debounce2
times_for_debounce2 += 1
return f"debounce_function2 -> {times_for_debounce2}"


@async_trace
@throttle(limit=0.25)
def throttle_function(a_int: int) -> None:
Expand All @@ -56,3 +87,10 @@ def throttle_function(a_int: int) -> None:
logger.warning(
f"'throttle_function' was called with {a_int}, times: {times_for_throttle}"
)


@throttle(limit=0.25)
def throttle_function2() -> str:
global times_for_throttle2
times_for_throttle2 += 1
return f"throttle_function2 -> {times_for_throttle}"

0 comments on commit 3e88e1c

Please sign in to comment.