From b4ca374676bf139768e09c76cb13f1f2d8f7e2f3 Mon Sep 17 00:00:00 2001 From: Reid Mello <30907815+rjmello@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:46:43 -0500 Subject: [PATCH] Implement CurveZMQ in HTEX The interchange serves as a CurveZMQ server, while the executor and various managers serve as CurveZMQ clients. Thus, all communication between these entities is now encrypted. The HTEX `start` method generates new certs for each run in a private `certificates/` directory. We generate a single shared client cert because all clients will have access to this dir. We disable encryption by default, but users can enable it by setting the `encrypted` initialization argument for the HTEX to `True`. --- parsl/executors/high_throughput/executor.py | 53 ++++++++++++--- .../executors/high_throughput/interchange.py | 21 ++++-- .../high_throughput/process_worker_pool.py | 22 +++++-- parsl/executors/high_throughput/zmq_pipes.py | 36 +++++----- parsl/tests/configs/ad_hoc_cluster_htex.py | 1 + parsl/tests/configs/azure_single_node.py | 1 + parsl/tests/configs/bluewaters.py | 1 + parsl/tests/configs/bridges.py | 1 + parsl/tests/configs/cc_in2p3.py | 1 + parsl/tests/configs/comet.py | 1 + parsl/tests/configs/cooley_htex.py | 1 + parsl/tests/configs/ec2_single_node.py | 1 + parsl/tests/configs/ec2_spot.py | 1 + parsl/tests/configs/frontera.py | 1 + parsl/tests/configs/htex_ad_hoc_cluster.py | 1 + parsl/tests/configs/htex_local.py | 1 + parsl/tests/configs/htex_local_alternate.py | 1 + .../configs/htex_local_intask_staging.py | 1 + .../tests/configs/htex_local_rsync_staging.py | 1 + parsl/tests/configs/local_adhoc.py | 1 + parsl/tests/configs/midway.py | 1 + parsl/tests/configs/nscc_singapore.py | 1 + parsl/tests/configs/osg_htex.py | 1 + parsl/tests/configs/petrelkube.py | 1 + parsl/tests/configs/summit.py | 1 + parsl/tests/configs/swan_htex.py | 1 + parsl/tests/configs/theta.py | 1 + parsl/tests/manual_tests/htex_local.py | 1 + parsl/tests/manual_tests/test_ad_hoc_htex.py | 1 + .../test_fan_in_out_htex_remote.py | 1 + .../tests/manual_tests/test_memory_limits.py | 1 + parsl/tests/scaling_tests/htex_local.py | 1 + parsl/tests/sites/test_affinity.py | 1 + parsl/tests/sites/test_concurrent.py | 1 + parsl/tests/sites/test_dynamic_executor.py | 1 + parsl/tests/sites/test_worker_info.py | 1 + parsl/tests/test_htex/test_htex.py | 46 +++++++++++++ .../tests/test_htex/test_htex_zmq_binding.py | 66 +++++++++++++++---- .../test_regression/test_97_parallelism_0.py | 1 + .../test_scaling/test_block_error_handler.py | 11 ++-- .../test_scaling/test_regression_1621.py | 1 + parsl/tests/test_scaling/test_scale_down.py | 1 + 42 files changed, 233 insertions(+), 57 deletions(-) create mode 100644 parsl/tests/test_htex/test_htex.py diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index 2e63df5a72..46929f75a2 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -22,6 +22,7 @@ UnsupportedFeatureError ) +from parsl import curvezmq from parsl.executors.status_handling import BlockProviderExecutor from parsl.providers.base import ExecutionProvider from parsl.data_provider.staging import Staging @@ -174,6 +175,9 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin): worker_logdir_root : string In case of a remote file system, specify the path to where logs will be kept. + + encrypted : bool + Flag to enable/disable encryption (CurveZMQ). Default is False. """ @typeguard.typechecked @@ -199,7 +203,8 @@ def __init__(self, poll_period: int = 10, address_probe_timeout: Optional[int] = None, worker_logdir_root: Optional[str] = None, - block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True): + block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True, + encrypted: bool = False): logger.debug("Initializing HighThroughputExecutor") @@ -256,6 +261,8 @@ def __init__(self, self.run_dir = '.' self.worker_logdir_root = worker_logdir_root self.cpu_affinity = cpu_affinity + self.encrypted = encrypted + self.cert_dir = None if not launch_cmd: launch_cmd = ( @@ -267,6 +274,7 @@ def __init__(self, "--poll {poll_period} " "--task_port={task_port} " "--result_port={result_port} " + "--cert_dir {cert_dir} " "--logdir={logdir} " "--block_id={{block_id}} " "--hb_period={heartbeat_period} " @@ -280,6 +288,16 @@ def __init__(self, radio_mode = "htex" + @property + def logdir(self): + return "{}/{}".format(self.run_dir, self.label) + + @property + def worker_logdir(self): + if self.worker_logdir_root is not None: + return "{}/{}".format(self.worker_logdir_root, self.label) + return self.logdir + def initialize_scaling(self): """Compose the launch command and scale out the initial blocks. """ @@ -289,9 +307,6 @@ def initialize_scaling(self): address_probe_timeout_string = "" if self.address_probe_timeout: address_probe_timeout_string = "--address_probe_timeout={}".format(self.address_probe_timeout) - worker_logdir = "{}/{}".format(self.run_dir, self.label) - if self.worker_logdir_root is not None: - worker_logdir = "{}/{}".format(self.worker_logdir_root, self.label) l_cmd = self.launch_cmd.format(debug=debug_opts, prefetch_capacity=self.prefetch_capacity, @@ -306,7 +321,8 @@ def initialize_scaling(self): heartbeat_period=self.heartbeat_period, heartbeat_threshold=self.heartbeat_threshold, poll_period=self.poll_period, - logdir=worker_logdir, + cert_dir=self.cert_dir, + logdir=self.worker_logdir, cpu_affinity=self.cpu_affinity, accelerators=" ".join(self.available_accelerators)) self.launch_cmd = l_cmd @@ -327,9 +343,25 @@ def initialize_scaling(self): def start(self): """Create the Interchange process and connect to it. """ - self.outgoing_q = zmq_pipes.TasksOutgoing("127.0.0.1", self.interchange_port_range) - self.incoming_q = zmq_pipes.ResultsIncoming("127.0.0.1", self.interchange_port_range) - self.command_client = zmq_pipes.CommandClient("127.0.0.1", self.interchange_port_range) + if self.encrypted and self.cert_dir is None: + logger.debug("Creating CurveZMQ certificates") + self.cert_dir = curvezmq.create_certificates(self.logdir) + elif not self.encrypted and self.cert_dir: + raise AttributeError( + "The certificates directory path attribute (cert_dir) is defined, but the " + "encrypted attribute is set to False. You must either change cert_dir to " + "None or encrypted to True." + ) + + self.outgoing_q = zmq_pipes.TasksOutgoing( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) + self.incoming_q = zmq_pipes.ResultsIncoming( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) + self.command_client = zmq_pipes.CommandClient( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) self._queue_management_thread = None self._start_queue_management_thread() @@ -450,10 +482,11 @@ def _start_local_interchange_process(self): "worker_port_range": self.worker_port_range, "hub_address": self.hub_address, "hub_port": self.hub_port, - "logdir": "{}/{}".format(self.run_dir, self.label), + "logdir": self.logdir, "heartbeat_threshold": self.heartbeat_threshold, "poll_period": self.poll_period, - "logging_level": logging.DEBUG if self.worker_debug else logging.INFO + "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, + "cert_dir": self.cert_dir, }, daemon=True, name="HTEX-Interchange" diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 67b70aa78d..c65aecb57a 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -16,6 +16,7 @@ from typing import cast, Any, Dict, NoReturn, Sequence, Set, Optional, Tuple +from parsl import curvezmq from parsl.utils import setproctitle from parsl.version import VERSION as PARSL_VERSION from parsl.serialize import serialize as serialize_object @@ -79,6 +80,7 @@ def __init__(self, logdir: str = ".", logging_level: int = logging.INFO, poll_period: int = 10, + cert_dir: Optional[str] = None, ) -> None: """ Parameters @@ -120,7 +122,10 @@ def __init__(self, poll_period : int The main thread polling period, in milliseconds. Default: 10ms + cert_dir : str | None + Path to the certificate directory. Default: None """ + self.cert_dir = cert_dir self.logdir = logdir os.makedirs(self.logdir, exist_ok=True) @@ -134,15 +139,15 @@ def __init__(self, logger.info("Attempting connection to client at {} on ports: {},{},{}".format( client_address, client_ports[0], client_ports[1], client_ports[2])) - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) + self.zmq_context = curvezmq.ServerContext(self.cert_dir) + self.task_incoming = self.zmq_context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) - self.results_outgoing = self.context.socket(zmq.DEALER) + self.results_outgoing = self.zmq_context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) - self.command_channel = self.context.socket(zmq.REP) + self.command_channel = self.zmq_context.socket(zmq.REP) self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) logger.info("Connected to client") @@ -155,9 +160,9 @@ def __init__(self, self.worker_ports = worker_ports self.worker_port_range = worker_port_range - self.task_outgoing = self.context.socket(zmq.ROUTER) + self.task_outgoing = self.zmq_context.socket(zmq.ROUTER) self.task_outgoing.set_hwm(0) - self.results_incoming = self.context.socket(zmq.ROUTER) + self.results_incoming = self.zmq_context.socket(zmq.ROUTER) self.results_incoming.set_hwm(0) if self.worker_ports: @@ -241,7 +246,8 @@ def task_puller(self) -> NoReturn: def _create_monitoring_channel(self) -> Optional[zmq.Socket]: if self.hub_address and self.hub_port: logger.info("Connecting to monitoring") - hub_channel = self.context.socket(zmq.DEALER) + # This is a one-off because monitoring is unencrypted + hub_channel = zmq.Context().socket(zmq.DEALER) hub_channel.set_hwm(0) hub_channel.connect("tcp://{}:{}".format(self.hub_address, self.hub_port)) logger.info("Monitoring enabled and connected to hub") @@ -379,6 +385,7 @@ def start(self) -> None: self.expire_bad_managers(interesting_managers, hub_channel) self.process_tasks_to_send(interesting_managers) + self.zmq_context.destroy() delta = time.time() - start logger.info("Processed {} tasks in {} seconds".format(self.count, delta)) logger.warning("Exiting") diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index 57cff2c783..aed9b1e402 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -20,8 +20,8 @@ from multiprocessing.managers import DictProxy from multiprocessing.sharedctypes import Synchronized +from parsl import curvezmq from parsl.process_loggers import wrap_with_logs - from parsl.version import VERSION as PARSL_VERSION from parsl.app.errors import RemoteExceptionWrapper from parsl.executors.high_throughput.errors import WorkerLost @@ -63,7 +63,8 @@ def __init__(self, *, heartbeat_period, poll_period, cpu_affinity, - available_accelerators: Sequence[str]): + available_accelerators: Sequence[str], + cert_dir: Optional[str]): """ Parameters ---------- @@ -118,6 +119,8 @@ def __init__(self, *, available_accelerators: list of str List of accelerators available to the workers. + cert_dir : str | None + Path to the certificate directory. """ logger.info("Manager started") @@ -137,15 +140,16 @@ def __init__(self, *, print("Failed to find a viable address to connect to interchange. Exiting") exit(5) - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) + self.cert_dir = cert_dir + self.zmq_context = curvezmq.ClientContext(self.cert_dir) + self.task_incoming = self.zmq_context.socket(zmq.DEALER) self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) # Linger is set to 0, so that the manager can exit even when there might be # messages in the pipe self.task_incoming.setsockopt(zmq.LINGER, 0) self.task_incoming.connect(task_q_url) - self.result_outgoing = self.context.socket(zmq.DEALER) + self.result_outgoing = self.zmq_context.socket(zmq.DEALER) self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) self.result_outgoing.setsockopt(zmq.LINGER, 0) self.result_outgoing.connect(result_q_url) @@ -468,7 +472,7 @@ def start(self): self.task_incoming.close() self.result_outgoing.close() - self.context.term() + self.zmq_context.term() delta = time.time() - start logger.info("process_worker_pool ran for {} seconds".format(delta)) return @@ -703,6 +707,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_ help="Enable logging at DEBUG level") parser.add_argument("-a", "--addresses", default='', help="Comma separated list of addresses at which the interchange could be reached") + parser.add_argument("--cert_dir", required=True, + help="Path to certificate directory.") parser.add_argument("-l", "--logdir", default="process_worker_pool_logs", help="Process worker pool log directory") parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1], @@ -746,6 +752,7 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_ logger.info("Python version: {}".format(sys.version)) logger.info("Debug logging: {}".format(args.debug)) + logger.info("Certificates dir: {}".format(args.cert_dir)) logger.info("Log dir: {}".format(args.logdir)) logger.info("Manager ID: {}".format(args.uid)) logger.info("Block ID: {}".format(args.block_id)) @@ -777,7 +784,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_ heartbeat_period=int(args.hb_period), poll_period=int(args.poll), cpu_affinity=args.cpu_affinity, - available_accelerators=args.available_accelerators) + available_accelerators=args.available_accelerators, + cert_dir=None if args.cert_dir == "None" else args.cert_dir) manager.start() except Exception: diff --git a/parsl/executors/high_throughput/zmq_pipes.py b/parsl/executors/high_throughput/zmq_pipes.py index cd5c7f2e8b..72c1422d4d 100644 --- a/parsl/executors/high_throughput/zmq_pipes.py +++ b/parsl/executors/high_throughput/zmq_pipes.py @@ -4,24 +4,28 @@ import logging import threading +from parsl import curvezmq + logger = logging.getLogger(__name__) class CommandClient: """ CommandClient """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() + self.zmq_context = zmq_context self.ip_address = ip_address self.port_range = port_range self.port = None @@ -33,7 +37,7 @@ def create_socket_and_bind(self): Upon recreating the socket, we bind to the same port. """ - self.zmq_socket = self.context.socket(zmq.REQ) + self.zmq_socket = self.zmq_context.socket(zmq.REQ) self.zmq_socket.setsockopt(zmq.LINGER, 0) if self.port is None: self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(self.ip_address), @@ -62,9 +66,7 @@ def run(self, message, max_retries=3): except zmq.ZMQError: logger.exception("Potential ZMQ REQ-REP deadlock caught") logger.info("Trying to reestablish context") - self.zmq_socket.close() - self.context.destroy() - self.context = zmq.Context() + self.zmq_context.recreate() self.create_socket_and_bind() else: break @@ -77,25 +79,27 @@ def run(self, message, max_retries=3): def close(self): self.zmq_socket.close() - self.context.term() + self.zmq_context.term() class TasksOutgoing: """ Outgoing task queue from the executor to the Interchange """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() - self.zmq_socket = self.context.socket(zmq.DEALER) + self.zmq_context = zmq_context + self.zmq_socket = self.zmq_context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address), min_port=port_range[0], @@ -127,26 +131,28 @@ def put(self, message): def close(self): self.zmq_socket.close() - self.context.term() + self.zmq_context.term() class ResultsIncoming: """ Incoming results queue from the Interchange to the executor """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() - self.results_receiver = self.context.socket(zmq.DEALER) + self.zmq_context = zmq_context + self.results_receiver = self.zmq_context.socket(zmq.DEALER) self.results_receiver.set_hwm(0) self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address), min_port=port_range[0], @@ -160,4 +166,4 @@ def get(self): def close(self): self.results_receiver.close() - self.context.term() + self.zmq_context.term() diff --git a/parsl/tests/configs/ad_hoc_cluster_htex.py b/parsl/tests/configs/ad_hoc_cluster_htex.py index 5c90d27918..0a3e9dc027 100644 --- a/parsl/tests/configs/ad_hoc_cluster_htex.py +++ b/parsl/tests/configs/ad_hoc_cluster_htex.py @@ -18,6 +18,7 @@ label='remote_htex', max_workers=2, worker_logdir_root=user_opts['adhoc']['script_dir'], + encrypted=True, provider=AdHocProvider( # Command to be run before starting a worker, such as: # 'module load Anaconda; source activate parsl_env'. diff --git a/parsl/tests/configs/azure_single_node.py b/parsl/tests/configs/azure_single_node.py index 90e90c19cd..17f3e7def9 100644 --- a/parsl/tests/configs/azure_single_node.py +++ b/parsl/tests/configs/azure_single_node.py @@ -40,6 +40,7 @@ storage_access=[HTTPInTaskStaging(), FTPInTaskStaging(), RSyncStaging(getpass.getuser() + "@" + user_opts['public_ip'])], label='azure_single_node', address=user_opts['public_ip'], + encrypted=True, provider=provider ) ] diff --git a/parsl/tests/configs/bluewaters.py b/parsl/tests/configs/bluewaters.py index 2e27eb8c4a..7cd088ecac 100644 --- a/parsl/tests/configs/bluewaters.py +++ b/parsl/tests/configs/bluewaters.py @@ -14,6 +14,7 @@ def fresh_config(): cores_per_worker=1, worker_debug=False, max_workers=1, + encrypted=True, provider=TorqueProvider( queue='normal', launcher=AprunLauncher(overrides="-b -- bwpy-environ --"), diff --git a/parsl/tests/configs/bridges.py b/parsl/tests/configs/bridges.py index 4dea2fe468..06d0c0cd43 100644 --- a/parsl/tests/configs/bridges.py +++ b/parsl/tests/configs/bridges.py @@ -14,6 +14,7 @@ def fresh_config(): # which compute nodes can communicate # address=address_by_interface('bond0.144'), max_workers=1, + encrypted=True, provider=SlurmProvider( user_opts['bridges']['partition'], # Partition / QOS nodes_per_block=2, diff --git a/parsl/tests/configs/cc_in2p3.py b/parsl/tests/configs/cc_in2p3.py index 9b54fa37e4..9a76d1f16e 100644 --- a/parsl/tests/configs/cc_in2p3.py +++ b/parsl/tests/configs/cc_in2p3.py @@ -12,6 +12,7 @@ def fresh_config(): HighThroughputExecutor( label='cc_in2p3_htex', max_workers=1, + encrypted=True, provider=GridEngineProvider( channel=LocalChannel(), nodes_per_block=2, diff --git a/parsl/tests/configs/comet.py b/parsl/tests/configs/comet.py index 1a253e26b2..8f39539509 100644 --- a/parsl/tests/configs/comet.py +++ b/parsl/tests/configs/comet.py @@ -11,6 +11,7 @@ def fresh_config(): HighThroughputExecutor( label='Comet_HTEX_multinode', max_workers=1, + encrypted=True, provider=SlurmProvider( 'debug', launcher=SrunLauncher(), diff --git a/parsl/tests/configs/cooley_htex.py b/parsl/tests/configs/cooley_htex.py index 4228da10b0..202379d0af 100644 --- a/parsl/tests/configs/cooley_htex.py +++ b/parsl/tests/configs/cooley_htex.py @@ -18,6 +18,7 @@ label="cooley_htex", worker_debug=False, cores_per_worker=1, + encrypted=True, provider=CobaltProvider( queue='debug', account=user_opts['cooley']['account'], diff --git a/parsl/tests/configs/ec2_single_node.py b/parsl/tests/configs/ec2_single_node.py index 61cccc8bea..92c8108add 100644 --- a/parsl/tests/configs/ec2_single_node.py +++ b/parsl/tests/configs/ec2_single_node.py @@ -28,6 +28,7 @@ HighThroughputExecutor( label='ec2_single_node', address=user_opts['public_ip'], + encrypted=True, provider=AWSProvider( user_opts['ec2']['image_id'], region=user_opts['ec2']['region'], diff --git a/parsl/tests/configs/ec2_spot.py b/parsl/tests/configs/ec2_spot.py index c693d9b17d..37f272e4de 100644 --- a/parsl/tests/configs/ec2_spot.py +++ b/parsl/tests/configs/ec2_spot.py @@ -15,6 +15,7 @@ HighThroughputExecutor( label='ec2_single_node', address=user_opts['public_ip'], + encrypted=True, provider=AWSProvider( user_opts['ec2']['image_id'], region=user_opts['ec2']['region'], diff --git a/parsl/tests/configs/frontera.py b/parsl/tests/configs/frontera.py index a003ce88c6..ba302d85b5 100644 --- a/parsl/tests/configs/frontera.py +++ b/parsl/tests/configs/frontera.py @@ -16,6 +16,7 @@ def fresh_config(): HighThroughputExecutor( label="frontera_htex", max_workers=1, + encrypted=True, provider=SlurmProvider( cmd_timeout=60, # Add extra time for slow scheduler responses channel=LocalChannel(), diff --git a/parsl/tests/configs/htex_ad_hoc_cluster.py b/parsl/tests/configs/htex_ad_hoc_cluster.py index 80949f1d1e..82ae7bf621 100644 --- a/parsl/tests/configs/htex_ad_hoc_cluster.py +++ b/parsl/tests/configs/htex_ad_hoc_cluster.py @@ -13,6 +13,7 @@ cores_per_worker=1, worker_debug=False, address=user_opts['public_ip'], + encrypted=True, provider=AdHocProvider( move_files=False, parallelism=1, diff --git a/parsl/tests/configs/htex_local.py b/parsl/tests/configs/htex_local.py index 2039918a4d..0f09fce9dc 100644 --- a/parsl/tests/configs/htex_local.py +++ b/parsl/tests/configs/htex_local.py @@ -13,6 +13,7 @@ def fresh_config(): label="htex_local", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/configs/htex_local_alternate.py b/parsl/tests/configs/htex_local_alternate.py index e3a7b6ff1e..2598a91b85 100644 --- a/parsl/tests/configs/htex_local_alternate.py +++ b/parsl/tests/configs/htex_local_alternate.py @@ -48,6 +48,7 @@ def fresh_config(): heartbeat_period=2, heartbeat_threshold=5, poll_period=100, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=0, diff --git a/parsl/tests/configs/htex_local_intask_staging.py b/parsl/tests/configs/htex_local_intask_staging.py index c6ad88be69..634cbf1654 100644 --- a/parsl/tests/configs/htex_local_intask_staging.py +++ b/parsl/tests/configs/htex_local_intask_staging.py @@ -15,6 +15,7 @@ label="htex_Local", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/configs/htex_local_rsync_staging.py b/parsl/tests/configs/htex_local_rsync_staging.py index 514fbfaee7..6cb47e55e9 100644 --- a/parsl/tests/configs/htex_local_rsync_staging.py +++ b/parsl/tests/configs/htex_local_rsync_staging.py @@ -16,6 +16,7 @@ worker_debug=True, cores_per_worker=1, working_dir="./rsync-workdir/", + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/configs/local_adhoc.py b/parsl/tests/configs/local_adhoc.py index 7ac6d7e781..96c08ee309 100644 --- a/parsl/tests/configs/local_adhoc.py +++ b/parsl/tests/configs/local_adhoc.py @@ -9,6 +9,7 @@ def fresh_config(): executors=[ HighThroughputExecutor( label='AdHoc', + encrypted=True, provider=AdHocProvider( channels=[LocalChannel(), LocalChannel()] ) diff --git a/parsl/tests/configs/midway.py b/parsl/tests/configs/midway.py index b5362838e0..ca32edd2b1 100644 --- a/parsl/tests/configs/midway.py +++ b/parsl/tests/configs/midway.py @@ -13,6 +13,7 @@ def fresh_config(): label='Midway_HTEX_multinode', worker_debug=False, max_workers=1, + encrypted=True, provider=SlurmProvider( 'broadwl', # Partition name, e.g 'broadwl' launcher=SrunLauncher(), diff --git a/parsl/tests/configs/nscc_singapore.py b/parsl/tests/configs/nscc_singapore.py index d8abbd45ba..c78018dc55 100644 --- a/parsl/tests/configs/nscc_singapore.py +++ b/parsl/tests/configs/nscc_singapore.py @@ -17,6 +17,7 @@ def fresh_config(): worker_debug=False, max_workers=1, address=address_by_interface('ib0'), + encrypted=True, provider=PBSProProvider( launcher=MpiRunLauncher(), # string to prepend to #PBS blocks in the submit diff --git a/parsl/tests/configs/osg_htex.py b/parsl/tests/configs/osg_htex.py index e4f1d4a98f..17250cca5d 100644 --- a/parsl/tests/configs/osg_htex.py +++ b/parsl/tests/configs/osg_htex.py @@ -14,6 +14,7 @@ HighThroughputExecutor( label='OSG_HTEX', max_workers=1, + encrypted=True, provider=CondorProvider( nodes_per_block=1, init_blocks=4, diff --git a/parsl/tests/configs/petrelkube.py b/parsl/tests/configs/petrelkube.py index dbc2afec39..6e61cd101e 100644 --- a/parsl/tests/configs/petrelkube.py +++ b/parsl/tests/configs/petrelkube.py @@ -23,6 +23,7 @@ def fresh_config(): # Address for the pod worker to connect back address=address_by_route(), + encrypted=True, provider=KubernetesProvider( namespace="dlhub-privileged", diff --git a/parsl/tests/configs/summit.py b/parsl/tests/configs/summit.py index ae0d009334..ed87366e55 100644 --- a/parsl/tests/configs/summit.py +++ b/parsl/tests/configs/summit.py @@ -21,6 +21,7 @@ def fresh_config(): # address=address_by_interface('ib0'), # This assumes Parsl is running on login node worker_port_range=(50000, 55000), max_workers=1, + encrypted=True, provider=LSFProvider( launcher=JsrunLauncher(), walltime="00:10:00", diff --git a/parsl/tests/configs/swan_htex.py b/parsl/tests/configs/swan_htex.py index 8db9f5a975..4884703a2a 100644 --- a/parsl/tests/configs/swan_htex.py +++ b/parsl/tests/configs/swan_htex.py @@ -24,6 +24,7 @@ executors=[ HighThroughputExecutor( label='swan_htex', + encrypted=True, provider=TorqueProvider( channel=SSHChannel( hostname='swan.cray.com', diff --git a/parsl/tests/configs/theta.py b/parsl/tests/configs/theta.py index f2ce169390..36092a948e 100644 --- a/parsl/tests/configs/theta.py +++ b/parsl/tests/configs/theta.py @@ -12,6 +12,7 @@ def fresh_config(): HighThroughputExecutor( label='theta_local_htex_multinode', max_workers=1, + encrypted=True, provider=CobaltProvider( queue=user_opts['theta']['queue'], account=user_opts['theta']['account'], diff --git a/parsl/tests/manual_tests/htex_local.py b/parsl/tests/manual_tests/htex_local.py index a2cd3fae7f..a5c0449be1 100644 --- a/parsl/tests/manual_tests/htex_local.py +++ b/parsl/tests/manual_tests/htex_local.py @@ -13,6 +13,7 @@ label="htex_local", # worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/manual_tests/test_ad_hoc_htex.py b/parsl/tests/manual_tests/test_ad_hoc_htex.py index d022686d6e..030658f243 100644 --- a/parsl/tests/manual_tests/test_ad_hoc_htex.py +++ b/parsl/tests/manual_tests/test_ad_hoc_htex.py @@ -15,6 +15,7 @@ label='AdHoc', max_workers=2, worker_logdir_root="/scratch/midway2/yadunand/parsl_scripts", + encrypted=True, provider=AdHocProvider( worker_init="source /scratch/midway2/yadunand/parsl_env_setup.sh", channels=[SSHChannel(hostname=m, diff --git a/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py b/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py index 72dc4c7fd7..ce09e4bcad 100644 --- a/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py +++ b/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py @@ -16,6 +16,7 @@ def local_setup(): label="theta_htex", # worker_debug=True, cores_per_worker=4, + encrypted=True, provider=CobaltProvider( queue='debug-flat-quad', account='CSC249ADCD01', diff --git a/parsl/tests/manual_tests/test_memory_limits.py b/parsl/tests/manual_tests/test_memory_limits.py index 9024939827..be4507808f 100644 --- a/parsl/tests/manual_tests/test_memory_limits.py +++ b/parsl/tests/manual_tests/test_memory_limits.py @@ -28,6 +28,7 @@ def test_simple(mem_per_worker): mem_per_worker=mem_per_worker, cores_per_worker=0.1, suppress_failure=True, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/scaling_tests/htex_local.py b/parsl/tests/scaling_tests/htex_local.py index ff6ed5b96b..1037d194c4 100644 --- a/parsl/tests/scaling_tests/htex_local.py +++ b/parsl/tests/scaling_tests/htex_local.py @@ -10,6 +10,7 @@ label="htex_local", cores_per_worker=1, max_workers=8, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/sites/test_affinity.py b/parsl/tests/sites/test_affinity.py index b2b588aa83..db3760916f 100644 --- a/parsl/tests/sites/test_affinity.py +++ b/parsl/tests/sites/test_affinity.py @@ -18,6 +18,7 @@ def local_config(): max_workers=2, cpu_affinity='block', available_accelerators=2, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/sites/test_concurrent.py b/parsl/tests/sites/test_concurrent.py index 01c72fd6b9..a80e61a626 100644 --- a/parsl/tests/sites/test_concurrent.py +++ b/parsl/tests/sites/test_concurrent.py @@ -17,6 +17,7 @@ def make_config(): max_workers=2, heartbeat_period=2, heartbeat_threshold=4, + encrypted=True, ) ], strategy='none', diff --git a/parsl/tests/sites/test_dynamic_executor.py b/parsl/tests/sites/test_dynamic_executor.py index 9c04d8b126..c3d449a5cc 100644 --- a/parsl/tests/sites/test_dynamic_executor.py +++ b/parsl/tests/sites/test_dynamic_executor.py @@ -60,6 +60,7 @@ def test_dynamic_executor(): label='htex_local', cores_per_worker=1, max_workers=5, + encrypted=True, provider=LocalProvider( init_blocks=1, max_blocks=1, diff --git a/parsl/tests/sites/test_worker_info.py b/parsl/tests/sites/test_worker_info.py index d9bb3e4c93..296ddfaf05 100644 --- a/parsl/tests/sites/test_worker_info.py +++ b/parsl/tests/sites/test_worker_info.py @@ -15,6 +15,7 @@ def local_config(): label="htex_Local", worker_debug=True, max_workers=4, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/test_htex/test_htex.py b/parsl/tests/test_htex/test_htex.py new file mode 100644 index 0000000000..ffe693c69e --- /dev/null +++ b/parsl/tests/test_htex/test_htex.py @@ -0,0 +1,46 @@ +import pathlib + +import pytest + +from parsl import curvezmq +from parsl import HighThroughputExecutor + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False)) +@pytest.mark.parametrize("cert_dir_provided", (True, False)) +def test_htex_start_encrypted( + encrypted: bool, cert_dir_provided: bool, tmpd_cwd: pathlib.Path +): + htex = HighThroughputExecutor(encrypted=encrypted) + htex.run_dir = str(tmpd_cwd) + if cert_dir_provided: + provided_base_dir = tmpd_cwd / "provided" + provided_base_dir.mkdir() + cert_dir = curvezmq.create_certificates(provided_base_dir) + htex.cert_dir = cert_dir + else: + cert_dir = curvezmq.create_certificates(htex.logdir) + + if not encrypted and cert_dir_provided: + with pytest.raises(AttributeError) as pyt_e: + htex.start() + assert "change cert_dir to None" in str(pyt_e.value) + return + + htex.start() + + assert htex.encrypted is encrypted + if encrypted: + assert htex.cert_dir == cert_dir + assert htex.outgoing_q.zmq_context.cert_dir == cert_dir + assert htex.incoming_q.zmq_context.cert_dir == cert_dir + assert htex.command_client.zmq_context.cert_dir == cert_dir + assert isinstance(htex.outgoing_q.zmq_context, curvezmq.ClientContext) + assert isinstance(htex.incoming_q.zmq_context, curvezmq.ClientContext) + assert isinstance(htex.command_client.zmq_context, curvezmq.ClientContext) + else: + assert htex.cert_dir is None + assert htex.outgoing_q.zmq_context.cert_dir is None + assert htex.incoming_q.zmq_context.cert_dir is None + assert htex.command_client.zmq_context.cert_dir is None diff --git a/parsl/tests/test_htex/test_htex_zmq_binding.py b/parsl/tests/test_htex/test_htex_zmq_binding.py index 442f0431b6..41de9055b0 100644 --- a/parsl/tests/test_htex/test_htex_zmq_binding.py +++ b/parsl/tests/test_htex/test_htex_zmq_binding.py @@ -1,42 +1,82 @@ -import logging +import pathlib +from typing import Optional +from unittest import mock import psutil import pytest import zmq +from parsl import curvezmq from parsl.executors.high_throughput.interchange import Interchange -def test_interchange_binding_no_address(): - ix = Interchange() +@pytest.fixture +def encrypted(request: pytest.FixtureRequest): + if hasattr(request, "param"): + return request.param + return True + + +@pytest.fixture +def cert_dir(encrypted: bool, tmpd_cwd: pathlib.Path): + if not encrypted: + return None + return curvezmq.create_certificates(tmpd_cwd) + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +@mock.patch.object(curvezmq.ServerContext, "socket", return_value=mock.MagicMock()) +def test_interchange_curvezmq_sockets( + mock_socket: mock.MagicMock, cert_dir: Optional[str], encrypted: bool +): + address = "127.0.0.1" + ix = Interchange(interchange_address=address, cert_dir=cert_dir) + assert isinstance(ix.zmq_context, curvezmq.ServerContext) + assert ix.zmq_context.encrypted is encrypted + assert mock_socket.call_count == 5 + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_no_address(cert_dir: Optional[str]): + ix = Interchange(cert_dir=cert_dir) assert ix.interchange_address == "*" -def test_interchange_binding_with_address(): +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_with_address(cert_dir: Optional[str]): # Using loopback address address = "127.0.0.1" - ix = Interchange(interchange_address=address) + ix = Interchange(interchange_address=address, cert_dir=cert_dir) assert ix.interchange_address == address -def test_interchange_binding_with_non_ipv4_address(): +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_with_non_ipv4_address(cert_dir: Optional[str]): # Confirm that a ipv4 address is required address = "localhost" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address) + Interchange(interchange_address=address, cert_dir=cert_dir) -def test_interchange_binding_bad_address(): - """ Confirm that we raise a ZMQError when a bad address is supplied""" +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_bad_address(cert_dir: Optional[str]): + """Confirm that we raise a ZMQError when a bad address is supplied""" address = "550.0.0.0" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address) + Interchange(interchange_address=address, cert_dir=cert_dir) -def test_limited_interface_binding(): - """ When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_limited_interface_binding(cert_dir: Optional[str]): + """When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" address = "127.0.0.1" - ix = Interchange(interchange_address=address) + ix = Interchange(interchange_address=address, cert_dir=cert_dir) ix.worker_result_port proc = psutil.Process() conns = proc.connections(kind="tcp") diff --git a/parsl/tests/test_regression/test_97_parallelism_0.py b/parsl/tests/test_regression/test_97_parallelism_0.py index ad47e3971f..de06bcd9df 100644 --- a/parsl/tests/test_regression/test_97_parallelism_0.py +++ b/parsl/tests/test_regression/test_97_parallelism_0.py @@ -15,6 +15,7 @@ def local_config() -> Config: label="htex_local", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( init_blocks=0, min_blocks=0, diff --git a/parsl/tests/test_scaling/test_block_error_handler.py b/parsl/tests/test_scaling/test_block_error_handler.py index 9d680212e3..7d9ffedd17 100644 --- a/parsl/tests/test_scaling/test_block_error_handler.py +++ b/parsl/tests/test_scaling/test_block_error_handler.py @@ -11,7 +11,7 @@ @pytest.mark.local def test_block_error_handler_false(): mock = Mock() - htex = HighThroughputExecutor(block_error_handler=False) + htex = HighThroughputExecutor(block_error_handler=False, encrypted=True) assert htex.block_error_handler is noop_error_handler htex.set_bad_state_and_fail_all = mock @@ -27,7 +27,7 @@ def test_block_error_handler_false(): @pytest.mark.local def test_block_error_handler_mock(): handler_mock = Mock() - htex = HighThroughputExecutor(block_error_handler=handler_mock) + htex = HighThroughputExecutor(block_error_handler=handler_mock, encrypted=True) assert htex.block_error_handler is handler_mock bad_jobs = {'1': JobStatus(JobState.FAILED), @@ -43,6 +43,7 @@ def test_block_error_handler_mock(): @pytest.mark.local def test_simple_error_handler(): htex = HighThroughputExecutor(block_error_handler=simple_error_handler, + encrypted=True, provider=LocalProvider(init_blocks=3)) assert htex.block_error_handler is simple_error_handler @@ -76,7 +77,7 @@ def test_simple_error_handler(): @pytest.mark.local def test_windowed_error_handler(): - htex = HighThroughputExecutor(block_error_handler=windowed_error_handler) + htex = HighThroughputExecutor(block_error_handler=windowed_error_handler, encrypted=True) assert htex.block_error_handler is windowed_error_handler bad_state_mock = Mock() @@ -110,7 +111,7 @@ def test_windowed_error_handler(): @pytest.mark.local def test_windowed_error_handler_sorting(): - htex = HighThroughputExecutor(block_error_handler=windowed_error_handler) + htex = HighThroughputExecutor(block_error_handler=windowed_error_handler, encrypted=True) assert htex.block_error_handler is windowed_error_handler bad_state_mock = Mock() @@ -136,7 +137,7 @@ def test_windowed_error_handler_sorting(): @pytest.mark.local def test_windowed_error_handler_with_threshold(): error_handler = partial(windowed_error_handler, threshold=2) - htex = HighThroughputExecutor(block_error_handler=error_handler) + htex = HighThroughputExecutor(block_error_handler=error_handler, encrypted=True) assert htex.block_error_handler is error_handler bad_state_mock = Mock() diff --git a/parsl/tests/test_scaling/test_regression_1621.py b/parsl/tests/test_scaling/test_regression_1621.py index 2a953c3fc7..bb367c379a 100644 --- a/parsl/tests/test_scaling/test_regression_1621.py +++ b/parsl/tests/test_scaling/test_regression_1621.py @@ -49,6 +49,7 @@ def test_one_block(tmpd_cwd): address="127.0.0.1", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=oneshot_provider, worker_logdir_root=str(tmpd_cwd) ) diff --git a/parsl/tests/test_scaling/test_scale_down.py b/parsl/tests/test_scaling/test_scale_down.py index 0a23a5acb5..6ebe441760 100644 --- a/parsl/tests/test_scaling/test_scale_down.py +++ b/parsl/tests/test_scaling/test_scale_down.py @@ -28,6 +28,7 @@ def local_config(): label="htex_local", address="127.0.0.1", max_workers=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=0,