Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints #284

Merged
merged 7 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/workflows/unittest-mpich.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ jobs:
label: linux-64-py-3-9-mpich
prefix: /usr/share/miniconda3/envs/my-env

- operating-system: ubuntu-latest
python-version: 3.8
label: linux-64-py-3-8-mpich
prefix: /usr/share/miniconda3/envs/my-env

steps:
- uses: actions/checkout@v2
- uses: conda-incubator/[email protected]
Expand Down
5 changes: 0 additions & 5 deletions .github/workflows/unittest-openmpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ jobs:
label: linux-64-py-3-9-openmpi
prefix: /usr/share/miniconda3/envs/my-env

- operating-system: ubuntu-latest
python-version: 3.8
label: linux-64-py-3-8-openmpi
prefix: /usr/share/miniconda3/envs/my-env

steps:
- uses: actions/checkout@v2
- uses: conda-incubator/[email protected]
Expand Down
33 changes: 17 additions & 16 deletions pympipool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
from typing import Optional
from ._version import get_versions
from pympipool.mpi.executor import PyMPIExecutor
from pympipool.shared.interface import SLURM_COMMAND
Expand Down Expand Up @@ -69,30 +70,30 @@ class Executor:

def __init__(
self,
max_workers=1,
cores_per_worker=1,
threads_per_core=1,
gpus_per_worker=0,
oversubscribe=False,
init_function=None,
cwd=None,
max_workers: int = 1,
cores_per_worker: int = 1,
threads_per_core: int = 1,
gpus_per_worker: int = 0,
oversubscribe: bool = False,
init_function: Optional[callable] = None,
cwd: Optional[str] = None,
executor=None,
hostname_localhost=False,
hostname_localhost: bool = False,
):
# Use __new__() instead of __init__(). This function is only implemented to enable auto-completion.
pass

def __new__(
cls,
max_workers=1,
cores_per_worker=1,
threads_per_core=1,
gpus_per_worker=0,
oversubscribe=False,
init_function=None,
cwd=None,
max_workers: int = 1,
cores_per_worker: int = 1,
threads_per_core: int = 1,
gpus_per_worker: int = 0,
oversubscribe: bool = False,
init_function: Optional[callable] = None,
cwd: Optional[str] = None,
executor=None,
hostname_localhost=False,
hostname_localhost: bool = False,
):
"""
Instead of returning a pympipool.Executor object this function returns either a pympipool.mpi.PyMPIExecutor,
Expand Down
3 changes: 2 additions & 1 deletion pympipool/backend/serial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from os.path import abspath
import sys
from typing import Optional

from pympipool.shared.communication import (
interface_connect,
Expand All @@ -10,7 +11,7 @@
from pympipool.shared.backend import call_funct, parse_arguments


def main(argument_lst=None):
def main(argument_lst: Optional[list[str]] = None):
if argument_lst is None:
argument_lst = sys.argv
argument_dict = parse_arguments(argument_lst=argument_lst)
Expand Down
33 changes: 17 additions & 16 deletions pympipool/flux/executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional

import flux.job

Expand Down Expand Up @@ -56,14 +57,14 @@ class PyFluxExecutor(ExecutorBroker):

def __init__(
self,
max_workers=1,
cores_per_worker=1,
threads_per_core=1,
gpus_per_worker=0,
init_function=None,
cwd=None,
executor=None,
hostname_localhost=False,
max_workers: int = 1,
cores_per_worker: int = 1,
threads_per_core: int = 1,
gpus_per_worker: int = 0,
init_function: Optional[callable] = None,
cwd: Optional[str] = None,
executor: Optional[flux.job.FluxExecutor] = None,
hostname_localhost: Optional[bool] = False,
):
super().__init__()
self._set_process(
Expand Down Expand Up @@ -92,12 +93,12 @@ def __init__(
class FluxPythonInterface(BaseInterface):
def __init__(
self,
cwd=None,
cores=1,
threads_per_core=1,
gpus_per_core=0,
oversubscribe=False,
executor=None,
cwd: Optional[str] = None,
cores: int = 1,
threads_per_core: int = 1,
gpus_per_core: int = 0,
oversubscribe: bool = False,
executor: Optional[flux.job.FluxExecutor] = None,
):
super().__init__(
cwd=cwd,
Expand All @@ -109,7 +110,7 @@ def __init__(
self._executor = executor
self._future = None

def bootup(self, command_lst):
def bootup(self, command_lst: list[str]):
if self._oversubscribe:
raise ValueError(
"Oversubscribing is currently not supported for the Flux adapter."
Expand All @@ -129,7 +130,7 @@ def bootup(self, command_lst):
jobspec.cwd = self._cwd
self._future = self._executor.submit(jobspec)

def shutdown(self, wait=True):
def shutdown(self, wait: bool = True):
if self.poll():
self._future.cancel()
# The flux future objects are not instantly updated,
Expand Down
14 changes: 8 additions & 6 deletions pympipool/mpi/executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pympipool.shared.executorbase import (
execute_parallel_tasks,
ExecutorBroker,
Expand Down Expand Up @@ -51,12 +53,12 @@ class PyMPIExecutor(ExecutorBroker):

def __init__(
self,
max_workers=1,
cores_per_worker=1,
oversubscribe=False,
init_function=None,
cwd=None,
hostname_localhost=False,
max_workers: int = 1,
cores_per_worker: int = 1,
oversubscribe: bool = False,
init_function: Optional[callable] = None,
cwd: Optional[str] = None,
hostname_localhost: bool = False,
):
super().__init__()
self._set_process(
Expand Down
15 changes: 11 additions & 4 deletions pympipool/shared/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional
import inspect


def call_funct(input_dict, funct=None, memory=None):
def call_funct(
input_dict: dict, funct: Optional[callable] = None, memory: Optional[dict] = None
) -> callable:
"""
Call function from dictionary

Expand Down Expand Up @@ -30,7 +33,7 @@ def funct(*args, **kwargs):
return funct(input_dict["fn"], *input_dict["args"], **input_dict["kwargs"])


def parse_arguments(argument_lst):
def parse_arguments(argument_lst: list[str]) -> dict:
"""
Simple function to parse command line arguments

Expand All @@ -50,7 +53,9 @@ def parse_arguments(argument_lst):
)


def update_default_dict_from_arguments(argument_lst, argument_dict, default_dict):
def update_default_dict_from_arguments(
argument_lst: list[str], argument_dict: dict, default_dict: dict
) -> dict:
default_dict.update(
{
k: argument_lst[argument_lst.index(v) + 1]
Expand All @@ -61,7 +66,9 @@ def update_default_dict_from_arguments(argument_lst, argument_dict, default_dict
return default_dict


def _update_dict_delta(dict_input, dict_output, keys_possible_lst):
def _update_dict_delta(
dict_input: dict, dict_output: dict, keys_possible_lst: list
) -> dict:
return {
k: v
for k, v in dict_input.items()
Expand Down
20 changes: 10 additions & 10 deletions pympipool/shared/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, interface=None):
self._process = None
self._interface = interface

def send_dict(self, input_dict):
def send_dict(self, input_dict: dict):
"""
Send a dictionary with instructions to a connected client process.

Expand All @@ -42,7 +42,7 @@ def receive_dict(self):
error_type = output["error_type"].split("'")[1]
raise eval(error_type)(output["error"])

def send_and_receive_dict(self, input_dict):
def send_and_receive_dict(self, input_dict: dict) -> dict:
"""
Combine both the send_dict() and receive_dict() function in a single call.

Expand All @@ -66,7 +66,7 @@ def bind_to_random_port(self):
"""
return self._socket.bind_to_random_port("tcp://*")

def bootup(self, command_lst):
def bootup(self, command_lst: list[str]):
"""
Boot up the client process to connect to the SocketInterface.

Expand All @@ -75,7 +75,7 @@ def bootup(self, command_lst):
"""
self._interface.bootup(command_lst=command_lst)

def shutdown(self, wait=True):
def shutdown(self, wait: bool = True):
result = None
if self._interface.poll():
result = self.send_and_receive_dict(
Expand All @@ -96,9 +96,9 @@ def __del__(self):


def interface_bootup(
command_lst,
command_lst: list[str],
connections,
hostname_localhost=False,
hostname_localhost: bool = False,
):
"""
Start interface for ZMQ communication
Expand Down Expand Up @@ -132,7 +132,7 @@ def interface_bootup(
return interface


def interface_connect(host, port):
def interface_connect(host: str, port: str):
"""
Connect to an existing SocketInterface instance by providing the hostname and the port as strings.

Expand All @@ -146,7 +146,7 @@ def interface_connect(host, port):
return context, socket


def interface_send(socket, result_dict):
def interface_send(socket: zmq.Socket, result_dict: dict):
"""
Send results to a SocketInterface instance.

Expand All @@ -157,7 +157,7 @@ def interface_send(socket, result_dict):
socket.send(cloudpickle.dumps(result_dict))


def interface_receive(socket):
def interface_receive(socket: zmq.Socket):
"""
Receive instructions from a SocketInterface instance.

Expand All @@ -167,7 +167,7 @@ def interface_receive(socket):
return cloudpickle.loads(socket.recv())


def interface_shutdown(socket, context):
def interface_shutdown(socket: zmq.Socket, context: zmq.Context):
"""
Close the connection to a SocketInterface instance.

Expand Down
Loading
Loading