Skip to content

Commit

Permalink
issue #122, asynchronous: add type hints to decorator functions
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Cafasso <[email protected]>
  • Loading branch information
noxdafox committed Nov 21, 2023
1 parent 36eefa4 commit 9527241
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
54 changes: 40 additions & 14 deletions pebble/asynchronous/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@

from itertools import count
from functools import wraps
from typing import Any, Callable
from concurrent.futures import TimeoutError

from pebble.common import ProcessExpired
from pebble.common import process_execute, send_result
from pebble.common import launch_process, stop_process, SLEEP_UNIT


def process(*args, **kwargs):
def process(*args, **kwargs) -> Callable:
"""Runs the decorated function in a concurrent process,
taking care of the result and error management.
Expand Down Expand Up @@ -75,7 +76,13 @@ def decorating_function(function):
return decorating_function


def _process_wrapper(function, timeout, name, daemon, mp_context):
def _process_wrapper(
function: Callable,
timeout: float,
name: str,
daemon: bool,
mp_context: multiprocessing.context.BaseContext
) -> Callable:
if isinstance(function, types.FunctionType):
_register_function(function)

Expand All @@ -85,7 +92,7 @@ def _process_wrapper(function, timeout, name, daemon, mp_context):
start_method = 'spawn' if os.name == 'nt' else 'fork'

@wraps(function)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> asyncio.Future:
loop = _get_asyncio_loop()
future = loop.create_future()
reader, writer = mp_context.Pipe(duplex=False)
Expand All @@ -109,7 +116,12 @@ def wrapper(*args, **kwargs):
return wrapper


async def _worker_handler(future, worker, pipe, timeout):
async def _worker_handler(
future: asyncio.Future,
worker: multiprocessing.Process,
pipe: multiprocessing.Pipe,
timeout: float
):
"""Worker lifecycle manager.
Waits for the worker to be perform its task,
Expand All @@ -130,18 +142,22 @@ async def _worker_handler(future, worker, pipe, timeout):
future.set_result(result)


async def _get_result(future, pipe, timeout):
async def _get_result(
future: asyncio.Future,
pipe: multiprocessing.Pipe,
timeout: float
) -> Any:
"""Waits for result and handles communication errors."""
counter = count(step=SLEEP_UNIT)

try:
while not pipe.poll():
if timeout is not None and next(counter) >= timeout:
return TimeoutError('Task Timeout', timeout)
elif future.cancelled():
if future.cancelled():
return asyncio.CancelledError()
else:
await asyncio.sleep(SLEEP_UNIT)

await asyncio.sleep(SLEEP_UNIT)

return pipe.recv()
except (EOFError, OSError):
Expand All @@ -150,7 +166,12 @@ async def _get_result(future, pipe, timeout):
return error


def _function_handler(function, args, kwargs, pipe):
def _function_handler(
function: Callable,
args: list,
kwargs: dict,
pipe: multiprocessing.Pipe
):
"""Runs the actual function in separate process and returns its result."""
signal.signal(signal.SIGINT, signal.SIG_IGN)

Expand All @@ -162,7 +183,12 @@ def _function_handler(function, args, kwargs, pipe):
send_result(writer, result)


def _validate_parameters(timeout, name, daemon, mp_context):
def _validate_parameters(
timeout: float,
name: str,
daemon: bool,
mp_context: multiprocessing.context.BaseContext
):
if timeout is not None and not isinstance(timeout, (int, float)):
raise TypeError('Timeout expected to be None or integer or float')
if name is not None and not isinstance(name, str):
Expand All @@ -174,7 +200,7 @@ def _validate_parameters(timeout, name, daemon, mp_context):
raise TypeError('Context expected to be None or multiprocessing.context')


def _get_asyncio_loop():
def _get_asyncio_loop() -> asyncio.BaseEventLoop:
"""Backwards compatible loop getter."""
try:
return asyncio.get_running_loop()
Expand All @@ -189,13 +215,13 @@ def _get_asyncio_loop():
_registered_functions = {}


def _register_function(function):
def _register_function(function: Callable) -> Callable:
_registered_functions[function.__qualname__] = function

return function


def _trampoline(name, module, *args, **kwargs):
def _trampoline(name: str, module: Any, *args, **kwargs) -> Any:
"""Trampoline function for decorators.
Lookups the function between the registered ones;
Expand All @@ -207,7 +233,7 @@ def _trampoline(name, module, *args, **kwargs):
return function(*args, **kwargs)


def _function_lookup(name, module):
def _function_lookup(name: str, module: Any) -> Callable:
"""Searches the function between the registered ones.
If not found, it imports the module forcing its registration.
Expand Down
18 changes: 12 additions & 6 deletions pebble/asynchronous/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

import asyncio

from typing import Callable
from functools import wraps
from traceback import format_exc

from pebble.common import launch_thread


def thread(*args, **kwargs):
def thread(*args, **kwargs) -> Callable:
"""Runs the decorated function within a concurrent thread,
taking care of the result and error management.
Expand Down Expand Up @@ -56,9 +57,9 @@ def decorating_function(function):
return decorating_function


def _thread_wrapper(function, name, daemon):
def _thread_wrapper(function: Callable, name: str, daemon: bool) -> Callable:
@wraps(function)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> asyncio.Future:
loop = _get_asyncio_loop()
future = loop.create_future()

Expand All @@ -69,7 +70,12 @@ def wrapper(*args, **kwargs):
return wrapper


def _function_handler(function, args, kwargs, future):
def _function_handler(
function: Callable,
args: list,
kwargs: dict,
future: asyncio.Future
):
"""Runs the actual function in separate thread and returns its result."""
loop = future.get_loop()

Expand All @@ -82,14 +88,14 @@ def _function_handler(function, args, kwargs, future):
loop.call_soon_threadsafe(future.set_result, result)


def _validate_parameters(name, daemon):
def _validate_parameters(name: str, daemon: bool):
if name is not None and not isinstance(name, str):
raise TypeError('Name expected to be None or string')
if daemon is not None and not isinstance(daemon, bool):
raise TypeError('Daemon expected to be None or bool')


def _get_asyncio_loop():
def _get_asyncio_loop() -> asyncio.BaseEventLoop:
"""Backwards compatible loop getter."""
try:
return asyncio.get_running_loop()
Expand Down

0 comments on commit 9527241

Please sign in to comment.