Skip to content

Commit

Permalink
Add type checking with mypy (#535)
Browse files Browse the repository at this point in the history
* Add type checking with mypy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Improve coverage

* more fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jan-janssen and pre-commit-ci[bot] authored Dec 24, 2024
1 parent 5cf3ecc commit 95c9480
Show file tree
Hide file tree
Showing 24 changed files with 243 additions and 180 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: MyPy

on:
push:
branches: [ main ]
pull_request:

jobs:
mypy:
runs-on: ubuntu-latest
steps:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.13"
architecture: x64
- name: Checkout
uses: actions/checkout@v4
- name: Install mypy
run: pip install mypy
- name: Test
run: mypy --ignore-missing-imports ${{ github.event.repository.name }}
10 changes: 5 additions & 5 deletions executorlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional

from executorlib._version import get_versions as _get_versions
from executorlib.interactive.executor import (
Expand All @@ -16,7 +16,7 @@
)

__version__ = _get_versions()["version"]
__all__ = []
__all__: list = []


class Executor:
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
pysqa_config_directory: Optional[str] = None,
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[callable] = None,
init_function: Optional[Callable] = None,
disable_dependencies: bool = False,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
Expand All @@ -123,7 +123,7 @@ def __new__(
pysqa_config_directory: Optional[str] = None,
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[callable] = None,
init_function: Optional[Callable] = None,
disable_dependencies: bool = False,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
Expand Down Expand Up @@ -177,7 +177,7 @@ def __new__(
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.
"""
default_resource_dict = {
default_resource_dict: dict = {
"cores": 1,
"threads_per_core": 1,
"gpus_per_core": 0,
Expand Down
6 changes: 3 additions & 3 deletions executorlib/backend/cache_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle
import sys
import time
from typing import Any

import cloudpickle

Expand All @@ -24,7 +25,7 @@ def main() -> None:
"""
from mpi4py import MPI

MPI.pickle.__init__(
MPI.pickle.__init__( # type: ignore
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
Expand All @@ -34,10 +35,9 @@ def main() -> None:
file_name = sys.argv[1]

time_start = time.time()
apply_dict = {}
if mpi_rank_zero:
apply_dict = backend_load_file(file_name=file_name)
else:
apply_dict = None
apply_dict = MPI.COMM_WORLD.bcast(apply_dict, root=0)
output = apply_dict["fn"].__call__(*apply_dict["args"], **apply_dict["kwargs"])
if mpi_size_larger_one:
Expand Down
12 changes: 6 additions & 6 deletions executorlib/backend/interactive_parallel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pickle
import sys
from os.path import abspath
from typing import Optional

import cloudpickle
import zmq

from executorlib.standalone.interactive.backend import call_funct, parse_arguments
from executorlib.standalone.interactive.communication import (
Expand All @@ -24,7 +26,7 @@ def main() -> None:
"""
from mpi4py import MPI

MPI.pickle.__init__(
MPI.pickle.__init__( # type: ignore
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
Expand All @@ -33,13 +35,12 @@ def main() -> None:
mpi_size_larger_one = MPI.COMM_WORLD.Get_size() > 1

argument_dict = parse_arguments(argument_lst=sys.argv)
context: Optional[zmq.Context] = None
socket: Optional[zmq.Socket] = None
if mpi_rank_zero:
context, socket = interface_connect(
host=argument_dict["host"], port=argument_dict["zmqport"]
)
else:
context = None
socket = None

memory = None

Expand All @@ -50,10 +51,9 @@ def main() -> None:

while True:
# Read from socket
input_dict: dict = {}
if mpi_rank_zero:
input_dict = interface_receive(socket=socket)
else:
input_dict = None
input_dict = MPI.COMM_WORLD.bcast(input_dict, root=0)

# Parse input
Expand Down
46 changes: 25 additions & 21 deletions executorlib/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from concurrent.futures import (
Future,
)
from typing import Optional
from typing import Callable, List, Optional, Union

from executorlib.standalone.inputcheck import check_resource_dict
from executorlib.standalone.queue import cancel_items_in_queue
Expand All @@ -27,8 +27,8 @@ def __init__(self, max_cores: Optional[int] = None):
"""
cloudpickle_register(ind=3)
self._max_cores = max_cores
self._future_queue: queue.Queue = queue.Queue()
self._process: Optional[RaisingThread] = None
self._future_queue: Optional[queue.Queue] = queue.Queue()
self._process: Optional[Union[RaisingThread, List[RaisingThread]]] = None

@property
def info(self) -> Optional[dict]:
Expand All @@ -39,21 +39,21 @@ def info(self) -> Optional[dict]:
Optional[dict]: Information about the executor.
"""
if self._process is not None and isinstance(self._process, list):
meta_data_dict = self._process[0]._kwargs.copy()
meta_data_dict = self._process[0].get_kwargs().copy()
if "future_queue" in meta_data_dict.keys():
del meta_data_dict["future_queue"]
meta_data_dict["max_workers"] = len(self._process)
return meta_data_dict
elif self._process is not None:
meta_data_dict = self._process._kwargs.copy()
meta_data_dict = self._process.get_kwargs().copy()
if "future_queue" in meta_data_dict.keys():
del meta_data_dict["future_queue"]
return meta_data_dict
else:
return None

@property
def future_queue(self) -> queue.Queue:
def future_queue(self) -> Optional[queue.Queue]:
"""
Get the future queue.
Expand All @@ -62,7 +62,7 @@ def future_queue(self) -> queue.Queue:
"""
return self._future_queue

def submit(self, fn: callable, *args, resource_dict: dict = {}, **kwargs) -> Future:
def submit(self, fn: Callable, *args, resource_dict: dict = {}, **kwargs) -> Future: # type: ignore
"""
Submits a callable to be executed with the given arguments.
Expand Down Expand Up @@ -97,16 +97,17 @@ def submit(self, fn: callable, *args, resource_dict: dict = {}, **kwargs) -> Fut
"The specified number of cores is larger than the available number of cores."
)
check_resource_dict(function=fn)
f = Future()
self._future_queue.put(
{
"fn": fn,
"args": args,
"kwargs": kwargs,
"future": f,
"resource_dict": resource_dict,
}
)
f: Future = Future()
if self._future_queue is not None:
self._future_queue.put(
{
"fn": fn,
"args": args,
"kwargs": kwargs,
"future": f,
"resource_dict": resource_dict,
}
)
return f

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
Expand All @@ -124,11 +125,11 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
futures. Futures that are completed or running will not be
cancelled.
"""
if cancel_futures:
if cancel_futures and self._future_queue is not None:
cancel_items_in_queue(que=self._future_queue)
if self._process is not None:
if self._process is not None and self._future_queue is not None:
self._future_queue.put({"shutdown": True, "wait": wait})
if wait:
if wait and isinstance(self._process, RaisingThread):
self._process.join()
self._future_queue.join()
self._process = None
Expand All @@ -151,7 +152,10 @@ def __len__(self) -> int:
Returns:
int: The length of the executor.
"""
return self._future_queue.qsize()
queue_size = 0
if self._future_queue is not None:
queue_size = self._future_queue.qsize()
return queue_size

def __del__(self):
"""
Expand Down
18 changes: 9 additions & 9 deletions executorlib/cache/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Callable, Optional

from executorlib.base.executor import ExecutorBase
from executorlib.cache.shared import execute_tasks_h5
Expand All @@ -21,16 +21,16 @@
from executorlib.cache.queue_spawner import execute_with_pysqa
except ImportError:
# If pysqa is not available fall back to executing tasks in a subprocess
execute_with_pysqa = execute_in_subprocess
execute_with_pysqa = execute_in_subprocess # type: ignore


class FileExecutor(ExecutorBase):
def __init__(
self,
cache_directory: str = "cache",
resource_dict: Optional[dict] = None,
execute_function: callable = execute_with_pysqa,
terminate_function: Optional[callable] = None,
execute_function: Callable = execute_with_pysqa,
terminate_function: Optional[Callable] = None,
pysqa_config_directory: Optional[str] = None,
backend: Optional[str] = None,
disable_dependencies: bool = False,
Expand All @@ -43,8 +43,8 @@ def __init__(
resource_dict (dict): A dictionary of resources required by the task. With the following keys:
- cores (int): number of MPI cores to be used for each function call
- cwd (str/None): current working directory where the parallel python task is executed
execute_function (callable, optional): The function to execute tasks. Defaults to execute_in_subprocess.
terminate_function (callable, optional): The function to terminate the tasks.
execute_function (Callable, optional): The function to execute tasks. Defaults to execute_in_subprocess.
terminate_function (Callable, optional): The function to terminate the tasks.
pysqa_config_directory (str, optional): path to the pysqa config directory (only for pysqa based backend).
backend (str, optional): name of the backend used to spawn tasks.
disable_dependencies (boolean): Disable resolving future objects during the submission.
Expand Down Expand Up @@ -81,9 +81,9 @@ def __init__(


def create_file_executor(
max_workers: int = 1,
max_workers: Optional[int] = None,
backend: str = "flux_submission",
max_cores: int = 1,
max_cores: Optional[int] = None,
cache_directory: Optional[str] = None,
resource_dict: Optional[dict] = None,
flux_executor=None,
Expand All @@ -93,7 +93,7 @@ def create_file_executor(
pysqa_config_directory: Optional[str] = None,
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[callable] = None,
init_function: Optional[Callable] = None,
disable_dependencies: bool = False,
):
if cache_directory is None:
Expand Down
2 changes: 1 addition & 1 deletion executorlib/cache/queue_spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def execute_with_pysqa(
config_directory: Optional[str] = None,
backend: Optional[str] = None,
cache_directory: Optional[str] = None,
) -> Tuple[int, int]:
) -> Optional[int]:
"""
Execute a command by submitting it to the queuing system
Expand Down
16 changes: 9 additions & 7 deletions executorlib/cache/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import queue
import sys
from concurrent.futures import Future
from typing import Optional, Tuple
from typing import Any, Callable, Optional, Tuple

from executorlib.standalone.command import get_command_path
from executorlib.standalone.hdf import dump, get_output
Expand All @@ -21,7 +21,7 @@ def __init__(self, file_name: str):
"""
self._file_name = file_name

def result(self) -> str:
def result(self) -> Any:
"""
Get the result of the future item.
Expand Down Expand Up @@ -49,9 +49,9 @@ def done(self) -> bool:
def execute_tasks_h5(
future_queue: queue.Queue,
cache_directory: str,
execute_function: callable,
execute_function: Callable,
resource_dict: dict,
terminate_function: Optional[callable] = None,
terminate_function: Optional[Callable] = None,
pysqa_config_directory: Optional[str] = None,
backend: Optional[str] = None,
disable_dependencies: bool = False,
Expand All @@ -65,16 +65,18 @@ def execute_tasks_h5(
resource_dict (dict): A dictionary of resources required by the task. With the following keys:
- cores (int): number of MPI cores to be used for each function call
- cwd (str/None): current working directory where the parallel python task is executed
execute_function (callable): The function to execute the tasks.
terminate_function (callable): The function to terminate the tasks.
execute_function (Callable): The function to execute the tasks.
terminate_function (Callable): The function to terminate the tasks.
pysqa_config_directory (str, optional): path to the pysqa config directory (only for pysqa based backend).
backend (str, optional): name of the backend used to spawn tasks.
Returns:
None
"""
memory_dict, process_dict, file_name_dict = {}, {}, {}
memory_dict: dict = {}
process_dict: dict = {}
file_name_dict: dict = {}
while True:
task_dict = None
try:
Expand Down
Loading

0 comments on commit 95c9480

Please sign in to comment.