diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index 2e63df5a72..bcb1975826 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -22,6 +22,7 @@ UnsupportedFeatureError ) +from parsl import curvezmq from parsl.executors.status_handling import BlockProviderExecutor from parsl.providers.base import ExecutionProvider from parsl.data_provider.staging import Staging @@ -174,6 +175,9 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin): worker_logdir_root : string In case of a remote file system, specify the path to where logs will be kept. + + encrypted : bool + Flag to enable/disable encryption. Default is True. """ @typeguard.typechecked @@ -199,7 +203,8 @@ def __init__(self, poll_period: int = 10, address_probe_timeout: Optional[int] = None, worker_logdir_root: Optional[str] = None, - block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True): + block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True, + encrypted: bool = True): logger.debug("Initializing HighThroughputExecutor") @@ -256,6 +261,8 @@ def __init__(self, self.run_dir = '.' self.worker_logdir_root = worker_logdir_root self.cpu_affinity = cpu_affinity + self.encrypted = encrypted + self.cert_dir = None if not launch_cmd: launch_cmd = ( @@ -267,6 +274,7 @@ def __init__(self, "--poll {poll_period} " "--task_port={task_port} " "--result_port={result_port} " + "--cert_dir {cert_dir} " "--logdir={logdir} " "--block_id={{block_id}} " "--hb_period={heartbeat_period} " @@ -280,6 +288,16 @@ def __init__(self, radio_mode = "htex" + @property + def logdir(self): + return "{}/{}".format(self.run_dir, self.label) + + @property + def worker_logdir(self): + if self.worker_logdir_root is not None: + return "{}/{}".format(self.worker_logdir_root, self.label) + return self.logdir + def initialize_scaling(self): """Compose the launch command and scale out the initial blocks. """ @@ -289,9 +307,6 @@ def initialize_scaling(self): address_probe_timeout_string = "" if self.address_probe_timeout: address_probe_timeout_string = "--address_probe_timeout={}".format(self.address_probe_timeout) - worker_logdir = "{}/{}".format(self.run_dir, self.label) - if self.worker_logdir_root is not None: - worker_logdir = "{}/{}".format(self.worker_logdir_root, self.label) l_cmd = self.launch_cmd.format(debug=debug_opts, prefetch_capacity=self.prefetch_capacity, @@ -306,7 +321,8 @@ def initialize_scaling(self): heartbeat_period=self.heartbeat_period, heartbeat_threshold=self.heartbeat_threshold, poll_period=self.poll_period, - logdir=worker_logdir, + cert_dir=self.cert_dir, + logdir=self.worker_logdir, cpu_affinity=self.cpu_affinity, accelerators=" ".join(self.available_accelerators)) self.launch_cmd = l_cmd @@ -327,9 +343,25 @@ def initialize_scaling(self): def start(self): """Create the Interchange process and connect to it. """ - self.outgoing_q = zmq_pipes.TasksOutgoing("127.0.0.1", self.interchange_port_range) - self.incoming_q = zmq_pipes.ResultsIncoming("127.0.0.1", self.interchange_port_range) - self.command_client = zmq_pipes.CommandClient("127.0.0.1", self.interchange_port_range) + if self.encrypted and self.cert_dir is None: + logger.debug("Creating CurveZMQ certificates") + self.cert_dir = curvezmq.create_certificates(self.logdir) + elif not self.encrypted and self.cert_dir: + raise AttributeError( + "The certificates directory path attribute (cert_dir) is defined, but the " + "encrypted attribute is set to False. You must either change cert_dir to " + "None or encrypted to True." + ) + + self.outgoing_q = zmq_pipes.TasksOutgoing( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) + self.incoming_q = zmq_pipes.ResultsIncoming( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) + self.command_client = zmq_pipes.CommandClient( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) self._queue_management_thread = None self._start_queue_management_thread() @@ -450,10 +482,11 @@ def _start_local_interchange_process(self): "worker_port_range": self.worker_port_range, "hub_address": self.hub_address, "hub_port": self.hub_port, - "logdir": "{}/{}".format(self.run_dir, self.label), + "logdir": self.logdir, "heartbeat_threshold": self.heartbeat_threshold, "poll_period": self.poll_period, - "logging_level": logging.DEBUG if self.worker_debug else logging.INFO + "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, + "cert_dir": self.cert_dir, }, daemon=True, name="HTEX-Interchange" diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 67b70aa78d..483d1c4e21 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -16,6 +16,7 @@ from typing import cast, Any, Dict, NoReturn, Sequence, Set, Optional, Tuple +from parsl import curvezmq from parsl.utils import setproctitle from parsl.version import VERSION as PARSL_VERSION from parsl.serialize import serialize as serialize_object @@ -79,6 +80,7 @@ def __init__(self, logdir: str = ".", logging_level: int = logging.INFO, poll_period: int = 10, + cert_dir: Optional[str] = "certificates", ) -> None: """ Parameters @@ -120,7 +122,10 @@ def __init__(self, poll_period : int The main thread polling period, in milliseconds. Default: 10ms + cert_dir : str | None + Path to the certificate directory. Default: 'certificates' """ + self.cert_dir = cert_dir self.logdir = logdir os.makedirs(self.logdir, exist_ok=True) @@ -134,15 +139,15 @@ def __init__(self, logger.info("Attempting connection to client at {} on ports: {},{},{}".format( client_address, client_ports[0], client_ports[1], client_ports[2])) - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) + self.zmq_context = curvezmq.ServerContext(self.cert_dir) + self.task_incoming = self.zmq_context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) - self.results_outgoing = self.context.socket(zmq.DEALER) + self.results_outgoing = self.zmq_context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) - self.command_channel = self.context.socket(zmq.REP) + self.command_channel = self.zmq_context.socket(zmq.REP) self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) logger.info("Connected to client") @@ -155,9 +160,9 @@ def __init__(self, self.worker_ports = worker_ports self.worker_port_range = worker_port_range - self.task_outgoing = self.context.socket(zmq.ROUTER) + self.task_outgoing = self.zmq_context.socket(zmq.ROUTER) self.task_outgoing.set_hwm(0) - self.results_incoming = self.context.socket(zmq.ROUTER) + self.results_incoming = self.zmq_context.socket(zmq.ROUTER) self.results_incoming.set_hwm(0) if self.worker_ports: @@ -241,7 +246,8 @@ def task_puller(self) -> NoReturn: def _create_monitoring_channel(self) -> Optional[zmq.Socket]: if self.hub_address and self.hub_port: logger.info("Connecting to monitoring") - hub_channel = self.context.socket(zmq.DEALER) + # This is a one-off because monitoring is unencrypted + hub_channel = zmq.Context().socket(zmq.DEALER) hub_channel.set_hwm(0) hub_channel.connect("tcp://{}:{}".format(self.hub_address, self.hub_port)) logger.info("Monitoring enabled and connected to hub") @@ -379,6 +385,7 @@ def start(self) -> None: self.expire_bad_managers(interesting_managers, hub_channel) self.process_tasks_to_send(interesting_managers) + self.zmq_context.destroy() delta = time.time() - start logger.info("Processed {} tasks in {} seconds".format(self.count, delta)) logger.warning("Exiting") diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index 57cff2c783..aed9b1e402 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -20,8 +20,8 @@ from multiprocessing.managers import DictProxy from multiprocessing.sharedctypes import Synchronized +from parsl import curvezmq from parsl.process_loggers import wrap_with_logs - from parsl.version import VERSION as PARSL_VERSION from parsl.app.errors import RemoteExceptionWrapper from parsl.executors.high_throughput.errors import WorkerLost @@ -63,7 +63,8 @@ def __init__(self, *, heartbeat_period, poll_period, cpu_affinity, - available_accelerators: Sequence[str]): + available_accelerators: Sequence[str], + cert_dir: Optional[str]): """ Parameters ---------- @@ -118,6 +119,8 @@ def __init__(self, *, available_accelerators: list of str List of accelerators available to the workers. + cert_dir : str | None + Path to the certificate directory. """ logger.info("Manager started") @@ -137,15 +140,16 @@ def __init__(self, *, print("Failed to find a viable address to connect to interchange. Exiting") exit(5) - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) + self.cert_dir = cert_dir + self.zmq_context = curvezmq.ClientContext(self.cert_dir) + self.task_incoming = self.zmq_context.socket(zmq.DEALER) self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) # Linger is set to 0, so that the manager can exit even when there might be # messages in the pipe self.task_incoming.setsockopt(zmq.LINGER, 0) self.task_incoming.connect(task_q_url) - self.result_outgoing = self.context.socket(zmq.DEALER) + self.result_outgoing = self.zmq_context.socket(zmq.DEALER) self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) self.result_outgoing.setsockopt(zmq.LINGER, 0) self.result_outgoing.connect(result_q_url) @@ -468,7 +472,7 @@ def start(self): self.task_incoming.close() self.result_outgoing.close() - self.context.term() + self.zmq_context.term() delta = time.time() - start logger.info("process_worker_pool ran for {} seconds".format(delta)) return @@ -703,6 +707,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_ help="Enable logging at DEBUG level") parser.add_argument("-a", "--addresses", default='', help="Comma separated list of addresses at which the interchange could be reached") + parser.add_argument("--cert_dir", required=True, + help="Path to certificate directory.") parser.add_argument("-l", "--logdir", default="process_worker_pool_logs", help="Process worker pool log directory") parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1], @@ -746,6 +752,7 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_ logger.info("Python version: {}".format(sys.version)) logger.info("Debug logging: {}".format(args.debug)) + logger.info("Certificates dir: {}".format(args.cert_dir)) logger.info("Log dir: {}".format(args.logdir)) logger.info("Manager ID: {}".format(args.uid)) logger.info("Block ID: {}".format(args.block_id)) @@ -777,7 +784,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_ heartbeat_period=int(args.hb_period), poll_period=int(args.poll), cpu_affinity=args.cpu_affinity, - available_accelerators=args.available_accelerators) + available_accelerators=args.available_accelerators, + cert_dir=None if args.cert_dir == "None" else args.cert_dir) manager.start() except Exception: diff --git a/parsl/executors/high_throughput/zmq_pipes.py b/parsl/executors/high_throughput/zmq_pipes.py index cd5c7f2e8b..72c1422d4d 100644 --- a/parsl/executors/high_throughput/zmq_pipes.py +++ b/parsl/executors/high_throughput/zmq_pipes.py @@ -4,24 +4,28 @@ import logging import threading +from parsl import curvezmq + logger = logging.getLogger(__name__) class CommandClient: """ CommandClient """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() + self.zmq_context = zmq_context self.ip_address = ip_address self.port_range = port_range self.port = None @@ -33,7 +37,7 @@ def create_socket_and_bind(self): Upon recreating the socket, we bind to the same port. """ - self.zmq_socket = self.context.socket(zmq.REQ) + self.zmq_socket = self.zmq_context.socket(zmq.REQ) self.zmq_socket.setsockopt(zmq.LINGER, 0) if self.port is None: self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(self.ip_address), @@ -62,9 +66,7 @@ def run(self, message, max_retries=3): except zmq.ZMQError: logger.exception("Potential ZMQ REQ-REP deadlock caught") logger.info("Trying to reestablish context") - self.zmq_socket.close() - self.context.destroy() - self.context = zmq.Context() + self.zmq_context.recreate() self.create_socket_and_bind() else: break @@ -77,25 +79,27 @@ def run(self, message, max_retries=3): def close(self): self.zmq_socket.close() - self.context.term() + self.zmq_context.term() class TasksOutgoing: """ Outgoing task queue from the executor to the Interchange """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() - self.zmq_socket = self.context.socket(zmq.DEALER) + self.zmq_context = zmq_context + self.zmq_socket = self.zmq_context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address), min_port=port_range[0], @@ -127,26 +131,28 @@ def put(self, message): def close(self): self.zmq_socket.close() - self.context.term() + self.zmq_context.term() class ResultsIncoming: """ Incoming results queue from the Interchange to the executor """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() - self.results_receiver = self.context.socket(zmq.DEALER) + self.zmq_context = zmq_context + self.results_receiver = self.zmq_context.socket(zmq.DEALER) self.results_receiver.set_hwm(0) self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address), min_port=port_range[0], @@ -160,4 +166,4 @@ def get(self): def close(self): self.results_receiver.close() - self.context.term() + self.zmq_context.term() diff --git a/parsl/tests/test_htex/test_htex.py b/parsl/tests/test_htex/test_htex.py new file mode 100644 index 0000000000..ffe693c69e --- /dev/null +++ b/parsl/tests/test_htex/test_htex.py @@ -0,0 +1,46 @@ +import pathlib + +import pytest + +from parsl import curvezmq +from parsl import HighThroughputExecutor + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False)) +@pytest.mark.parametrize("cert_dir_provided", (True, False)) +def test_htex_start_encrypted( + encrypted: bool, cert_dir_provided: bool, tmpd_cwd: pathlib.Path +): + htex = HighThroughputExecutor(encrypted=encrypted) + htex.run_dir = str(tmpd_cwd) + if cert_dir_provided: + provided_base_dir = tmpd_cwd / "provided" + provided_base_dir.mkdir() + cert_dir = curvezmq.create_certificates(provided_base_dir) + htex.cert_dir = cert_dir + else: + cert_dir = curvezmq.create_certificates(htex.logdir) + + if not encrypted and cert_dir_provided: + with pytest.raises(AttributeError) as pyt_e: + htex.start() + assert "change cert_dir to None" in str(pyt_e.value) + return + + htex.start() + + assert htex.encrypted is encrypted + if encrypted: + assert htex.cert_dir == cert_dir + assert htex.outgoing_q.zmq_context.cert_dir == cert_dir + assert htex.incoming_q.zmq_context.cert_dir == cert_dir + assert htex.command_client.zmq_context.cert_dir == cert_dir + assert isinstance(htex.outgoing_q.zmq_context, curvezmq.ClientContext) + assert isinstance(htex.incoming_q.zmq_context, curvezmq.ClientContext) + assert isinstance(htex.command_client.zmq_context, curvezmq.ClientContext) + else: + assert htex.cert_dir is None + assert htex.outgoing_q.zmq_context.cert_dir is None + assert htex.incoming_q.zmq_context.cert_dir is None + assert htex.command_client.zmq_context.cert_dir is None diff --git a/parsl/tests/test_htex/test_htex_zmq_binding.py b/parsl/tests/test_htex/test_htex_zmq_binding.py index 442f0431b6..41de9055b0 100644 --- a/parsl/tests/test_htex/test_htex_zmq_binding.py +++ b/parsl/tests/test_htex/test_htex_zmq_binding.py @@ -1,42 +1,82 @@ -import logging +import pathlib +from typing import Optional +from unittest import mock import psutil import pytest import zmq +from parsl import curvezmq from parsl.executors.high_throughput.interchange import Interchange -def test_interchange_binding_no_address(): - ix = Interchange() +@pytest.fixture +def encrypted(request: pytest.FixtureRequest): + if hasattr(request, "param"): + return request.param + return True + + +@pytest.fixture +def cert_dir(encrypted: bool, tmpd_cwd: pathlib.Path): + if not encrypted: + return None + return curvezmq.create_certificates(tmpd_cwd) + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +@mock.patch.object(curvezmq.ServerContext, "socket", return_value=mock.MagicMock()) +def test_interchange_curvezmq_sockets( + mock_socket: mock.MagicMock, cert_dir: Optional[str], encrypted: bool +): + address = "127.0.0.1" + ix = Interchange(interchange_address=address, cert_dir=cert_dir) + assert isinstance(ix.zmq_context, curvezmq.ServerContext) + assert ix.zmq_context.encrypted is encrypted + assert mock_socket.call_count == 5 + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_no_address(cert_dir: Optional[str]): + ix = Interchange(cert_dir=cert_dir) assert ix.interchange_address == "*" -def test_interchange_binding_with_address(): +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_with_address(cert_dir: Optional[str]): # Using loopback address address = "127.0.0.1" - ix = Interchange(interchange_address=address) + ix = Interchange(interchange_address=address, cert_dir=cert_dir) assert ix.interchange_address == address -def test_interchange_binding_with_non_ipv4_address(): +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_with_non_ipv4_address(cert_dir: Optional[str]): # Confirm that a ipv4 address is required address = "localhost" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address) + Interchange(interchange_address=address, cert_dir=cert_dir) -def test_interchange_binding_bad_address(): - """ Confirm that we raise a ZMQError when a bad address is supplied""" +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_bad_address(cert_dir: Optional[str]): + """Confirm that we raise a ZMQError when a bad address is supplied""" address = "550.0.0.0" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address) + Interchange(interchange_address=address, cert_dir=cert_dir) -def test_limited_interface_binding(): - """ When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_limited_interface_binding(cert_dir: Optional[str]): + """When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" address = "127.0.0.1" - ix = Interchange(interchange_address=address) + ix = Interchange(interchange_address=address, cert_dir=cert_dir) ix.worker_result_port proc = psutil.Process() conns = proc.connections(kind="tcp")