Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-janssen committed Mar 18, 2024
1 parent 56f0158 commit 83a3788
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 99 deletions.
32 changes: 16 additions & 16 deletions pympipool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,30 @@ class Executor:

def __init__(
self,
max_workers=1,
cores_per_worker=1,
threads_per_core=1,
gpus_per_worker=0,
oversubscribe=False,
init_function=None,
cwd=None,
max_workers: int = 1,
cores_per_worker: int = 1,
threads_per_core: int = 1,
gpus_per_worker: int = 0,
oversubscribe: bool = False,
init_function: callable = None,
cwd: str = None,
executor=None,
hostname_localhost=False,
hostname_localhost: bool =False,
):
# Use __new__() instead of __init__(). This function is only implemented to enable auto-completion.
pass

def __new__(
cls,
max_workers=1,
cores_per_worker=1,
threads_per_core=1,
gpus_per_worker=0,
oversubscribe=False,
init_function=None,
cwd=None,
max_workers: int = 1,
cores_per_worker: int = 1,
threads_per_core: int = 1,
gpus_per_worker: int = 0,
oversubscribe: bool = False,
init_function: callable = None,
cwd: str = None,
executor=None,
hostname_localhost=False,
hostname_localhost: bool = False,
):
"""
Instead of returning a pympipool.Executor object this function returns either a pympipool.mpi.PyMPIExecutor,
Expand Down
2 changes: 1 addition & 1 deletion pympipool/backend/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pympipool.shared.backend import call_funct, parse_arguments


def main(argument_lst=None):
def main(argument_lst: list[str] = None):
if argument_lst is None:
argument_lst = sys.argv
argument_dict = parse_arguments(argument_lst=argument_lst)
Expand Down
32 changes: 16 additions & 16 deletions pympipool/flux/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ class PyFluxExecutor(ExecutorBroker):

def __init__(
self,
max_workers=1,
cores_per_worker=1,
threads_per_core=1,
gpus_per_worker=0,
init_function=None,
cwd=None,
executor=None,
hostname_localhost=False,
max_workers: int = 1,
cores_per_worker: int = 1,
threads_per_core: int = 1,
gpus_per_worker: int = 0,
init_function: callable = None,
cwd: str = None,
executor: flux.job.FluxExecutor = None,
hostname_localhost: bool = False,
):
super().__init__()
self._set_process(
Expand Down Expand Up @@ -92,12 +92,12 @@ def __init__(
class FluxPythonInterface(BaseInterface):
def __init__(
self,
cwd=None,
cores=1,
threads_per_core=1,
gpus_per_core=0,
oversubscribe=False,
executor=None,
cwd: str = None,
cores: int = 1,
threads_per_core: int = 1,
gpus_per_core: int = 0,
oversubscribe: bool = False,
executor: flux.job.FluxExecutor = None,
):
super().__init__(
cwd=cwd,
Expand All @@ -109,7 +109,7 @@ def __init__(
self._executor = executor
self._future = None

def bootup(self, command_lst):
def bootup(self, command_lst: list[str]):
if self._oversubscribe:
raise ValueError(
"Oversubscribing is currently not supported for the Flux adapter."
Expand All @@ -129,7 +129,7 @@ def bootup(self, command_lst):
jobspec.cwd = self._cwd
self._future = self._executor.submit(jobspec)

def shutdown(self, wait=True):
def shutdown(self, wait: bool = True):
if self.poll():
self._future.cancel()
# The flux future objects are not instantly updated,
Expand Down
12 changes: 6 additions & 6 deletions pympipool/mpi/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ class PyMPIExecutor(ExecutorBroker):

def __init__(
self,
max_workers=1,
cores_per_worker=1,
oversubscribe=False,
init_function=None,
cwd=None,
hostname_localhost=False,
max_workers: int = 1,
cores_per_worker: int = 1,
oversubscribe: bool = False,
init_function: callable = None,
cwd: str = None,
hostname_localhost: bool = False,
):
super().__init__()
self._set_process(
Expand Down
8 changes: 4 additions & 4 deletions pympipool/shared/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect


def call_funct(input_dict, funct=None, memory=None):
def call_funct(input_dict: dict, funct: callable = None, memory: dict = None) -> callable:
"""
Call function from dictionary
Expand Down Expand Up @@ -30,7 +30,7 @@ def funct(*args, **kwargs):
return funct(input_dict["fn"], *input_dict["args"], **input_dict["kwargs"])


def parse_arguments(argument_lst):
def parse_arguments(argument_lst: list[str]) -> dict:
"""
Simple function to parse command line arguments
Expand All @@ -50,7 +50,7 @@ def parse_arguments(argument_lst):
)


def update_default_dict_from_arguments(argument_lst, argument_dict, default_dict):
def update_default_dict_from_arguments(argument_lst: list[str], argument_dict: dict, default_dict: dict) -> dict:
default_dict.update(
{
k: argument_lst[argument_lst.index(v) + 1]
Expand All @@ -61,7 +61,7 @@ def update_default_dict_from_arguments(argument_lst, argument_dict, default_dict
return default_dict


def _update_dict_delta(dict_input, dict_output, keys_possible_lst):
def _update_dict_delta(dict_input: dict, dict_output: dict, keys_possible_lst: list) -> dict:
return {
k: v
for k, v in dict_input.items()
Expand Down
16 changes: 8 additions & 8 deletions pympipool/shared/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, interface=None):
self._process = None
self._interface = interface

def send_dict(self, input_dict):
def send_dict(self, input_dict: dict):
"""
Send a dictionary with instructions to a connected client process.
Expand All @@ -42,7 +42,7 @@ def receive_dict(self):
error_type = output["error_type"].split("'")[1]
raise eval(error_type)(output["error"])

def send_and_receive_dict(self, input_dict):
def send_and_receive_dict(self, input_dict: dict) -> dict:
"""
Combine both the send_dict() and receive_dict() function in a single call.
Expand All @@ -66,7 +66,7 @@ def bind_to_random_port(self):
"""
return self._socket.bind_to_random_port("tcp://*")

def bootup(self, command_lst):
def bootup(self, command_lst: list[str]):
"""
Boot up the client process to connect to the SocketInterface.
Expand All @@ -75,7 +75,7 @@ def bootup(self, command_lst):
"""
self._interface.bootup(command_lst=command_lst)

def shutdown(self, wait=True):
def shutdown(self, wait: bool = True):
result = None
if self._interface.poll():
result = self.send_and_receive_dict(
Expand All @@ -96,9 +96,9 @@ def __del__(self):


def interface_bootup(
command_lst,
command_lst: list[str],
connections,
hostname_localhost=False,
hostname_localhost: bool = False,
):
"""
Start interface for ZMQ communication
Expand Down Expand Up @@ -132,7 +132,7 @@ def interface_bootup(
return interface


def interface_connect(host, port):
def interface_connect(host: str, port: str):
"""
Connect to an existing SocketInterface instance by providing the hostname and the port as strings.
Expand All @@ -146,7 +146,7 @@ def interface_connect(host, port):
return context, socket


def interface_send(socket, result_dict):
def interface_send(socket, result_dict: dict):
"""
Send results to a SocketInterface instance.
Expand Down
31 changes: 16 additions & 15 deletions pympipool/shared/executorbase.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from concurrent.futures import (
Executor as FutureExecutor,
Future,
Expand All @@ -22,7 +23,7 @@ def __init__(self):
def future_queue(self):
return self._future_queue

def submit(self, fn, *args, **kwargs):
def submit(self, fn: callable, *args, **kwargs):
"""Submits a callable to be executed with the given arguments.
Schedules the callable to be executed as fn(*args, **kwargs) and returns
Expand All @@ -35,7 +36,7 @@ def submit(self, fn, *args, **kwargs):
self._future_queue.put({"fn": fn, "args": args, "kwargs": kwargs, "future": f})
return f

def shutdown(self, wait=True, *, cancel_futures=False):
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
"""Clean-up the resources associated with the Executor.
It is safe to call this method several times. Otherwise, no other
Expand All @@ -58,7 +59,7 @@ def shutdown(self, wait=True, *, cancel_futures=False):
self._process = None
self._future_queue = None

def _set_process(self, process):
def _set_process(self, process: threading.Thread):
self._process = process
self._process.start()

Expand All @@ -71,13 +72,13 @@ def __del__(self):
except (AttributeError, RuntimeError):
pass

def _set_process(self, process):
def _set_process(self, process: threading.Thread):
self._process = process
self._process.start()


class ExecutorBroker(ExecutorBase):
def shutdown(self, wait=True, *, cancel_futures=False):
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
"""Clean-up the resources associated with the Executor.
It is safe to call this method several times. Otherwise, no other
Expand All @@ -103,13 +104,13 @@ def shutdown(self, wait=True, *, cancel_futures=False):
self._process = None
self._future_queue = None

def _set_process(self, process):
def _set_process(self, process: threading.Thread):
self._process = process
for process in self._process:
process.start()


def cancel_items_in_queue(que):
def cancel_items_in_queue(que: queue.Queue):
"""
Cancel items which are still waiting in the queue. If the executor is busy tasks remain in the queue, so the future
objects have to be cancelled when the executor shuts down.
Expand All @@ -127,7 +128,7 @@ def cancel_items_in_queue(que):
break


def cloudpickle_register(ind=2):
def cloudpickle_register(ind: int = 2):
"""
Cloudpickle can either pickle by value or pickle by reference. The functions which are communicated have to
be pickled by value rather than by reference, so the module which calls the map function is pickled by value.
Expand All @@ -150,11 +151,11 @@ def cloudpickle_register(ind=2):


def execute_parallel_tasks(
future_queue,
cores,
future_queue: queue.Queue,
cores: int,
interface_class,
hostname_localhost=False,
init_function=None,
hostname_localhost: bool = False,
init_function: callable = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -183,7 +184,7 @@ def execute_parallel_tasks(
)


def execute_parallel_tasks_loop(interface, future_queue, init_function=None):
def execute_parallel_tasks_loop(interface, future_queue: queue.Queue, init_function: callable = None):
if init_function is not None:
interface.send_dict(
input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}}
Expand All @@ -209,7 +210,7 @@ def execute_parallel_tasks_loop(interface, future_queue, init_function=None):
future_queue.task_done()


def _get_backend_path(cores):
def _get_backend_path(cores: int):
command_lst = [sys.executable]
if cores > 1:
command_lst += [_get_command_path(executable="mpiexec.py")]
Expand All @@ -218,5 +219,5 @@ def _get_backend_path(cores):
return command_lst


def _get_command_path(executable):
def _get_command_path(executable: str):
return os.path.abspath(os.path.join(__file__, "..", "..", "backend", executable))
Loading

0 comments on commit 83a3788

Please sign in to comment.