-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MPIExecutor -- a wrapper class over HTEx which fixes or removes options irrelevant when enable_mpi_mode=True.
- Loading branch information
Showing
7 changed files
with
323 additions
and
361 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
from parsl.executors.threads import ThreadPoolExecutor | ||
from parsl.executors.workqueue.executor import WorkQueueExecutor | ||
from parsl.executors.high_throughput.executor import HighThroughputExecutor | ||
from parsl.executors.high_throughput.mpi_executor import MPIExecutor | ||
from parsl.executors.flux.executor import FluxExecutor | ||
|
||
__all__ = ['ThreadPoolExecutor', | ||
'HighThroughputExecutor', | ||
'MPIExecutor', | ||
'WorkQueueExecutor', | ||
'FluxExecutor'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""A simplified interface for HTEx when running in MPI mode""" | ||
from typing import Optional, Tuple, List, Union, Callable, Dict | ||
|
||
import typeguard | ||
|
||
from parsl.data_provider.staging import Staging | ||
from parsl.executors.high_throughput.executor import HighThroughputExecutor, GENERAL_HTEX_PARAM_DOCS | ||
from parsl.executors.status_handling import BlockProviderExecutor | ||
from parsl.jobs.states import JobStatus | ||
from parsl.providers import LocalProvider | ||
from parsl.providers.base import ExecutionProvider | ||
|
||
|
||
class MPIExecutor(HighThroughputExecutor): | ||
__doc__ = f"""A version of :class:`~parsl.HighThroughputExecutor` tuned for executing multi-node (e.g., MPI) tasks. | ||
The Provider _must_ use the :class:`~parsl.launchers.SimpleLauncher`, | ||
which places a single pool of workers on the first node of a block. | ||
Each worker can then make system calls which use an MPI launcher (e.g., ``mpirun``, ``srun``) | ||
to spawn multi-node tasks. | ||
Specify the maximum number of multi-node tasks to run at once using ``max_workers_per_block``. | ||
The maximum number should be smaller than the ``nodes_per_block`` in the Provider. | ||
Parameters | ||
---------- | ||
max_workers_per_block: int | ||
Maximum number of MPI applications to run at once per block | ||
{GENERAL_HTEX_PARAM_DOCS} | ||
""" | ||
|
||
@typeguard.typechecked | ||
def __init__(self, | ||
label: str = 'MPIExecutor', | ||
provider: ExecutionProvider = LocalProvider(), | ||
launch_cmd: Optional[str] = None, | ||
address: Optional[str] = None, | ||
worker_ports: Optional[Tuple[int, int]] = None, | ||
worker_port_range: Optional[Tuple[int, int]] = (54000, 55000), | ||
interchange_port_range: Optional[Tuple[int, int]] = (55000, 56000), | ||
storage_access: Optional[List[Staging]] = None, | ||
working_dir: Optional[str] = None, | ||
worker_debug: bool = False, | ||
max_workers_per_block: int = 1, | ||
prefetch_capacity: int = 0, | ||
heartbeat_threshold: int = 120, | ||
heartbeat_period: int = 30, | ||
drain_period: Optional[int] = None, | ||
poll_period: int = 10, | ||
address_probe_timeout: Optional[int] = None, | ||
worker_logdir_root: Optional[str] = None, | ||
mpi_launcher: str = "mpiexec", | ||
block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True, | ||
encrypted: bool = False): | ||
super().__init__( | ||
# Hard-coded settings | ||
cores_per_worker=1e-9, # Ensures there will be at least an absurd number of workers | ||
enable_mpi_mode=True, | ||
max_workers_per_node=max_workers_per_block, | ||
|
||
# Everything else | ||
label=label, | ||
provider=provider, | ||
launch_cmd=launch_cmd, | ||
address=address, | ||
worker_ports=worker_ports, | ||
worker_port_range=worker_port_range, | ||
interchange_port_range=interchange_port_range, | ||
storage_access=storage_access, | ||
working_dir=working_dir, | ||
worker_debug=worker_debug, | ||
prefetch_capacity=prefetch_capacity, | ||
heartbeat_threshold=heartbeat_threshold, | ||
heartbeat_period=heartbeat_period, | ||
drain_period=drain_period, | ||
poll_period=poll_period, | ||
address_probe_timeout=address_probe_timeout, | ||
worker_logdir_root=worker_logdir_root, | ||
mpi_launcher=mpi_launcher, | ||
block_error_handler=block_error_handler, | ||
encrypted=encrypted | ||
) | ||
|
||
self.max_workers_per_block = max_workers_per_block |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Tests for the wrapper class""" | ||
from inspect import signature | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
import parsl | ||
from .test_mpi_mode_enabled import get_env_vars | ||
from parsl import HighThroughputExecutor, Config | ||
from parsl.launchers import SimpleLauncher | ||
from parsl.providers import LocalProvider | ||
from parsl.executors.high_throughput.mpi_executor import MPIExecutor | ||
|
||
cwd = Path(__file__).parent.absolute() | ||
pbs_nodefile = cwd.joinpath("mocks", "pbs_nodefile") | ||
|
||
|
||
def local_config(): | ||
return Config( | ||
executors=[ | ||
MPIExecutor( | ||
max_workers_per_block=1, | ||
provider=LocalProvider( | ||
worker_init=f"export PBS_NODEFILE={pbs_nodefile}", | ||
launcher=SimpleLauncher() | ||
) | ||
) | ||
] | ||
) | ||
|
||
|
||
@pytest.mark.local | ||
def test_docstring(): | ||
"""Ensure the old kwargs are copied over into the new class""" | ||
assert 'label' in MPIExecutor.__doc__ | ||
assert 'max_workers_per_block' in MPIExecutor.__doc__ | ||
assert 'available_accelerators' not in MPIExecutor.__doc__ | ||
|
||
|
||
@pytest.mark.local | ||
def test_init(): | ||
"""Ensure all relevant kwargs are copied over from HTEx""" | ||
|
||
new_kwargs = {'max_workers_per_block'} | ||
excluded_kwargs = {'available_accelerators', 'enable_mpi_mode', 'cores_per_worker', 'max_workers_per_node', | ||
'mem_per_worker', 'cpu_affinity', 'max_workers'} | ||
|
||
# Get the kwargs from both HTEx and MPIEx | ||
htex_kwargs = set(signature(HighThroughputExecutor.__init__).parameters) | ||
mpix_kwargs = set(signature(MPIExecutor.__init__).parameters) | ||
|
||
assert mpix_kwargs.difference(htex_kwargs) == new_kwargs | ||
assert len(mpix_kwargs.intersection(excluded_kwargs)) == 0 | ||
assert mpix_kwargs.union(excluded_kwargs).difference(new_kwargs) == htex_kwargs | ||
|
||
|
||
@pytest.mark.local | ||
def test_get_env(): | ||
future = get_env_vars(parsl_resource_specification={ | ||
"num_nodes": 3, | ||
"ranks_per_node": 5, | ||
}) | ||
env_vars = future.result() | ||
assert env_vars['PARSL_NUM_RANKS'] == '15' |