From d99c66441ba2dbc2d3d120ac7c64e8f0f1f884fb Mon Sep 17 00:00:00 2001 From: Reid Mello <30907815+rjmello@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:46:33 -0500 Subject: [PATCH 1/3] Create CurveZMQ context classes `ServerContext` and `ClientContext` replace the standard `zmq.Context`. They share many commonly used methods, including `term`, `destroy` and, most importantly, `socket`. The latter applies the necessary certs and options to each socket object. A connection requires a `ServerContext` on one end, which validates clients, and a `ClientContext` on the other, which validates the server. Certificates are generated via the `create_certificates` function. --- parsl/curvezmq.py | 202 ++++++++++++++++ parsl/tests/test_curvezmq.py | 455 +++++++++++++++++++++++++++++++++++ 2 files changed, 657 insertions(+) create mode 100644 parsl/curvezmq.py create mode 100644 parsl/tests/test_curvezmq.py diff --git a/parsl/curvezmq.py b/parsl/curvezmq.py new file mode 100644 index 0000000000..cd220c818f --- /dev/null +++ b/parsl/curvezmq.py @@ -0,0 +1,202 @@ +import os +from abc import ABCMeta, abstractmethod +from typing import Optional, Tuple, Union + +import zmq +import zmq.auth +from zmq.auth.thread import ThreadAuthenticator + + +def create_certificates(base_dir: Union[str, os.PathLike]): + """Create server and client certificates in a private directory. + + This will overwrite existing certificate files. + + Parameters + ---------- + base_dir : str | os.PathLike + Parent directory of the private certificates directory. + """ + cert_dir = os.path.join(base_dir, "certificates") + os.makedirs(cert_dir, mode=0o700, exist_ok=True) + + zmq.auth.create_certificates(cert_dir, name="server") + zmq.auth.create_certificates(cert_dir, name="client") + + return cert_dir + + +def _load_certificate( + cert_dir: Union[str, os.PathLike], name: str +) -> Tuple[bytes, bytes]: + if os.stat(cert_dir).st_mode & 0o777 != 0o700: + raise OSError(f"The certificates directory must be private: {cert_dir}") + + # pyzmq creates secret key files with the '.key_secret' extension + # Ref: https://github.com/zeromq/pyzmq/blob/ae615d4097ccfbc6b5c17de60355cbe6e00a6065/zmq/auth/certs.py#L73 + secret_key_file = os.path.join(cert_dir, f"{name}.key_secret") + public_key, secret_key = zmq.auth.load_certificate(secret_key_file) + if secret_key is None: + raise ValueError(f"No secret key found in {secret_key_file}") + + return public_key, secret_key + + +class BaseContext(metaclass=ABCMeta): + """Base CurveZMQ context""" + + def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None: + self.cert_dir = cert_dir + self._ctx = zmq.Context() + + def __del__(self): + self.destroy() + + @property + def encrypted(self): + """Indicates whether encryption is enabled. + + False (disabled) when self.cert_dir is set to None. + """ + return self.cert_dir is not None + + @property + def closed(self): + return self._ctx.closed + + @abstractmethod + def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket: + """Create a socket associated with this context. + + This method will apply all necessary certificates and socket options. + + Parameters + ---------- + socket_type : int + The socket type, which can be any of the 0MQ socket types: REQ, REP, + PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc. + + args: + passed to the zmq.Context.socket method. + + kwargs: + passed to the zmq.Context.socket method. + """ + ... + + def term(self): + """Terminate the context.""" + self._ctx.term() + + def destroy(self, linger: Optional[int] = None): + """Close all sockets associated with this context and then terminate + the context. + + .. warning:: + + destroy involves calling ``zmq_close()``, which is **NOT** threadsafe. + If there are active sockets in other threads, this must not be called. + + Parameters + ---------- + linger : int, optional + If specified, set LINGER on sockets prior to closing them. + """ + self._ctx.destroy(linger) + + def recreate(self, linger: Optional[int] = None): + """Destroy then recreate the context. + + Parameters + ---------- + linger : int, optional + If specified, set LINGER on sockets prior to closing them. + """ + self.destroy(linger) + self._ctx = zmq.Context() + + +class ServerContext(BaseContext): + """CurveZMQ server context + + We create server sockets via the `ctx.socket` method, which automatically + applies the necessary certificates and socket options. + + We handle client certificate authentication in a separate dedicated thread. + + Parameters + ---------- + cert_dir : str | os.PathLike | None + Path to the certificate directory. Setting this to None will disable encryption. + """ + + def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None: + super().__init__(cert_dir) + self.auth_thread = None + if self.encrypted: + self.auth_thread = self._start_auth_thread() + + def _start_auth_thread(self) -> ThreadAuthenticator: + auth_thread = ThreadAuthenticator(self._ctx) + auth_thread.start() + # Only allow certs that are in the cert dir + assert self.cert_dir # For mypy + auth_thread.configure_curve(domain="*", location=self.cert_dir) + return auth_thread + + def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket: + sock = self._ctx.socket(socket_type, *args, **kwargs) + if self.encrypted: + assert self.cert_dir # For mypy + _, secret_key = _load_certificate(self.cert_dir, name="server") + try: + # Only the clients need the server's public key to encrypt + # messages and verify the server's identity. + # Ref: http://curvezmq.org/page:read-the-docs + sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key) + except zmq.ZMQError as e: + raise ValueError("Invalid CurveZMQ key format") from e + sock.setsockopt(zmq.CURVE_SERVER, True) # Must come before bind + return sock + + def term(self): + if self.auth_thread: + self.auth_thread.stop() + super().term() + + def destroy(self, linger: Optional[int] = None): + if self.auth_thread: + self.auth_thread.stop() + super().destroy(linger) + + def recreate(self, linger: Optional[int] = None): + super().recreate(linger) + if self.auth_thread: + self.auth_thread = self._start_auth_thread() + + +class ClientContext(BaseContext): + """CurveZMQ client context + + We create client sockets via the `ctx.socket` method, which automatically + applies the necessary certificates and socket options. + + Parameters + ---------- + cert_dir : str | os.PathLike | None + Path to the certificate directory. Setting this to None will disable encryption. + """ + + def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket: + sock = self._ctx.socket(socket_type, *args, **kwargs) + if self.encrypted: + assert self.cert_dir # For mypy + public_key, secret_key = _load_certificate(self.cert_dir, name="client") + server_public_key, _ = _load_certificate(self.cert_dir, name="server") + try: + sock.setsockopt(zmq.CURVE_PUBLICKEY, public_key) + sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key) + sock.setsockopt(zmq.CURVE_SERVERKEY, server_public_key) + except zmq.ZMQError as e: + raise ValueError("Invalid CurveZMQ key format") from e + return sock diff --git a/parsl/tests/test_curvezmq.py b/parsl/tests/test_curvezmq.py new file mode 100644 index 0000000000..38b206ab67 --- /dev/null +++ b/parsl/tests/test_curvezmq.py @@ -0,0 +1,455 @@ +import os +import pathlib +from typing import Union +from unittest import mock + +import pytest +import zmq +import zmq.auth +from zmq.auth.thread import ThreadAuthenticator + +from parsl import curvezmq + +ADDR = "tcp://127.0.0.1" + + +def get_server_socket(ctx: curvezmq.ServerContext): + sock = ctx.socket(zmq.PULL) + sock.setsockopt(zmq.RCVTIMEO, 200) + sock.setsockopt(zmq.LINGER, 0) + port = sock.bind_to_random_port(ADDR) + return sock, port + + +def get_client_socket(ctx: curvezmq.ClientContext, port: int): + sock = ctx.socket(zmq.PUSH) + sock.setsockopt(zmq.SNDTIMEO, 200) + sock.setsockopt(zmq.LINGER, 0) + sock.connect(f"{ADDR}:{port}") + return sock + + +def get_external_server_socket( + ctx: Union[curvezmq.ServerContext, zmq.Context], secret_key: bytes +): + sock = ctx.socket(zmq.PULL) + sock.setsockopt(zmq.RCVTIMEO, 200) + sock.setsockopt(zmq.LINGER, 0) + sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key) + sock.setsockopt(zmq.CURVE_SERVER, True) + port = sock.bind_to_random_port(ADDR) + return sock, port + + +def get_external_client_socket( + ctx: Union[curvezmq.ClientContext, zmq.Context], + public_key: bytes, + secret_key: bytes, + server_key: bytes, + port: int, +): + sock = ctx.socket(zmq.PUSH) + sock.setsockopt(zmq.LINGER, 0) + sock.setsockopt(zmq.CURVE_PUBLICKEY, public_key) + sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key) + sock.setsockopt(zmq.CURVE_SERVERKEY, server_key) + sock.connect(f"{ADDR}:{port}") + return sock + + +@pytest.fixture +def encrypted(request: pytest.FixtureRequest): + if hasattr(request, "param"): + return request.param + return True + + +@pytest.fixture +def cert_dir(encrypted: bool, tmpd_cwd: pathlib.Path): + if not encrypted: + return None + return curvezmq.create_certificates(tmpd_cwd) + + +@pytest.fixture +def server_ctx(cert_dir: Union[str, None]): + ctx = curvezmq.ServerContext(cert_dir) + yield ctx + ctx.destroy() + + +@pytest.fixture +def client_ctx(cert_dir: Union[str, None]): + ctx = curvezmq.ClientContext(cert_dir) + yield ctx + ctx.destroy() + + +@pytest.fixture +def zmq_ctx(): + ctx = zmq.Context() + yield ctx + ctx.destroy() + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_client_context_init(cert_dir: Union[str, None]): + ctx = curvezmq.ClientContext(cert_dir=cert_dir) + + assert ctx.cert_dir == cert_dir + if cert_dir is None: + assert not ctx.encrypted + else: + assert ctx.encrypted + + ctx.destroy() + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_server_context_init(cert_dir: Union[str, None]): + ctx = curvezmq.ServerContext(cert_dir=cert_dir) + + assert ctx.cert_dir == cert_dir + if cert_dir is None: + assert not ctx.encrypted + assert not ctx.auth_thread + else: + assert ctx.encrypted + assert isinstance(ctx.auth_thread, ThreadAuthenticator) + + ctx.destroy() + + +@pytest.mark.local +def test_create_certificates(tmpd_cwd: pathlib.Path): + cert_dir = tmpd_cwd / "certificates" + assert not os.path.exists(cert_dir) + + ret = curvezmq.create_certificates(tmpd_cwd) + + assert str(cert_dir) == ret + assert os.path.exists(cert_dir) + assert os.stat(cert_dir).st_mode & 0o777 == 0o700 + assert len(os.listdir(cert_dir)) == 4 + + +@pytest.mark.local +def test_create_certificates_overwrite(tmpd_cwd: pathlib.Path): + cert_dir = curvezmq.create_certificates(tmpd_cwd) + client_pub_1, client_sec_1 = curvezmq._load_certificate(cert_dir, name="client") + server_pub_1, server_sec_1 = curvezmq._load_certificate(cert_dir, name="server") + + curvezmq.create_certificates(tmpd_cwd) + client_pub_2, client_sec_2 = curvezmq._load_certificate(cert_dir, name="client") + server_pub_2, server_sec_2 = curvezmq._load_certificate(cert_dir, name="server") + + assert client_pub_1 != client_pub_2 + assert client_sec_1 != client_sec_2 + assert server_pub_1 != server_pub_2 + assert server_sec_1 != server_sec_2 + + +@pytest.mark.local +def test_cert_dir_not_private(tmpd_cwd: pathlib.Path): + cert_dir = tmpd_cwd / "certificates" + os.makedirs(cert_dir, mode=0o777) + client_ctx = curvezmq.ClientContext(cert_dir) + server_ctx = curvezmq.ServerContext(cert_dir) + + err_msg = "directory must be private" + + with pytest.raises(OSError) as pyt_e: + client_ctx.socket(zmq.REQ) + assert err_msg in str(pyt_e.value) + + with pytest.raises(OSError) as pyt_e: + server_ctx.socket(zmq.REP) + assert err_msg in str(pyt_e.value) + + client_ctx.destroy() + server_ctx.destroy() + + +@pytest.mark.local +def test_missing_cert_dir(): + cert_dir = "/bad/cert/dir" + client_ctx = curvezmq.ClientContext(cert_dir) + server_ctx = curvezmq.ServerContext(cert_dir) + + err_msg = "No such file or directory" + + with pytest.raises(FileNotFoundError) as pyt_e: + client_ctx.socket(zmq.REQ) + assert err_msg in str(pyt_e.value) + + with pytest.raises(FileNotFoundError) as pyt_e: + server_ctx.socket(zmq.REP) + assert err_msg in str(pyt_e.value) + + client_ctx.destroy() + server_ctx.destroy() + + +@pytest.mark.local +def test_missing_secret_file(tmpd_cwd: pathlib.Path): + cert_dir = tmpd_cwd / "certificates" + os.makedirs(cert_dir, mode=0o700) + + client_ctx = curvezmq.ClientContext(cert_dir) + server_ctx = curvezmq.ServerContext(cert_dir) + + err_msg = "Invalid certificate file" + + with pytest.raises(OSError) as pyt_e: + client_ctx.socket(zmq.REQ) + assert err_msg in str(pyt_e.value) + + with pytest.raises(OSError) as pyt_e: + server_ctx.socket(zmq.REP) + assert err_msg in str(pyt_e.value) + + client_ctx.destroy() + server_ctx.destroy() + + +@pytest.mark.local +def test_bad_secret_file(tmpd_cwd: pathlib.Path): + client_sec = tmpd_cwd / "client.key_secret" + server_sec = tmpd_cwd / "server.key_secret" + client_sec.write_text("bad") + server_sec.write_text("boy") + + client_ctx = curvezmq.ClientContext(tmpd_cwd) + server_ctx = curvezmq.ServerContext(tmpd_cwd) + + err_msg = "No public key found" + + with pytest.raises(ValueError) as pyt_e: + client_ctx.socket(zmq.REQ) + assert err_msg in str(pyt_e.value) + + with pytest.raises(ValueError) as pyt_e: + server_ctx.socket(zmq.REP) + assert err_msg in str(pyt_e.value) + + client_ctx.destroy() + server_ctx.destroy() + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_client_context_term(client_ctx: curvezmq.ClientContext): + assert not client_ctx.closed + + client_ctx.term() + + assert client_ctx.closed + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_server_context_term(server_ctx: curvezmq.ServerContext, encrypted: bool): + assert not server_ctx.closed + if encrypted: + assert server_ctx.auth_thread + assert server_ctx.auth_thread.pipe + + server_ctx.term() + + assert server_ctx.closed + if encrypted: + assert server_ctx.auth_thread + assert not server_ctx.auth_thread.pipe + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_client_context_destroy(client_ctx: curvezmq.ClientContext): + sock = client_ctx.socket(zmq.REP) + + assert not client_ctx.closed + + client_ctx.destroy() + + assert sock.closed + assert client_ctx.closed + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_server_context_destroy(server_ctx: curvezmq.ServerContext, encrypted: bool): + sock = server_ctx.socket(zmq.REP) + + assert not server_ctx.closed + if encrypted: + assert server_ctx.auth_thread + assert server_ctx.auth_thread.pipe + + server_ctx.destroy() + + assert sock.closed + assert server_ctx.closed + if encrypted: + assert server_ctx.auth_thread + assert not server_ctx.auth_thread.pipe + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_client_context_recreate(client_ctx: curvezmq.ClientContext): + hidden_ctx = client_ctx._ctx + sock = client_ctx.socket(zmq.REQ) + + assert not sock.closed + assert not client_ctx.closed + + client_ctx.recreate() + + assert sock.closed + assert not client_ctx.closed + assert hidden_ctx != client_ctx._ctx + assert hidden_ctx.closed + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_server_context_recreate(server_ctx: curvezmq.ServerContext, encrypted: bool): + hidden_ctx = server_ctx._ctx + sock = server_ctx.socket(zmq.REP) + + assert not sock.closed + assert not server_ctx.closed + if encrypted: + assert server_ctx.auth_thread + auth_thread = server_ctx.auth_thread + assert auth_thread.pipe + + server_ctx.recreate() + + assert sock.closed + assert not server_ctx.closed + assert hidden_ctx.closed + assert hidden_ctx != server_ctx._ctx + if encrypted: + assert server_ctx.auth_thread + assert auth_thread != server_ctx.auth_thread + assert server_ctx.auth_thread.pipe + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_connection( + server_ctx: curvezmq.ServerContext, client_ctx: curvezmq.ClientContext +): + server_socket, port = get_server_socket(server_ctx) + client_socket = get_client_socket(client_ctx, port) + + msg = b"howdy" + client_socket.send(msg) + recv = server_socket.recv() + + assert recv == msg + + +@pytest.mark.local +@mock.patch.object(curvezmq, "_load_certificate") +def test_invalid_key_format( + mock_load_cert: mock.MagicMock, + server_ctx: curvezmq.ServerContext, + client_ctx: curvezmq.ClientContext, +): + mock_load_cert.return_value = (b"badkey", b"badkey") + + with pytest.raises(ValueError) as e1_info: + server_ctx.socket(zmq.REP) + with pytest.raises(ValueError) as e2_info: + client_ctx.socket(zmq.REQ) + e1, e2 = e1_info.exconly, e2_info.exconly + + assert str(e1) == str(e2) + assert "Invalid CurveZMQ key format" in str(e1) + + +@pytest.mark.local +def test_invalid_client_keys(server_ctx: curvezmq.ServerContext, zmq_ctx: zmq.Context): + server_socket, port = get_server_socket(server_ctx) + + cert_dir = server_ctx.cert_dir + assert cert_dir # For mypy + public_key, secret_key = curvezmq._load_certificate(cert_dir, "client") + server_key, _ = curvezmq._load_certificate(cert_dir, "server") + + BAD_PUB_KEY, BAD_SEC_KEY = zmq.curve_keypair() + msg = b"howdy" + + client_socket = get_external_client_socket( + zmq_ctx, + public_key, + secret_key, + server_key, + port, + ) + client_socket.send(msg) + assert server_socket.recv() == msg + + client_socket = get_external_client_socket( + zmq_ctx, + BAD_PUB_KEY, + BAD_SEC_KEY, + server_key, + port, + ) + client_socket.send(msg) + with pytest.raises(zmq.Again): + server_socket.recv() + + client_socket = get_external_client_socket( + zmq_ctx, + public_key, + secret_key, + BAD_PUB_KEY, + port, + ) + client_socket.send(msg) + with pytest.raises(zmq.Again): + server_socket.recv() + + # Ensure sockets are operational + client_socket = get_external_client_socket( + zmq_ctx, + public_key, + secret_key, + server_key, + port, + ) + client_socket.send(msg) + assert server_socket.recv() == msg + + +@pytest.mark.local +def test_invalid_server_key(client_ctx: curvezmq.ClientContext, zmq_ctx: zmq.Context): + cert_dir = client_ctx.cert_dir + assert cert_dir # For mypy + _, secret_key = curvezmq._load_certificate(cert_dir, "server") + + _, BAD_SEC_KEY = zmq.curve_keypair() + msg = b"howdy" + + server_socket, port = get_external_server_socket(zmq_ctx, secret_key) + client_socket = get_client_socket(client_ctx, port) + client_socket.send(msg) + assert server_socket.recv() == msg + + server_socket, port = get_external_server_socket(zmq_ctx, BAD_SEC_KEY) + client_socket = get_client_socket(client_ctx, port) + client_socket.send(msg) + with pytest.raises(zmq.Again): + server_socket.recv() + + # Ensure sockets are operational + server_socket, port = get_external_server_socket(zmq_ctx, secret_key) + client_socket = get_client_socket(client_ctx, port) + client_socket.send(msg) + assert server_socket.recv() == msg From 47488c10af16b2ef17756c4e8438be962b0a332e Mon Sep 17 00:00:00 2001 From: Reid Mello <30907815+rjmello@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:46:43 -0500 Subject: [PATCH 2/3] Implement CurveZMQ in HTEX The interchange serves as a CurveZMQ server, while the executor and various managers serve as CurveZMQ clients. Thus, all communication between these entities is now encrypted. The HTEX `start` method generates new certs for each run in a private `certificates/` directory. We generate a single shared client cert because all clients will have access to this dir. We disable encryption by default, but users can enable it by setting the `encrypted` initialization argument for the HTEX to `True`. --- parsl/executors/high_throughput/executor.py | 53 ++++++++++++--- .../executors/high_throughput/interchange.py | 21 ++++-- .../high_throughput/process_worker_pool.py | 22 +++++-- parsl/executors/high_throughput/zmq_pipes.py | 36 +++++----- parsl/tests/configs/ad_hoc_cluster_htex.py | 1 + parsl/tests/configs/azure_single_node.py | 1 + parsl/tests/configs/bluewaters.py | 1 + parsl/tests/configs/bridges.py | 1 + parsl/tests/configs/cc_in2p3.py | 1 + parsl/tests/configs/comet.py | 1 + parsl/tests/configs/cooley_htex.py | 1 + parsl/tests/configs/ec2_single_node.py | 1 + parsl/tests/configs/ec2_spot.py | 1 + parsl/tests/configs/frontera.py | 1 + parsl/tests/configs/htex_ad_hoc_cluster.py | 1 + parsl/tests/configs/htex_local.py | 1 + parsl/tests/configs/htex_local_alternate.py | 1 + .../configs/htex_local_intask_staging.py | 1 + .../tests/configs/htex_local_rsync_staging.py | 1 + parsl/tests/configs/local_adhoc.py | 1 + parsl/tests/configs/midway.py | 1 + parsl/tests/configs/nscc_singapore.py | 1 + parsl/tests/configs/osg_htex.py | 1 + parsl/tests/configs/petrelkube.py | 1 + parsl/tests/configs/summit.py | 1 + parsl/tests/configs/swan_htex.py | 1 + parsl/tests/configs/theta.py | 1 + parsl/tests/manual_tests/htex_local.py | 1 + parsl/tests/manual_tests/test_ad_hoc_htex.py | 1 + .../test_fan_in_out_htex_remote.py | 1 + .../tests/manual_tests/test_memory_limits.py | 1 + parsl/tests/scaling_tests/htex_local.py | 1 + parsl/tests/sites/test_affinity.py | 1 + parsl/tests/sites/test_concurrent.py | 1 + parsl/tests/sites/test_dynamic_executor.py | 1 + parsl/tests/sites/test_worker_info.py | 1 + parsl/tests/test_htex/test_htex.py | 46 +++++++++++++ .../tests/test_htex/test_htex_zmq_binding.py | 66 +++++++++++++++---- .../test_regression/test_97_parallelism_0.py | 1 + .../test_scaling/test_block_error_handler.py | 11 ++-- .../test_scaling/test_regression_1621.py | 1 + parsl/tests/test_scaling/test_scale_down.py | 1 + 42 files changed, 233 insertions(+), 57 deletions(-) create mode 100644 parsl/tests/test_htex/test_htex.py diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index 2e63df5a72..46929f75a2 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -22,6 +22,7 @@ UnsupportedFeatureError ) +from parsl import curvezmq from parsl.executors.status_handling import BlockProviderExecutor from parsl.providers.base import ExecutionProvider from parsl.data_provider.staging import Staging @@ -174,6 +175,9 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin): worker_logdir_root : string In case of a remote file system, specify the path to where logs will be kept. + + encrypted : bool + Flag to enable/disable encryption (CurveZMQ). Default is False. """ @typeguard.typechecked @@ -199,7 +203,8 @@ def __init__(self, poll_period: int = 10, address_probe_timeout: Optional[int] = None, worker_logdir_root: Optional[str] = None, - block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True): + block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True, + encrypted: bool = False): logger.debug("Initializing HighThroughputExecutor") @@ -256,6 +261,8 @@ def __init__(self, self.run_dir = '.' self.worker_logdir_root = worker_logdir_root self.cpu_affinity = cpu_affinity + self.encrypted = encrypted + self.cert_dir = None if not launch_cmd: launch_cmd = ( @@ -267,6 +274,7 @@ def __init__(self, "--poll {poll_period} " "--task_port={task_port} " "--result_port={result_port} " + "--cert_dir {cert_dir} " "--logdir={logdir} " "--block_id={{block_id}} " "--hb_period={heartbeat_period} " @@ -280,6 +288,16 @@ def __init__(self, radio_mode = "htex" + @property + def logdir(self): + return "{}/{}".format(self.run_dir, self.label) + + @property + def worker_logdir(self): + if self.worker_logdir_root is not None: + return "{}/{}".format(self.worker_logdir_root, self.label) + return self.logdir + def initialize_scaling(self): """Compose the launch command and scale out the initial blocks. """ @@ -289,9 +307,6 @@ def initialize_scaling(self): address_probe_timeout_string = "" if self.address_probe_timeout: address_probe_timeout_string = "--address_probe_timeout={}".format(self.address_probe_timeout) - worker_logdir = "{}/{}".format(self.run_dir, self.label) - if self.worker_logdir_root is not None: - worker_logdir = "{}/{}".format(self.worker_logdir_root, self.label) l_cmd = self.launch_cmd.format(debug=debug_opts, prefetch_capacity=self.prefetch_capacity, @@ -306,7 +321,8 @@ def initialize_scaling(self): heartbeat_period=self.heartbeat_period, heartbeat_threshold=self.heartbeat_threshold, poll_period=self.poll_period, - logdir=worker_logdir, + cert_dir=self.cert_dir, + logdir=self.worker_logdir, cpu_affinity=self.cpu_affinity, accelerators=" ".join(self.available_accelerators)) self.launch_cmd = l_cmd @@ -327,9 +343,25 @@ def initialize_scaling(self): def start(self): """Create the Interchange process and connect to it. """ - self.outgoing_q = zmq_pipes.TasksOutgoing("127.0.0.1", self.interchange_port_range) - self.incoming_q = zmq_pipes.ResultsIncoming("127.0.0.1", self.interchange_port_range) - self.command_client = zmq_pipes.CommandClient("127.0.0.1", self.interchange_port_range) + if self.encrypted and self.cert_dir is None: + logger.debug("Creating CurveZMQ certificates") + self.cert_dir = curvezmq.create_certificates(self.logdir) + elif not self.encrypted and self.cert_dir: + raise AttributeError( + "The certificates directory path attribute (cert_dir) is defined, but the " + "encrypted attribute is set to False. You must either change cert_dir to " + "None or encrypted to True." + ) + + self.outgoing_q = zmq_pipes.TasksOutgoing( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) + self.incoming_q = zmq_pipes.ResultsIncoming( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) + self.command_client = zmq_pipes.CommandClient( + curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range + ) self._queue_management_thread = None self._start_queue_management_thread() @@ -450,10 +482,11 @@ def _start_local_interchange_process(self): "worker_port_range": self.worker_port_range, "hub_address": self.hub_address, "hub_port": self.hub_port, - "logdir": "{}/{}".format(self.run_dir, self.label), + "logdir": self.logdir, "heartbeat_threshold": self.heartbeat_threshold, "poll_period": self.poll_period, - "logging_level": logging.DEBUG if self.worker_debug else logging.INFO + "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, + "cert_dir": self.cert_dir, }, daemon=True, name="HTEX-Interchange" diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 67b70aa78d..c65aecb57a 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -16,6 +16,7 @@ from typing import cast, Any, Dict, NoReturn, Sequence, Set, Optional, Tuple +from parsl import curvezmq from parsl.utils import setproctitle from parsl.version import VERSION as PARSL_VERSION from parsl.serialize import serialize as serialize_object @@ -79,6 +80,7 @@ def __init__(self, logdir: str = ".", logging_level: int = logging.INFO, poll_period: int = 10, + cert_dir: Optional[str] = None, ) -> None: """ Parameters @@ -120,7 +122,10 @@ def __init__(self, poll_period : int The main thread polling period, in milliseconds. Default: 10ms + cert_dir : str | None + Path to the certificate directory. Default: None """ + self.cert_dir = cert_dir self.logdir = logdir os.makedirs(self.logdir, exist_ok=True) @@ -134,15 +139,15 @@ def __init__(self, logger.info("Attempting connection to client at {} on ports: {},{},{}".format( client_address, client_ports[0], client_ports[1], client_ports[2])) - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) + self.zmq_context = curvezmq.ServerContext(self.cert_dir) + self.task_incoming = self.zmq_context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) - self.results_outgoing = self.context.socket(zmq.DEALER) + self.results_outgoing = self.zmq_context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) - self.command_channel = self.context.socket(zmq.REP) + self.command_channel = self.zmq_context.socket(zmq.REP) self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) logger.info("Connected to client") @@ -155,9 +160,9 @@ def __init__(self, self.worker_ports = worker_ports self.worker_port_range = worker_port_range - self.task_outgoing = self.context.socket(zmq.ROUTER) + self.task_outgoing = self.zmq_context.socket(zmq.ROUTER) self.task_outgoing.set_hwm(0) - self.results_incoming = self.context.socket(zmq.ROUTER) + self.results_incoming = self.zmq_context.socket(zmq.ROUTER) self.results_incoming.set_hwm(0) if self.worker_ports: @@ -241,7 +246,8 @@ def task_puller(self) -> NoReturn: def _create_monitoring_channel(self) -> Optional[zmq.Socket]: if self.hub_address and self.hub_port: logger.info("Connecting to monitoring") - hub_channel = self.context.socket(zmq.DEALER) + # This is a one-off because monitoring is unencrypted + hub_channel = zmq.Context().socket(zmq.DEALER) hub_channel.set_hwm(0) hub_channel.connect("tcp://{}:{}".format(self.hub_address, self.hub_port)) logger.info("Monitoring enabled and connected to hub") @@ -379,6 +385,7 @@ def start(self) -> None: self.expire_bad_managers(interesting_managers, hub_channel) self.process_tasks_to_send(interesting_managers) + self.zmq_context.destroy() delta = time.time() - start logger.info("Processed {} tasks in {} seconds".format(self.count, delta)) logger.warning("Exiting") diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index 665a3ef33d..7826bec3dd 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -20,8 +20,8 @@ from multiprocessing.managers import DictProxy from multiprocessing.sharedctypes import Synchronized +from parsl import curvezmq from parsl.process_loggers import wrap_with_logs - from parsl.version import VERSION as PARSL_VERSION from parsl.app.errors import RemoteExceptionWrapper from parsl.executors.high_throughput.errors import WorkerLost @@ -63,7 +63,8 @@ def __init__(self, *, heartbeat_period, poll_period, cpu_affinity, - available_accelerators: Sequence[str]): + available_accelerators: Sequence[str], + cert_dir: Optional[str]): """ Parameters ---------- @@ -118,6 +119,8 @@ def __init__(self, *, available_accelerators: list of str List of accelerators available to the workers. + cert_dir : str | None + Path to the certificate directory. """ logger.info("Manager started") @@ -137,15 +140,16 @@ def __init__(self, *, print("Failed to find a viable address to connect to interchange. Exiting") exit(5) - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) + self.cert_dir = cert_dir + self.zmq_context = curvezmq.ClientContext(self.cert_dir) + self.task_incoming = self.zmq_context.socket(zmq.DEALER) self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) # Linger is set to 0, so that the manager can exit even when there might be # messages in the pipe self.task_incoming.setsockopt(zmq.LINGER, 0) self.task_incoming.connect(task_q_url) - self.result_outgoing = self.context.socket(zmq.DEALER) + self.result_outgoing = self.zmq_context.socket(zmq.DEALER) self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) self.result_outgoing.setsockopt(zmq.LINGER, 0) self.result_outgoing.connect(result_q_url) @@ -468,7 +472,7 @@ def start(self): self.task_incoming.close() self.result_outgoing.close() - self.context.term() + self.zmq_context.term() delta = time.time() - start logger.info("process_worker_pool ran for {} seconds".format(delta)) return @@ -720,6 +724,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_ help="Enable logging at DEBUG level") parser.add_argument("-a", "--addresses", default='', help="Comma separated list of addresses at which the interchange could be reached") + parser.add_argument("--cert_dir", required=True, + help="Path to certificate directory.") parser.add_argument("-l", "--logdir", default="process_worker_pool_logs", help="Process worker pool log directory") parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1], @@ -773,6 +779,7 @@ def strategyorlist(s: str): logger.info("Python version: {}".format(sys.version)) logger.info("Debug logging: {}".format(args.debug)) + logger.info("Certificates dir: {}".format(args.cert_dir)) logger.info("Log dir: {}".format(args.logdir)) logger.info("Manager ID: {}".format(args.uid)) logger.info("Block ID: {}".format(args.block_id)) @@ -804,7 +811,8 @@ def strategyorlist(s: str): heartbeat_period=int(args.hb_period), poll_period=int(args.poll), cpu_affinity=args.cpu_affinity, - available_accelerators=args.available_accelerators) + available_accelerators=args.available_accelerators, + cert_dir=None if args.cert_dir == "None" else args.cert_dir) manager.start() except Exception: diff --git a/parsl/executors/high_throughput/zmq_pipes.py b/parsl/executors/high_throughput/zmq_pipes.py index cd5c7f2e8b..72c1422d4d 100644 --- a/parsl/executors/high_throughput/zmq_pipes.py +++ b/parsl/executors/high_throughput/zmq_pipes.py @@ -4,24 +4,28 @@ import logging import threading +from parsl import curvezmq + logger = logging.getLogger(__name__) class CommandClient: """ CommandClient """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() + self.zmq_context = zmq_context self.ip_address = ip_address self.port_range = port_range self.port = None @@ -33,7 +37,7 @@ def create_socket_and_bind(self): Upon recreating the socket, we bind to the same port. """ - self.zmq_socket = self.context.socket(zmq.REQ) + self.zmq_socket = self.zmq_context.socket(zmq.REQ) self.zmq_socket.setsockopt(zmq.LINGER, 0) if self.port is None: self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(self.ip_address), @@ -62,9 +66,7 @@ def run(self, message, max_retries=3): except zmq.ZMQError: logger.exception("Potential ZMQ REQ-REP deadlock caught") logger.info("Trying to reestablish context") - self.zmq_socket.close() - self.context.destroy() - self.context = zmq.Context() + self.zmq_context.recreate() self.create_socket_and_bind() else: break @@ -77,25 +79,27 @@ def run(self, message, max_retries=3): def close(self): self.zmq_socket.close() - self.context.term() + self.zmq_context.term() class TasksOutgoing: """ Outgoing task queue from the executor to the Interchange """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() - self.zmq_socket = self.context.socket(zmq.DEALER) + self.zmq_context = zmq_context + self.zmq_socket = self.zmq_context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address), min_port=port_range[0], @@ -127,26 +131,28 @@ def put(self, message): def close(self): self.zmq_socket.close() - self.context.term() + self.zmq_context.term() class ResultsIncoming: """ Incoming results queue from the Interchange to the executor """ - def __init__(self, ip_address, port_range): + def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): """ Parameters ---------- + zmq_context: curvezmq.ClientContext + CurveZMQ client context used to create secure sockets ip_address: str IP address of the client (where Parsl runs) port_range: tuple(int, int) Port range for the comms between client and interchange """ - self.context = zmq.Context() - self.results_receiver = self.context.socket(zmq.DEALER) + self.zmq_context = zmq_context + self.results_receiver = self.zmq_context.socket(zmq.DEALER) self.results_receiver.set_hwm(0) self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address), min_port=port_range[0], @@ -160,4 +166,4 @@ def get(self): def close(self): self.results_receiver.close() - self.context.term() + self.zmq_context.term() diff --git a/parsl/tests/configs/ad_hoc_cluster_htex.py b/parsl/tests/configs/ad_hoc_cluster_htex.py index 5c90d27918..0a3e9dc027 100644 --- a/parsl/tests/configs/ad_hoc_cluster_htex.py +++ b/parsl/tests/configs/ad_hoc_cluster_htex.py @@ -18,6 +18,7 @@ label='remote_htex', max_workers=2, worker_logdir_root=user_opts['adhoc']['script_dir'], + encrypted=True, provider=AdHocProvider( # Command to be run before starting a worker, such as: # 'module load Anaconda; source activate parsl_env'. diff --git a/parsl/tests/configs/azure_single_node.py b/parsl/tests/configs/azure_single_node.py index 90e90c19cd..17f3e7def9 100644 --- a/parsl/tests/configs/azure_single_node.py +++ b/parsl/tests/configs/azure_single_node.py @@ -40,6 +40,7 @@ storage_access=[HTTPInTaskStaging(), FTPInTaskStaging(), RSyncStaging(getpass.getuser() + "@" + user_opts['public_ip'])], label='azure_single_node', address=user_opts['public_ip'], + encrypted=True, provider=provider ) ] diff --git a/parsl/tests/configs/bluewaters.py b/parsl/tests/configs/bluewaters.py index 2e27eb8c4a..7cd088ecac 100644 --- a/parsl/tests/configs/bluewaters.py +++ b/parsl/tests/configs/bluewaters.py @@ -14,6 +14,7 @@ def fresh_config(): cores_per_worker=1, worker_debug=False, max_workers=1, + encrypted=True, provider=TorqueProvider( queue='normal', launcher=AprunLauncher(overrides="-b -- bwpy-environ --"), diff --git a/parsl/tests/configs/bridges.py b/parsl/tests/configs/bridges.py index 4dea2fe468..06d0c0cd43 100644 --- a/parsl/tests/configs/bridges.py +++ b/parsl/tests/configs/bridges.py @@ -14,6 +14,7 @@ def fresh_config(): # which compute nodes can communicate # address=address_by_interface('bond0.144'), max_workers=1, + encrypted=True, provider=SlurmProvider( user_opts['bridges']['partition'], # Partition / QOS nodes_per_block=2, diff --git a/parsl/tests/configs/cc_in2p3.py b/parsl/tests/configs/cc_in2p3.py index 9b54fa37e4..9a76d1f16e 100644 --- a/parsl/tests/configs/cc_in2p3.py +++ b/parsl/tests/configs/cc_in2p3.py @@ -12,6 +12,7 @@ def fresh_config(): HighThroughputExecutor( label='cc_in2p3_htex', max_workers=1, + encrypted=True, provider=GridEngineProvider( channel=LocalChannel(), nodes_per_block=2, diff --git a/parsl/tests/configs/comet.py b/parsl/tests/configs/comet.py index 1a253e26b2..8f39539509 100644 --- a/parsl/tests/configs/comet.py +++ b/parsl/tests/configs/comet.py @@ -11,6 +11,7 @@ def fresh_config(): HighThroughputExecutor( label='Comet_HTEX_multinode', max_workers=1, + encrypted=True, provider=SlurmProvider( 'debug', launcher=SrunLauncher(), diff --git a/parsl/tests/configs/cooley_htex.py b/parsl/tests/configs/cooley_htex.py index 4228da10b0..202379d0af 100644 --- a/parsl/tests/configs/cooley_htex.py +++ b/parsl/tests/configs/cooley_htex.py @@ -18,6 +18,7 @@ label="cooley_htex", worker_debug=False, cores_per_worker=1, + encrypted=True, provider=CobaltProvider( queue='debug', account=user_opts['cooley']['account'], diff --git a/parsl/tests/configs/ec2_single_node.py b/parsl/tests/configs/ec2_single_node.py index 61cccc8bea..92c8108add 100644 --- a/parsl/tests/configs/ec2_single_node.py +++ b/parsl/tests/configs/ec2_single_node.py @@ -28,6 +28,7 @@ HighThroughputExecutor( label='ec2_single_node', address=user_opts['public_ip'], + encrypted=True, provider=AWSProvider( user_opts['ec2']['image_id'], region=user_opts['ec2']['region'], diff --git a/parsl/tests/configs/ec2_spot.py b/parsl/tests/configs/ec2_spot.py index c693d9b17d..37f272e4de 100644 --- a/parsl/tests/configs/ec2_spot.py +++ b/parsl/tests/configs/ec2_spot.py @@ -15,6 +15,7 @@ HighThroughputExecutor( label='ec2_single_node', address=user_opts['public_ip'], + encrypted=True, provider=AWSProvider( user_opts['ec2']['image_id'], region=user_opts['ec2']['region'], diff --git a/parsl/tests/configs/frontera.py b/parsl/tests/configs/frontera.py index a003ce88c6..ba302d85b5 100644 --- a/parsl/tests/configs/frontera.py +++ b/parsl/tests/configs/frontera.py @@ -16,6 +16,7 @@ def fresh_config(): HighThroughputExecutor( label="frontera_htex", max_workers=1, + encrypted=True, provider=SlurmProvider( cmd_timeout=60, # Add extra time for slow scheduler responses channel=LocalChannel(), diff --git a/parsl/tests/configs/htex_ad_hoc_cluster.py b/parsl/tests/configs/htex_ad_hoc_cluster.py index 80949f1d1e..82ae7bf621 100644 --- a/parsl/tests/configs/htex_ad_hoc_cluster.py +++ b/parsl/tests/configs/htex_ad_hoc_cluster.py @@ -13,6 +13,7 @@ cores_per_worker=1, worker_debug=False, address=user_opts['public_ip'], + encrypted=True, provider=AdHocProvider( move_files=False, parallelism=1, diff --git a/parsl/tests/configs/htex_local.py b/parsl/tests/configs/htex_local.py index 2039918a4d..0f09fce9dc 100644 --- a/parsl/tests/configs/htex_local.py +++ b/parsl/tests/configs/htex_local.py @@ -13,6 +13,7 @@ def fresh_config(): label="htex_local", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/configs/htex_local_alternate.py b/parsl/tests/configs/htex_local_alternate.py index e3a7b6ff1e..2598a91b85 100644 --- a/parsl/tests/configs/htex_local_alternate.py +++ b/parsl/tests/configs/htex_local_alternate.py @@ -48,6 +48,7 @@ def fresh_config(): heartbeat_period=2, heartbeat_threshold=5, poll_period=100, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=0, diff --git a/parsl/tests/configs/htex_local_intask_staging.py b/parsl/tests/configs/htex_local_intask_staging.py index c6ad88be69..634cbf1654 100644 --- a/parsl/tests/configs/htex_local_intask_staging.py +++ b/parsl/tests/configs/htex_local_intask_staging.py @@ -15,6 +15,7 @@ label="htex_Local", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/configs/htex_local_rsync_staging.py b/parsl/tests/configs/htex_local_rsync_staging.py index 514fbfaee7..6cb47e55e9 100644 --- a/parsl/tests/configs/htex_local_rsync_staging.py +++ b/parsl/tests/configs/htex_local_rsync_staging.py @@ -16,6 +16,7 @@ worker_debug=True, cores_per_worker=1, working_dir="./rsync-workdir/", + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/configs/local_adhoc.py b/parsl/tests/configs/local_adhoc.py index 7ac6d7e781..96c08ee309 100644 --- a/parsl/tests/configs/local_adhoc.py +++ b/parsl/tests/configs/local_adhoc.py @@ -9,6 +9,7 @@ def fresh_config(): executors=[ HighThroughputExecutor( label='AdHoc', + encrypted=True, provider=AdHocProvider( channels=[LocalChannel(), LocalChannel()] ) diff --git a/parsl/tests/configs/midway.py b/parsl/tests/configs/midway.py index b5362838e0..ca32edd2b1 100644 --- a/parsl/tests/configs/midway.py +++ b/parsl/tests/configs/midway.py @@ -13,6 +13,7 @@ def fresh_config(): label='Midway_HTEX_multinode', worker_debug=False, max_workers=1, + encrypted=True, provider=SlurmProvider( 'broadwl', # Partition name, e.g 'broadwl' launcher=SrunLauncher(), diff --git a/parsl/tests/configs/nscc_singapore.py b/parsl/tests/configs/nscc_singapore.py index d8abbd45ba..c78018dc55 100644 --- a/parsl/tests/configs/nscc_singapore.py +++ b/parsl/tests/configs/nscc_singapore.py @@ -17,6 +17,7 @@ def fresh_config(): worker_debug=False, max_workers=1, address=address_by_interface('ib0'), + encrypted=True, provider=PBSProProvider( launcher=MpiRunLauncher(), # string to prepend to #PBS blocks in the submit diff --git a/parsl/tests/configs/osg_htex.py b/parsl/tests/configs/osg_htex.py index e4f1d4a98f..17250cca5d 100644 --- a/parsl/tests/configs/osg_htex.py +++ b/parsl/tests/configs/osg_htex.py @@ -14,6 +14,7 @@ HighThroughputExecutor( label='OSG_HTEX', max_workers=1, + encrypted=True, provider=CondorProvider( nodes_per_block=1, init_blocks=4, diff --git a/parsl/tests/configs/petrelkube.py b/parsl/tests/configs/petrelkube.py index dbc2afec39..6e61cd101e 100644 --- a/parsl/tests/configs/petrelkube.py +++ b/parsl/tests/configs/petrelkube.py @@ -23,6 +23,7 @@ def fresh_config(): # Address for the pod worker to connect back address=address_by_route(), + encrypted=True, provider=KubernetesProvider( namespace="dlhub-privileged", diff --git a/parsl/tests/configs/summit.py b/parsl/tests/configs/summit.py index ae0d009334..ed87366e55 100644 --- a/parsl/tests/configs/summit.py +++ b/parsl/tests/configs/summit.py @@ -21,6 +21,7 @@ def fresh_config(): # address=address_by_interface('ib0'), # This assumes Parsl is running on login node worker_port_range=(50000, 55000), max_workers=1, + encrypted=True, provider=LSFProvider( launcher=JsrunLauncher(), walltime="00:10:00", diff --git a/parsl/tests/configs/swan_htex.py b/parsl/tests/configs/swan_htex.py index 8db9f5a975..4884703a2a 100644 --- a/parsl/tests/configs/swan_htex.py +++ b/parsl/tests/configs/swan_htex.py @@ -24,6 +24,7 @@ executors=[ HighThroughputExecutor( label='swan_htex', + encrypted=True, provider=TorqueProvider( channel=SSHChannel( hostname='swan.cray.com', diff --git a/parsl/tests/configs/theta.py b/parsl/tests/configs/theta.py index f2ce169390..36092a948e 100644 --- a/parsl/tests/configs/theta.py +++ b/parsl/tests/configs/theta.py @@ -12,6 +12,7 @@ def fresh_config(): HighThroughputExecutor( label='theta_local_htex_multinode', max_workers=1, + encrypted=True, provider=CobaltProvider( queue=user_opts['theta']['queue'], account=user_opts['theta']['account'], diff --git a/parsl/tests/manual_tests/htex_local.py b/parsl/tests/manual_tests/htex_local.py index a2cd3fae7f..a5c0449be1 100644 --- a/parsl/tests/manual_tests/htex_local.py +++ b/parsl/tests/manual_tests/htex_local.py @@ -13,6 +13,7 @@ label="htex_local", # worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/manual_tests/test_ad_hoc_htex.py b/parsl/tests/manual_tests/test_ad_hoc_htex.py index d022686d6e..030658f243 100644 --- a/parsl/tests/manual_tests/test_ad_hoc_htex.py +++ b/parsl/tests/manual_tests/test_ad_hoc_htex.py @@ -15,6 +15,7 @@ label='AdHoc', max_workers=2, worker_logdir_root="/scratch/midway2/yadunand/parsl_scripts", + encrypted=True, provider=AdHocProvider( worker_init="source /scratch/midway2/yadunand/parsl_env_setup.sh", channels=[SSHChannel(hostname=m, diff --git a/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py b/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py index 72dc4c7fd7..ce09e4bcad 100644 --- a/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py +++ b/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py @@ -16,6 +16,7 @@ def local_setup(): label="theta_htex", # worker_debug=True, cores_per_worker=4, + encrypted=True, provider=CobaltProvider( queue='debug-flat-quad', account='CSC249ADCD01', diff --git a/parsl/tests/manual_tests/test_memory_limits.py b/parsl/tests/manual_tests/test_memory_limits.py index 9024939827..be4507808f 100644 --- a/parsl/tests/manual_tests/test_memory_limits.py +++ b/parsl/tests/manual_tests/test_memory_limits.py @@ -28,6 +28,7 @@ def test_simple(mem_per_worker): mem_per_worker=mem_per_worker, cores_per_worker=0.1, suppress_failure=True, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/scaling_tests/htex_local.py b/parsl/tests/scaling_tests/htex_local.py index ff6ed5b96b..1037d194c4 100644 --- a/parsl/tests/scaling_tests/htex_local.py +++ b/parsl/tests/scaling_tests/htex_local.py @@ -10,6 +10,7 @@ label="htex_local", cores_per_worker=1, max_workers=8, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/sites/test_affinity.py b/parsl/tests/sites/test_affinity.py index b2b588aa83..db3760916f 100644 --- a/parsl/tests/sites/test_affinity.py +++ b/parsl/tests/sites/test_affinity.py @@ -18,6 +18,7 @@ def local_config(): max_workers=2, cpu_affinity='block', available_accelerators=2, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/sites/test_concurrent.py b/parsl/tests/sites/test_concurrent.py index 01c72fd6b9..a80e61a626 100644 --- a/parsl/tests/sites/test_concurrent.py +++ b/parsl/tests/sites/test_concurrent.py @@ -17,6 +17,7 @@ def make_config(): max_workers=2, heartbeat_period=2, heartbeat_threshold=4, + encrypted=True, ) ], strategy='none', diff --git a/parsl/tests/sites/test_dynamic_executor.py b/parsl/tests/sites/test_dynamic_executor.py index 9c04d8b126..c3d449a5cc 100644 --- a/parsl/tests/sites/test_dynamic_executor.py +++ b/parsl/tests/sites/test_dynamic_executor.py @@ -60,6 +60,7 @@ def test_dynamic_executor(): label='htex_local', cores_per_worker=1, max_workers=5, + encrypted=True, provider=LocalProvider( init_blocks=1, max_blocks=1, diff --git a/parsl/tests/sites/test_worker_info.py b/parsl/tests/sites/test_worker_info.py index d9bb3e4c93..296ddfaf05 100644 --- a/parsl/tests/sites/test_worker_info.py +++ b/parsl/tests/sites/test_worker_info.py @@ -15,6 +15,7 @@ def local_config(): label="htex_Local", worker_debug=True, max_workers=4, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=1, diff --git a/parsl/tests/test_htex/test_htex.py b/parsl/tests/test_htex/test_htex.py new file mode 100644 index 0000000000..ffe693c69e --- /dev/null +++ b/parsl/tests/test_htex/test_htex.py @@ -0,0 +1,46 @@ +import pathlib + +import pytest + +from parsl import curvezmq +from parsl import HighThroughputExecutor + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False)) +@pytest.mark.parametrize("cert_dir_provided", (True, False)) +def test_htex_start_encrypted( + encrypted: bool, cert_dir_provided: bool, tmpd_cwd: pathlib.Path +): + htex = HighThroughputExecutor(encrypted=encrypted) + htex.run_dir = str(tmpd_cwd) + if cert_dir_provided: + provided_base_dir = tmpd_cwd / "provided" + provided_base_dir.mkdir() + cert_dir = curvezmq.create_certificates(provided_base_dir) + htex.cert_dir = cert_dir + else: + cert_dir = curvezmq.create_certificates(htex.logdir) + + if not encrypted and cert_dir_provided: + with pytest.raises(AttributeError) as pyt_e: + htex.start() + assert "change cert_dir to None" in str(pyt_e.value) + return + + htex.start() + + assert htex.encrypted is encrypted + if encrypted: + assert htex.cert_dir == cert_dir + assert htex.outgoing_q.zmq_context.cert_dir == cert_dir + assert htex.incoming_q.zmq_context.cert_dir == cert_dir + assert htex.command_client.zmq_context.cert_dir == cert_dir + assert isinstance(htex.outgoing_q.zmq_context, curvezmq.ClientContext) + assert isinstance(htex.incoming_q.zmq_context, curvezmq.ClientContext) + assert isinstance(htex.command_client.zmq_context, curvezmq.ClientContext) + else: + assert htex.cert_dir is None + assert htex.outgoing_q.zmq_context.cert_dir is None + assert htex.incoming_q.zmq_context.cert_dir is None + assert htex.command_client.zmq_context.cert_dir is None diff --git a/parsl/tests/test_htex/test_htex_zmq_binding.py b/parsl/tests/test_htex/test_htex_zmq_binding.py index 442f0431b6..41de9055b0 100644 --- a/parsl/tests/test_htex/test_htex_zmq_binding.py +++ b/parsl/tests/test_htex/test_htex_zmq_binding.py @@ -1,42 +1,82 @@ -import logging +import pathlib +from typing import Optional +from unittest import mock import psutil import pytest import zmq +from parsl import curvezmq from parsl.executors.high_throughput.interchange import Interchange -def test_interchange_binding_no_address(): - ix = Interchange() +@pytest.fixture +def encrypted(request: pytest.FixtureRequest): + if hasattr(request, "param"): + return request.param + return True + + +@pytest.fixture +def cert_dir(encrypted: bool, tmpd_cwd: pathlib.Path): + if not encrypted: + return None + return curvezmq.create_certificates(tmpd_cwd) + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +@mock.patch.object(curvezmq.ServerContext, "socket", return_value=mock.MagicMock()) +def test_interchange_curvezmq_sockets( + mock_socket: mock.MagicMock, cert_dir: Optional[str], encrypted: bool +): + address = "127.0.0.1" + ix = Interchange(interchange_address=address, cert_dir=cert_dir) + assert isinstance(ix.zmq_context, curvezmq.ServerContext) + assert ix.zmq_context.encrypted is encrypted + assert mock_socket.call_count == 5 + + +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_no_address(cert_dir: Optional[str]): + ix = Interchange(cert_dir=cert_dir) assert ix.interchange_address == "*" -def test_interchange_binding_with_address(): +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_with_address(cert_dir: Optional[str]): # Using loopback address address = "127.0.0.1" - ix = Interchange(interchange_address=address) + ix = Interchange(interchange_address=address, cert_dir=cert_dir) assert ix.interchange_address == address -def test_interchange_binding_with_non_ipv4_address(): +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_with_non_ipv4_address(cert_dir: Optional[str]): # Confirm that a ipv4 address is required address = "localhost" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address) + Interchange(interchange_address=address, cert_dir=cert_dir) -def test_interchange_binding_bad_address(): - """ Confirm that we raise a ZMQError when a bad address is supplied""" +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_interchange_binding_bad_address(cert_dir: Optional[str]): + """Confirm that we raise a ZMQError when a bad address is supplied""" address = "550.0.0.0" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address) + Interchange(interchange_address=address, cert_dir=cert_dir) -def test_limited_interface_binding(): - """ When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" +@pytest.mark.local +@pytest.mark.parametrize("encrypted", (True, False), indirect=True) +def test_limited_interface_binding(cert_dir: Optional[str]): + """When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" address = "127.0.0.1" - ix = Interchange(interchange_address=address) + ix = Interchange(interchange_address=address, cert_dir=cert_dir) ix.worker_result_port proc = psutil.Process() conns = proc.connections(kind="tcp") diff --git a/parsl/tests/test_regression/test_97_parallelism_0.py b/parsl/tests/test_regression/test_97_parallelism_0.py index ad47e3971f..de06bcd9df 100644 --- a/parsl/tests/test_regression/test_97_parallelism_0.py +++ b/parsl/tests/test_regression/test_97_parallelism_0.py @@ -15,6 +15,7 @@ def local_config() -> Config: label="htex_local", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=LocalProvider( init_blocks=0, min_blocks=0, diff --git a/parsl/tests/test_scaling/test_block_error_handler.py b/parsl/tests/test_scaling/test_block_error_handler.py index 9d680212e3..7d9ffedd17 100644 --- a/parsl/tests/test_scaling/test_block_error_handler.py +++ b/parsl/tests/test_scaling/test_block_error_handler.py @@ -11,7 +11,7 @@ @pytest.mark.local def test_block_error_handler_false(): mock = Mock() - htex = HighThroughputExecutor(block_error_handler=False) + htex = HighThroughputExecutor(block_error_handler=False, encrypted=True) assert htex.block_error_handler is noop_error_handler htex.set_bad_state_and_fail_all = mock @@ -27,7 +27,7 @@ def test_block_error_handler_false(): @pytest.mark.local def test_block_error_handler_mock(): handler_mock = Mock() - htex = HighThroughputExecutor(block_error_handler=handler_mock) + htex = HighThroughputExecutor(block_error_handler=handler_mock, encrypted=True) assert htex.block_error_handler is handler_mock bad_jobs = {'1': JobStatus(JobState.FAILED), @@ -43,6 +43,7 @@ def test_block_error_handler_mock(): @pytest.mark.local def test_simple_error_handler(): htex = HighThroughputExecutor(block_error_handler=simple_error_handler, + encrypted=True, provider=LocalProvider(init_blocks=3)) assert htex.block_error_handler is simple_error_handler @@ -76,7 +77,7 @@ def test_simple_error_handler(): @pytest.mark.local def test_windowed_error_handler(): - htex = HighThroughputExecutor(block_error_handler=windowed_error_handler) + htex = HighThroughputExecutor(block_error_handler=windowed_error_handler, encrypted=True) assert htex.block_error_handler is windowed_error_handler bad_state_mock = Mock() @@ -110,7 +111,7 @@ def test_windowed_error_handler(): @pytest.mark.local def test_windowed_error_handler_sorting(): - htex = HighThroughputExecutor(block_error_handler=windowed_error_handler) + htex = HighThroughputExecutor(block_error_handler=windowed_error_handler, encrypted=True) assert htex.block_error_handler is windowed_error_handler bad_state_mock = Mock() @@ -136,7 +137,7 @@ def test_windowed_error_handler_sorting(): @pytest.mark.local def test_windowed_error_handler_with_threshold(): error_handler = partial(windowed_error_handler, threshold=2) - htex = HighThroughputExecutor(block_error_handler=error_handler) + htex = HighThroughputExecutor(block_error_handler=error_handler, encrypted=True) assert htex.block_error_handler is error_handler bad_state_mock = Mock() diff --git a/parsl/tests/test_scaling/test_regression_1621.py b/parsl/tests/test_scaling/test_regression_1621.py index 2a953c3fc7..bb367c379a 100644 --- a/parsl/tests/test_scaling/test_regression_1621.py +++ b/parsl/tests/test_scaling/test_regression_1621.py @@ -49,6 +49,7 @@ def test_one_block(tmpd_cwd): address="127.0.0.1", worker_debug=True, cores_per_worker=1, + encrypted=True, provider=oneshot_provider, worker_logdir_root=str(tmpd_cwd) ) diff --git a/parsl/tests/test_scaling/test_scale_down.py b/parsl/tests/test_scaling/test_scale_down.py index 0a23a5acb5..6ebe441760 100644 --- a/parsl/tests/test_scaling/test_scale_down.py +++ b/parsl/tests/test_scaling/test_scale_down.py @@ -28,6 +28,7 @@ def local_config(): label="htex_local", address="127.0.0.1", max_workers=1, + encrypted=True, provider=LocalProvider( channel=LocalChannel(), init_blocks=0, From 431daefada930a7f2f1789f9d53b88b133cfc0fa Mon Sep 17 00:00:00 2001 From: Reid Mello <30907815+rjmello@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:32:59 -0500 Subject: [PATCH 3/3] Add section about HTEX encryption to docs --- docs/userguide/execution.rst | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/docs/userguide/execution.rst b/docs/userguide/execution.rst index 0601b38802..1a61e40e73 100644 --- a/docs/userguide/execution.rst +++ b/docs/userguide/execution.rst @@ -342,3 +342,50 @@ The following code snippet shows how apps can specify suitable executors in the def visualize(inputs=(), outputs=()): bash_array = " ".join(inputs) return "viz {} -o {}".format(bash_array, outputs[0]) + + +Encryption +---------- + +Users can enable encryption for the ``HighThroughputExecutor`` by setting its ``encrypted`` +initialization argument to ``True``. + +For example, + +.. code-block:: python + + from parsl.config import Config + from parsl.executors import HighThroughputExecutor + + config = Config( + executors=[ + HighThroughputExecutor( + encrypted=True + ) + ] + ) + +Under the hood, we use `CurveZMQ `_ to encrypt all communication channels +between the executor and related nodes. + +Encryption performance +^^^^^^^^^^^^^^^^^^^^^^ + +CurveZMQ depends on `libzmq `_ and `libsodium `_, +which `pyzmq `_ (a Parsl dependency) includes as part of its +installation via ``pip``. This installation path should work on most systems, but users have +reported significant performance degradation as a result. + +If you experience a significant performance hit after enabling encryption, we recommend installing +``pyzmq`` with conda: + +.. code-block:: bash + + conda install conda-forge::pyzmq + +Alternatively, you can `install libsodium `_, then +`install libzmq `_, then build ``pyzmq`` from source: + +.. code-block:: bash + + pip3 install parsl --no-binary pyzmq