From d99c66441ba2dbc2d3d120ac7c64e8f0f1f884fb Mon Sep 17 00:00:00 2001
From: Reid Mello <30907815+rjmello@users.noreply.github.com>
Date: Tue, 16 Jan 2024 15:46:33 -0500
Subject: [PATCH 1/3] Create CurveZMQ context classes
`ServerContext` and `ClientContext` replace the standard `zmq.Context`.
They share many commonly used methods, including `term`, `destroy` and,
most importantly, `socket`. The latter applies the necessary certs and
options to each socket object.
A connection requires a `ServerContext` on one end, which validates
clients, and a `ClientContext` on the other, which validates the server.
Certificates are generated via the `create_certificates` function.
---
parsl/curvezmq.py | 202 ++++++++++++++++
parsl/tests/test_curvezmq.py | 455 +++++++++++++++++++++++++++++++++++
2 files changed, 657 insertions(+)
create mode 100644 parsl/curvezmq.py
create mode 100644 parsl/tests/test_curvezmq.py
diff --git a/parsl/curvezmq.py b/parsl/curvezmq.py
new file mode 100644
index 0000000000..cd220c818f
--- /dev/null
+++ b/parsl/curvezmq.py
@@ -0,0 +1,202 @@
+import os
+from abc import ABCMeta, abstractmethod
+from typing import Optional, Tuple, Union
+
+import zmq
+import zmq.auth
+from zmq.auth.thread import ThreadAuthenticator
+
+
+def create_certificates(base_dir: Union[str, os.PathLike]):
+ """Create server and client certificates in a private directory.
+
+ This will overwrite existing certificate files.
+
+ Parameters
+ ----------
+ base_dir : str | os.PathLike
+ Parent directory of the private certificates directory.
+ """
+ cert_dir = os.path.join(base_dir, "certificates")
+ os.makedirs(cert_dir, mode=0o700, exist_ok=True)
+
+ zmq.auth.create_certificates(cert_dir, name="server")
+ zmq.auth.create_certificates(cert_dir, name="client")
+
+ return cert_dir
+
+
+def _load_certificate(
+ cert_dir: Union[str, os.PathLike], name: str
+) -> Tuple[bytes, bytes]:
+ if os.stat(cert_dir).st_mode & 0o777 != 0o700:
+ raise OSError(f"The certificates directory must be private: {cert_dir}")
+
+ # pyzmq creates secret key files with the '.key_secret' extension
+ # Ref: https://github.com/zeromq/pyzmq/blob/ae615d4097ccfbc6b5c17de60355cbe6e00a6065/zmq/auth/certs.py#L73
+ secret_key_file = os.path.join(cert_dir, f"{name}.key_secret")
+ public_key, secret_key = zmq.auth.load_certificate(secret_key_file)
+ if secret_key is None:
+ raise ValueError(f"No secret key found in {secret_key_file}")
+
+ return public_key, secret_key
+
+
+class BaseContext(metaclass=ABCMeta):
+ """Base CurveZMQ context"""
+
+ def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None:
+ self.cert_dir = cert_dir
+ self._ctx = zmq.Context()
+
+ def __del__(self):
+ self.destroy()
+
+ @property
+ def encrypted(self):
+ """Indicates whether encryption is enabled.
+
+ False (disabled) when self.cert_dir is set to None.
+ """
+ return self.cert_dir is not None
+
+ @property
+ def closed(self):
+ return self._ctx.closed
+
+ @abstractmethod
+ def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
+ """Create a socket associated with this context.
+
+ This method will apply all necessary certificates and socket options.
+
+ Parameters
+ ----------
+ socket_type : int
+ The socket type, which can be any of the 0MQ socket types: REQ, REP,
+ PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc.
+
+ args:
+ passed to the zmq.Context.socket method.
+
+ kwargs:
+ passed to the zmq.Context.socket method.
+ """
+ ...
+
+ def term(self):
+ """Terminate the context."""
+ self._ctx.term()
+
+ def destroy(self, linger: Optional[int] = None):
+ """Close all sockets associated with this context and then terminate
+ the context.
+
+ .. warning::
+
+ destroy involves calling ``zmq_close()``, which is **NOT** threadsafe.
+ If there are active sockets in other threads, this must not be called.
+
+ Parameters
+ ----------
+ linger : int, optional
+ If specified, set LINGER on sockets prior to closing them.
+ """
+ self._ctx.destroy(linger)
+
+ def recreate(self, linger: Optional[int] = None):
+ """Destroy then recreate the context.
+
+ Parameters
+ ----------
+ linger : int, optional
+ If specified, set LINGER on sockets prior to closing them.
+ """
+ self.destroy(linger)
+ self._ctx = zmq.Context()
+
+
+class ServerContext(BaseContext):
+ """CurveZMQ server context
+
+ We create server sockets via the `ctx.socket` method, which automatically
+ applies the necessary certificates and socket options.
+
+ We handle client certificate authentication in a separate dedicated thread.
+
+ Parameters
+ ----------
+ cert_dir : str | os.PathLike | None
+ Path to the certificate directory. Setting this to None will disable encryption.
+ """
+
+ def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None:
+ super().__init__(cert_dir)
+ self.auth_thread = None
+ if self.encrypted:
+ self.auth_thread = self._start_auth_thread()
+
+ def _start_auth_thread(self) -> ThreadAuthenticator:
+ auth_thread = ThreadAuthenticator(self._ctx)
+ auth_thread.start()
+ # Only allow certs that are in the cert dir
+ assert self.cert_dir # For mypy
+ auth_thread.configure_curve(domain="*", location=self.cert_dir)
+ return auth_thread
+
+ def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
+ sock = self._ctx.socket(socket_type, *args, **kwargs)
+ if self.encrypted:
+ assert self.cert_dir # For mypy
+ _, secret_key = _load_certificate(self.cert_dir, name="server")
+ try:
+ # Only the clients need the server's public key to encrypt
+ # messages and verify the server's identity.
+ # Ref: http://curvezmq.org/page:read-the-docs
+ sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
+ except zmq.ZMQError as e:
+ raise ValueError("Invalid CurveZMQ key format") from e
+ sock.setsockopt(zmq.CURVE_SERVER, True) # Must come before bind
+ return sock
+
+ def term(self):
+ if self.auth_thread:
+ self.auth_thread.stop()
+ super().term()
+
+ def destroy(self, linger: Optional[int] = None):
+ if self.auth_thread:
+ self.auth_thread.stop()
+ super().destroy(linger)
+
+ def recreate(self, linger: Optional[int] = None):
+ super().recreate(linger)
+ if self.auth_thread:
+ self.auth_thread = self._start_auth_thread()
+
+
+class ClientContext(BaseContext):
+ """CurveZMQ client context
+
+ We create client sockets via the `ctx.socket` method, which automatically
+ applies the necessary certificates and socket options.
+
+ Parameters
+ ----------
+ cert_dir : str | os.PathLike | None
+ Path to the certificate directory. Setting this to None will disable encryption.
+ """
+
+ def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
+ sock = self._ctx.socket(socket_type, *args, **kwargs)
+ if self.encrypted:
+ assert self.cert_dir # For mypy
+ public_key, secret_key = _load_certificate(self.cert_dir, name="client")
+ server_public_key, _ = _load_certificate(self.cert_dir, name="server")
+ try:
+ sock.setsockopt(zmq.CURVE_PUBLICKEY, public_key)
+ sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
+ sock.setsockopt(zmq.CURVE_SERVERKEY, server_public_key)
+ except zmq.ZMQError as e:
+ raise ValueError("Invalid CurveZMQ key format") from e
+ return sock
diff --git a/parsl/tests/test_curvezmq.py b/parsl/tests/test_curvezmq.py
new file mode 100644
index 0000000000..38b206ab67
--- /dev/null
+++ b/parsl/tests/test_curvezmq.py
@@ -0,0 +1,455 @@
+import os
+import pathlib
+from typing import Union
+from unittest import mock
+
+import pytest
+import zmq
+import zmq.auth
+from zmq.auth.thread import ThreadAuthenticator
+
+from parsl import curvezmq
+
+ADDR = "tcp://127.0.0.1"
+
+
+def get_server_socket(ctx: curvezmq.ServerContext):
+ sock = ctx.socket(zmq.PULL)
+ sock.setsockopt(zmq.RCVTIMEO, 200)
+ sock.setsockopt(zmq.LINGER, 0)
+ port = sock.bind_to_random_port(ADDR)
+ return sock, port
+
+
+def get_client_socket(ctx: curvezmq.ClientContext, port: int):
+ sock = ctx.socket(zmq.PUSH)
+ sock.setsockopt(zmq.SNDTIMEO, 200)
+ sock.setsockopt(zmq.LINGER, 0)
+ sock.connect(f"{ADDR}:{port}")
+ return sock
+
+
+def get_external_server_socket(
+ ctx: Union[curvezmq.ServerContext, zmq.Context], secret_key: bytes
+):
+ sock = ctx.socket(zmq.PULL)
+ sock.setsockopt(zmq.RCVTIMEO, 200)
+ sock.setsockopt(zmq.LINGER, 0)
+ sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
+ sock.setsockopt(zmq.CURVE_SERVER, True)
+ port = sock.bind_to_random_port(ADDR)
+ return sock, port
+
+
+def get_external_client_socket(
+ ctx: Union[curvezmq.ClientContext, zmq.Context],
+ public_key: bytes,
+ secret_key: bytes,
+ server_key: bytes,
+ port: int,
+):
+ sock = ctx.socket(zmq.PUSH)
+ sock.setsockopt(zmq.LINGER, 0)
+ sock.setsockopt(zmq.CURVE_PUBLICKEY, public_key)
+ sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
+ sock.setsockopt(zmq.CURVE_SERVERKEY, server_key)
+ sock.connect(f"{ADDR}:{port}")
+ return sock
+
+
+@pytest.fixture
+def encrypted(request: pytest.FixtureRequest):
+ if hasattr(request, "param"):
+ return request.param
+ return True
+
+
+@pytest.fixture
+def cert_dir(encrypted: bool, tmpd_cwd: pathlib.Path):
+ if not encrypted:
+ return None
+ return curvezmq.create_certificates(tmpd_cwd)
+
+
+@pytest.fixture
+def server_ctx(cert_dir: Union[str, None]):
+ ctx = curvezmq.ServerContext(cert_dir)
+ yield ctx
+ ctx.destroy()
+
+
+@pytest.fixture
+def client_ctx(cert_dir: Union[str, None]):
+ ctx = curvezmq.ClientContext(cert_dir)
+ yield ctx
+ ctx.destroy()
+
+
+@pytest.fixture
+def zmq_ctx():
+ ctx = zmq.Context()
+ yield ctx
+ ctx.destroy()
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_client_context_init(cert_dir: Union[str, None]):
+ ctx = curvezmq.ClientContext(cert_dir=cert_dir)
+
+ assert ctx.cert_dir == cert_dir
+ if cert_dir is None:
+ assert not ctx.encrypted
+ else:
+ assert ctx.encrypted
+
+ ctx.destroy()
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_server_context_init(cert_dir: Union[str, None]):
+ ctx = curvezmq.ServerContext(cert_dir=cert_dir)
+
+ assert ctx.cert_dir == cert_dir
+ if cert_dir is None:
+ assert not ctx.encrypted
+ assert not ctx.auth_thread
+ else:
+ assert ctx.encrypted
+ assert isinstance(ctx.auth_thread, ThreadAuthenticator)
+
+ ctx.destroy()
+
+
+@pytest.mark.local
+def test_create_certificates(tmpd_cwd: pathlib.Path):
+ cert_dir = tmpd_cwd / "certificates"
+ assert not os.path.exists(cert_dir)
+
+ ret = curvezmq.create_certificates(tmpd_cwd)
+
+ assert str(cert_dir) == ret
+ assert os.path.exists(cert_dir)
+ assert os.stat(cert_dir).st_mode & 0o777 == 0o700
+ assert len(os.listdir(cert_dir)) == 4
+
+
+@pytest.mark.local
+def test_create_certificates_overwrite(tmpd_cwd: pathlib.Path):
+ cert_dir = curvezmq.create_certificates(tmpd_cwd)
+ client_pub_1, client_sec_1 = curvezmq._load_certificate(cert_dir, name="client")
+ server_pub_1, server_sec_1 = curvezmq._load_certificate(cert_dir, name="server")
+
+ curvezmq.create_certificates(tmpd_cwd)
+ client_pub_2, client_sec_2 = curvezmq._load_certificate(cert_dir, name="client")
+ server_pub_2, server_sec_2 = curvezmq._load_certificate(cert_dir, name="server")
+
+ assert client_pub_1 != client_pub_2
+ assert client_sec_1 != client_sec_2
+ assert server_pub_1 != server_pub_2
+ assert server_sec_1 != server_sec_2
+
+
+@pytest.mark.local
+def test_cert_dir_not_private(tmpd_cwd: pathlib.Path):
+ cert_dir = tmpd_cwd / "certificates"
+ os.makedirs(cert_dir, mode=0o777)
+ client_ctx = curvezmq.ClientContext(cert_dir)
+ server_ctx = curvezmq.ServerContext(cert_dir)
+
+ err_msg = "directory must be private"
+
+ with pytest.raises(OSError) as pyt_e:
+ client_ctx.socket(zmq.REQ)
+ assert err_msg in str(pyt_e.value)
+
+ with pytest.raises(OSError) as pyt_e:
+ server_ctx.socket(zmq.REP)
+ assert err_msg in str(pyt_e.value)
+
+ client_ctx.destroy()
+ server_ctx.destroy()
+
+
+@pytest.mark.local
+def test_missing_cert_dir():
+ cert_dir = "/bad/cert/dir"
+ client_ctx = curvezmq.ClientContext(cert_dir)
+ server_ctx = curvezmq.ServerContext(cert_dir)
+
+ err_msg = "No such file or directory"
+
+ with pytest.raises(FileNotFoundError) as pyt_e:
+ client_ctx.socket(zmq.REQ)
+ assert err_msg in str(pyt_e.value)
+
+ with pytest.raises(FileNotFoundError) as pyt_e:
+ server_ctx.socket(zmq.REP)
+ assert err_msg in str(pyt_e.value)
+
+ client_ctx.destroy()
+ server_ctx.destroy()
+
+
+@pytest.mark.local
+def test_missing_secret_file(tmpd_cwd: pathlib.Path):
+ cert_dir = tmpd_cwd / "certificates"
+ os.makedirs(cert_dir, mode=0o700)
+
+ client_ctx = curvezmq.ClientContext(cert_dir)
+ server_ctx = curvezmq.ServerContext(cert_dir)
+
+ err_msg = "Invalid certificate file"
+
+ with pytest.raises(OSError) as pyt_e:
+ client_ctx.socket(zmq.REQ)
+ assert err_msg in str(pyt_e.value)
+
+ with pytest.raises(OSError) as pyt_e:
+ server_ctx.socket(zmq.REP)
+ assert err_msg in str(pyt_e.value)
+
+ client_ctx.destroy()
+ server_ctx.destroy()
+
+
+@pytest.mark.local
+def test_bad_secret_file(tmpd_cwd: pathlib.Path):
+ client_sec = tmpd_cwd / "client.key_secret"
+ server_sec = tmpd_cwd / "server.key_secret"
+ client_sec.write_text("bad")
+ server_sec.write_text("boy")
+
+ client_ctx = curvezmq.ClientContext(tmpd_cwd)
+ server_ctx = curvezmq.ServerContext(tmpd_cwd)
+
+ err_msg = "No public key found"
+
+ with pytest.raises(ValueError) as pyt_e:
+ client_ctx.socket(zmq.REQ)
+ assert err_msg in str(pyt_e.value)
+
+ with pytest.raises(ValueError) as pyt_e:
+ server_ctx.socket(zmq.REP)
+ assert err_msg in str(pyt_e.value)
+
+ client_ctx.destroy()
+ server_ctx.destroy()
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_client_context_term(client_ctx: curvezmq.ClientContext):
+ assert not client_ctx.closed
+
+ client_ctx.term()
+
+ assert client_ctx.closed
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_server_context_term(server_ctx: curvezmq.ServerContext, encrypted: bool):
+ assert not server_ctx.closed
+ if encrypted:
+ assert server_ctx.auth_thread
+ assert server_ctx.auth_thread.pipe
+
+ server_ctx.term()
+
+ assert server_ctx.closed
+ if encrypted:
+ assert server_ctx.auth_thread
+ assert not server_ctx.auth_thread.pipe
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_client_context_destroy(client_ctx: curvezmq.ClientContext):
+ sock = client_ctx.socket(zmq.REP)
+
+ assert not client_ctx.closed
+
+ client_ctx.destroy()
+
+ assert sock.closed
+ assert client_ctx.closed
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_server_context_destroy(server_ctx: curvezmq.ServerContext, encrypted: bool):
+ sock = server_ctx.socket(zmq.REP)
+
+ assert not server_ctx.closed
+ if encrypted:
+ assert server_ctx.auth_thread
+ assert server_ctx.auth_thread.pipe
+
+ server_ctx.destroy()
+
+ assert sock.closed
+ assert server_ctx.closed
+ if encrypted:
+ assert server_ctx.auth_thread
+ assert not server_ctx.auth_thread.pipe
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_client_context_recreate(client_ctx: curvezmq.ClientContext):
+ hidden_ctx = client_ctx._ctx
+ sock = client_ctx.socket(zmq.REQ)
+
+ assert not sock.closed
+ assert not client_ctx.closed
+
+ client_ctx.recreate()
+
+ assert sock.closed
+ assert not client_ctx.closed
+ assert hidden_ctx != client_ctx._ctx
+ assert hidden_ctx.closed
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_server_context_recreate(server_ctx: curvezmq.ServerContext, encrypted: bool):
+ hidden_ctx = server_ctx._ctx
+ sock = server_ctx.socket(zmq.REP)
+
+ assert not sock.closed
+ assert not server_ctx.closed
+ if encrypted:
+ assert server_ctx.auth_thread
+ auth_thread = server_ctx.auth_thread
+ assert auth_thread.pipe
+
+ server_ctx.recreate()
+
+ assert sock.closed
+ assert not server_ctx.closed
+ assert hidden_ctx.closed
+ assert hidden_ctx != server_ctx._ctx
+ if encrypted:
+ assert server_ctx.auth_thread
+ assert auth_thread != server_ctx.auth_thread
+ assert server_ctx.auth_thread.pipe
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_connection(
+ server_ctx: curvezmq.ServerContext, client_ctx: curvezmq.ClientContext
+):
+ server_socket, port = get_server_socket(server_ctx)
+ client_socket = get_client_socket(client_ctx, port)
+
+ msg = b"howdy"
+ client_socket.send(msg)
+ recv = server_socket.recv()
+
+ assert recv == msg
+
+
+@pytest.mark.local
+@mock.patch.object(curvezmq, "_load_certificate")
+def test_invalid_key_format(
+ mock_load_cert: mock.MagicMock,
+ server_ctx: curvezmq.ServerContext,
+ client_ctx: curvezmq.ClientContext,
+):
+ mock_load_cert.return_value = (b"badkey", b"badkey")
+
+ with pytest.raises(ValueError) as e1_info:
+ server_ctx.socket(zmq.REP)
+ with pytest.raises(ValueError) as e2_info:
+ client_ctx.socket(zmq.REQ)
+ e1, e2 = e1_info.exconly, e2_info.exconly
+
+ assert str(e1) == str(e2)
+ assert "Invalid CurveZMQ key format" in str(e1)
+
+
+@pytest.mark.local
+def test_invalid_client_keys(server_ctx: curvezmq.ServerContext, zmq_ctx: zmq.Context):
+ server_socket, port = get_server_socket(server_ctx)
+
+ cert_dir = server_ctx.cert_dir
+ assert cert_dir # For mypy
+ public_key, secret_key = curvezmq._load_certificate(cert_dir, "client")
+ server_key, _ = curvezmq._load_certificate(cert_dir, "server")
+
+ BAD_PUB_KEY, BAD_SEC_KEY = zmq.curve_keypair()
+ msg = b"howdy"
+
+ client_socket = get_external_client_socket(
+ zmq_ctx,
+ public_key,
+ secret_key,
+ server_key,
+ port,
+ )
+ client_socket.send(msg)
+ assert server_socket.recv() == msg
+
+ client_socket = get_external_client_socket(
+ zmq_ctx,
+ BAD_PUB_KEY,
+ BAD_SEC_KEY,
+ server_key,
+ port,
+ )
+ client_socket.send(msg)
+ with pytest.raises(zmq.Again):
+ server_socket.recv()
+
+ client_socket = get_external_client_socket(
+ zmq_ctx,
+ public_key,
+ secret_key,
+ BAD_PUB_KEY,
+ port,
+ )
+ client_socket.send(msg)
+ with pytest.raises(zmq.Again):
+ server_socket.recv()
+
+ # Ensure sockets are operational
+ client_socket = get_external_client_socket(
+ zmq_ctx,
+ public_key,
+ secret_key,
+ server_key,
+ port,
+ )
+ client_socket.send(msg)
+ assert server_socket.recv() == msg
+
+
+@pytest.mark.local
+def test_invalid_server_key(client_ctx: curvezmq.ClientContext, zmq_ctx: zmq.Context):
+ cert_dir = client_ctx.cert_dir
+ assert cert_dir # For mypy
+ _, secret_key = curvezmq._load_certificate(cert_dir, "server")
+
+ _, BAD_SEC_KEY = zmq.curve_keypair()
+ msg = b"howdy"
+
+ server_socket, port = get_external_server_socket(zmq_ctx, secret_key)
+ client_socket = get_client_socket(client_ctx, port)
+ client_socket.send(msg)
+ assert server_socket.recv() == msg
+
+ server_socket, port = get_external_server_socket(zmq_ctx, BAD_SEC_KEY)
+ client_socket = get_client_socket(client_ctx, port)
+ client_socket.send(msg)
+ with pytest.raises(zmq.Again):
+ server_socket.recv()
+
+ # Ensure sockets are operational
+ server_socket, port = get_external_server_socket(zmq_ctx, secret_key)
+ client_socket = get_client_socket(client_ctx, port)
+ client_socket.send(msg)
+ assert server_socket.recv() == msg
From 47488c10af16b2ef17756c4e8438be962b0a332e Mon Sep 17 00:00:00 2001
From: Reid Mello <30907815+rjmello@users.noreply.github.com>
Date: Tue, 16 Jan 2024 15:46:43 -0500
Subject: [PATCH 2/3] Implement CurveZMQ in HTEX
The interchange serves as a CurveZMQ server, while the executor and
various managers serve as CurveZMQ clients. Thus, all communication
between these entities is now encrypted.
The HTEX `start` method generates new certs for each run in a private
`certificates/` directory. We generate a single shared client cert
because all clients will have access to this dir.
We disable encryption by default, but users can enable it by setting the
`encrypted` initialization argument for the HTEX to `True`.
---
parsl/executors/high_throughput/executor.py | 53 ++++++++++++---
.../executors/high_throughput/interchange.py | 21 ++++--
.../high_throughput/process_worker_pool.py | 22 +++++--
parsl/executors/high_throughput/zmq_pipes.py | 36 +++++-----
parsl/tests/configs/ad_hoc_cluster_htex.py | 1 +
parsl/tests/configs/azure_single_node.py | 1 +
parsl/tests/configs/bluewaters.py | 1 +
parsl/tests/configs/bridges.py | 1 +
parsl/tests/configs/cc_in2p3.py | 1 +
parsl/tests/configs/comet.py | 1 +
parsl/tests/configs/cooley_htex.py | 1 +
parsl/tests/configs/ec2_single_node.py | 1 +
parsl/tests/configs/ec2_spot.py | 1 +
parsl/tests/configs/frontera.py | 1 +
parsl/tests/configs/htex_ad_hoc_cluster.py | 1 +
parsl/tests/configs/htex_local.py | 1 +
parsl/tests/configs/htex_local_alternate.py | 1 +
.../configs/htex_local_intask_staging.py | 1 +
.../tests/configs/htex_local_rsync_staging.py | 1 +
parsl/tests/configs/local_adhoc.py | 1 +
parsl/tests/configs/midway.py | 1 +
parsl/tests/configs/nscc_singapore.py | 1 +
parsl/tests/configs/osg_htex.py | 1 +
parsl/tests/configs/petrelkube.py | 1 +
parsl/tests/configs/summit.py | 1 +
parsl/tests/configs/swan_htex.py | 1 +
parsl/tests/configs/theta.py | 1 +
parsl/tests/manual_tests/htex_local.py | 1 +
parsl/tests/manual_tests/test_ad_hoc_htex.py | 1 +
.../test_fan_in_out_htex_remote.py | 1 +
.../tests/manual_tests/test_memory_limits.py | 1 +
parsl/tests/scaling_tests/htex_local.py | 1 +
parsl/tests/sites/test_affinity.py | 1 +
parsl/tests/sites/test_concurrent.py | 1 +
parsl/tests/sites/test_dynamic_executor.py | 1 +
parsl/tests/sites/test_worker_info.py | 1 +
parsl/tests/test_htex/test_htex.py | 46 +++++++++++++
.../tests/test_htex/test_htex_zmq_binding.py | 66 +++++++++++++++----
.../test_regression/test_97_parallelism_0.py | 1 +
.../test_scaling/test_block_error_handler.py | 11 ++--
.../test_scaling/test_regression_1621.py | 1 +
parsl/tests/test_scaling/test_scale_down.py | 1 +
42 files changed, 233 insertions(+), 57 deletions(-)
create mode 100644 parsl/tests/test_htex/test_htex.py
diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py
index 2e63df5a72..46929f75a2 100644
--- a/parsl/executors/high_throughput/executor.py
+++ b/parsl/executors/high_throughput/executor.py
@@ -22,6 +22,7 @@
UnsupportedFeatureError
)
+from parsl import curvezmq
from parsl.executors.status_handling import BlockProviderExecutor
from parsl.providers.base import ExecutionProvider
from parsl.data_provider.staging import Staging
@@ -174,6 +175,9 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
worker_logdir_root : string
In case of a remote file system, specify the path to where logs will be kept.
+
+ encrypted : bool
+ Flag to enable/disable encryption (CurveZMQ). Default is False.
"""
@typeguard.typechecked
@@ -199,7 +203,8 @@ def __init__(self,
poll_period: int = 10,
address_probe_timeout: Optional[int] = None,
worker_logdir_root: Optional[str] = None,
- block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True):
+ block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True,
+ encrypted: bool = False):
logger.debug("Initializing HighThroughputExecutor")
@@ -256,6 +261,8 @@ def __init__(self,
self.run_dir = '.'
self.worker_logdir_root = worker_logdir_root
self.cpu_affinity = cpu_affinity
+ self.encrypted = encrypted
+ self.cert_dir = None
if not launch_cmd:
launch_cmd = (
@@ -267,6 +274,7 @@ def __init__(self,
"--poll {poll_period} "
"--task_port={task_port} "
"--result_port={result_port} "
+ "--cert_dir {cert_dir} "
"--logdir={logdir} "
"--block_id={{block_id}} "
"--hb_period={heartbeat_period} "
@@ -280,6 +288,16 @@ def __init__(self,
radio_mode = "htex"
+ @property
+ def logdir(self):
+ return "{}/{}".format(self.run_dir, self.label)
+
+ @property
+ def worker_logdir(self):
+ if self.worker_logdir_root is not None:
+ return "{}/{}".format(self.worker_logdir_root, self.label)
+ return self.logdir
+
def initialize_scaling(self):
"""Compose the launch command and scale out the initial blocks.
"""
@@ -289,9 +307,6 @@ def initialize_scaling(self):
address_probe_timeout_string = ""
if self.address_probe_timeout:
address_probe_timeout_string = "--address_probe_timeout={}".format(self.address_probe_timeout)
- worker_logdir = "{}/{}".format(self.run_dir, self.label)
- if self.worker_logdir_root is not None:
- worker_logdir = "{}/{}".format(self.worker_logdir_root, self.label)
l_cmd = self.launch_cmd.format(debug=debug_opts,
prefetch_capacity=self.prefetch_capacity,
@@ -306,7 +321,8 @@ def initialize_scaling(self):
heartbeat_period=self.heartbeat_period,
heartbeat_threshold=self.heartbeat_threshold,
poll_period=self.poll_period,
- logdir=worker_logdir,
+ cert_dir=self.cert_dir,
+ logdir=self.worker_logdir,
cpu_affinity=self.cpu_affinity,
accelerators=" ".join(self.available_accelerators))
self.launch_cmd = l_cmd
@@ -327,9 +343,25 @@ def initialize_scaling(self):
def start(self):
"""Create the Interchange process and connect to it.
"""
- self.outgoing_q = zmq_pipes.TasksOutgoing("127.0.0.1", self.interchange_port_range)
- self.incoming_q = zmq_pipes.ResultsIncoming("127.0.0.1", self.interchange_port_range)
- self.command_client = zmq_pipes.CommandClient("127.0.0.1", self.interchange_port_range)
+ if self.encrypted and self.cert_dir is None:
+ logger.debug("Creating CurveZMQ certificates")
+ self.cert_dir = curvezmq.create_certificates(self.logdir)
+ elif not self.encrypted and self.cert_dir:
+ raise AttributeError(
+ "The certificates directory path attribute (cert_dir) is defined, but the "
+ "encrypted attribute is set to False. You must either change cert_dir to "
+ "None or encrypted to True."
+ )
+
+ self.outgoing_q = zmq_pipes.TasksOutgoing(
+ curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
+ )
+ self.incoming_q = zmq_pipes.ResultsIncoming(
+ curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
+ )
+ self.command_client = zmq_pipes.CommandClient(
+ curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
+ )
self._queue_management_thread = None
self._start_queue_management_thread()
@@ -450,10 +482,11 @@ def _start_local_interchange_process(self):
"worker_port_range": self.worker_port_range,
"hub_address": self.hub_address,
"hub_port": self.hub_port,
- "logdir": "{}/{}".format(self.run_dir, self.label),
+ "logdir": self.logdir,
"heartbeat_threshold": self.heartbeat_threshold,
"poll_period": self.poll_period,
- "logging_level": logging.DEBUG if self.worker_debug else logging.INFO
+ "logging_level": logging.DEBUG if self.worker_debug else logging.INFO,
+ "cert_dir": self.cert_dir,
},
daemon=True,
name="HTEX-Interchange"
diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py
index 67b70aa78d..c65aecb57a 100644
--- a/parsl/executors/high_throughput/interchange.py
+++ b/parsl/executors/high_throughput/interchange.py
@@ -16,6 +16,7 @@
from typing import cast, Any, Dict, NoReturn, Sequence, Set, Optional, Tuple
+from parsl import curvezmq
from parsl.utils import setproctitle
from parsl.version import VERSION as PARSL_VERSION
from parsl.serialize import serialize as serialize_object
@@ -79,6 +80,7 @@ def __init__(self,
logdir: str = ".",
logging_level: int = logging.INFO,
poll_period: int = 10,
+ cert_dir: Optional[str] = None,
) -> None:
"""
Parameters
@@ -120,7 +122,10 @@ def __init__(self,
poll_period : int
The main thread polling period, in milliseconds. Default: 10ms
+ cert_dir : str | None
+ Path to the certificate directory. Default: None
"""
+ self.cert_dir = cert_dir
self.logdir = logdir
os.makedirs(self.logdir, exist_ok=True)
@@ -134,15 +139,15 @@ def __init__(self,
logger.info("Attempting connection to client at {} on ports: {},{},{}".format(
client_address, client_ports[0], client_ports[1], client_ports[2]))
- self.context = zmq.Context()
- self.task_incoming = self.context.socket(zmq.DEALER)
+ self.zmq_context = curvezmq.ServerContext(self.cert_dir)
+ self.task_incoming = self.zmq_context.socket(zmq.DEALER)
self.task_incoming.set_hwm(0)
self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))
- self.results_outgoing = self.context.socket(zmq.DEALER)
+ self.results_outgoing = self.zmq_context.socket(zmq.DEALER)
self.results_outgoing.set_hwm(0)
self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))
- self.command_channel = self.context.socket(zmq.REP)
+ self.command_channel = self.zmq_context.socket(zmq.REP)
self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
logger.info("Connected to client")
@@ -155,9 +160,9 @@ def __init__(self,
self.worker_ports = worker_ports
self.worker_port_range = worker_port_range
- self.task_outgoing = self.context.socket(zmq.ROUTER)
+ self.task_outgoing = self.zmq_context.socket(zmq.ROUTER)
self.task_outgoing.set_hwm(0)
- self.results_incoming = self.context.socket(zmq.ROUTER)
+ self.results_incoming = self.zmq_context.socket(zmq.ROUTER)
self.results_incoming.set_hwm(0)
if self.worker_ports:
@@ -241,7 +246,8 @@ def task_puller(self) -> NoReturn:
def _create_monitoring_channel(self) -> Optional[zmq.Socket]:
if self.hub_address and self.hub_port:
logger.info("Connecting to monitoring")
- hub_channel = self.context.socket(zmq.DEALER)
+ # This is a one-off because monitoring is unencrypted
+ hub_channel = zmq.Context().socket(zmq.DEALER)
hub_channel.set_hwm(0)
hub_channel.connect("tcp://{}:{}".format(self.hub_address, self.hub_port))
logger.info("Monitoring enabled and connected to hub")
@@ -379,6 +385,7 @@ def start(self) -> None:
self.expire_bad_managers(interesting_managers, hub_channel)
self.process_tasks_to_send(interesting_managers)
+ self.zmq_context.destroy()
delta = time.time() - start
logger.info("Processed {} tasks in {} seconds".format(self.count, delta))
logger.warning("Exiting")
diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py
index 665a3ef33d..7826bec3dd 100755
--- a/parsl/executors/high_throughput/process_worker_pool.py
+++ b/parsl/executors/high_throughput/process_worker_pool.py
@@ -20,8 +20,8 @@
from multiprocessing.managers import DictProxy
from multiprocessing.sharedctypes import Synchronized
+from parsl import curvezmq
from parsl.process_loggers import wrap_with_logs
-
from parsl.version import VERSION as PARSL_VERSION
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.high_throughput.errors import WorkerLost
@@ -63,7 +63,8 @@ def __init__(self, *,
heartbeat_period,
poll_period,
cpu_affinity,
- available_accelerators: Sequence[str]):
+ available_accelerators: Sequence[str],
+ cert_dir: Optional[str]):
"""
Parameters
----------
@@ -118,6 +119,8 @@ def __init__(self, *,
available_accelerators: list of str
List of accelerators available to the workers.
+ cert_dir : str | None
+ Path to the certificate directory.
"""
logger.info("Manager started")
@@ -137,15 +140,16 @@ def __init__(self, *,
print("Failed to find a viable address to connect to interchange. Exiting")
exit(5)
- self.context = zmq.Context()
- self.task_incoming = self.context.socket(zmq.DEALER)
+ self.cert_dir = cert_dir
+ self.zmq_context = curvezmq.ClientContext(self.cert_dir)
+ self.task_incoming = self.zmq_context.socket(zmq.DEALER)
self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
# Linger is set to 0, so that the manager can exit even when there might be
# messages in the pipe
self.task_incoming.setsockopt(zmq.LINGER, 0)
self.task_incoming.connect(task_q_url)
- self.result_outgoing = self.context.socket(zmq.DEALER)
+ self.result_outgoing = self.zmq_context.socket(zmq.DEALER)
self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
self.result_outgoing.setsockopt(zmq.LINGER, 0)
self.result_outgoing.connect(result_q_url)
@@ -468,7 +472,7 @@ def start(self):
self.task_incoming.close()
self.result_outgoing.close()
- self.context.term()
+ self.zmq_context.term()
delta = time.time() - start
logger.info("process_worker_pool ran for {} seconds".format(delta))
return
@@ -720,6 +724,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_
help="Enable logging at DEBUG level")
parser.add_argument("-a", "--addresses", default='',
help="Comma separated list of addresses at which the interchange could be reached")
+ parser.add_argument("--cert_dir", required=True,
+ help="Path to certificate directory.")
parser.add_argument("-l", "--logdir", default="process_worker_pool_logs",
help="Process worker pool log directory")
parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1],
@@ -773,6 +779,7 @@ def strategyorlist(s: str):
logger.info("Python version: {}".format(sys.version))
logger.info("Debug logging: {}".format(args.debug))
+ logger.info("Certificates dir: {}".format(args.cert_dir))
logger.info("Log dir: {}".format(args.logdir))
logger.info("Manager ID: {}".format(args.uid))
logger.info("Block ID: {}".format(args.block_id))
@@ -804,7 +811,8 @@ def strategyorlist(s: str):
heartbeat_period=int(args.hb_period),
poll_period=int(args.poll),
cpu_affinity=args.cpu_affinity,
- available_accelerators=args.available_accelerators)
+ available_accelerators=args.available_accelerators,
+ cert_dir=None if args.cert_dir == "None" else args.cert_dir)
manager.start()
except Exception:
diff --git a/parsl/executors/high_throughput/zmq_pipes.py b/parsl/executors/high_throughput/zmq_pipes.py
index cd5c7f2e8b..72c1422d4d 100644
--- a/parsl/executors/high_throughput/zmq_pipes.py
+++ b/parsl/executors/high_throughput/zmq_pipes.py
@@ -4,24 +4,28 @@
import logging
import threading
+from parsl import curvezmq
+
logger = logging.getLogger(__name__)
class CommandClient:
""" CommandClient
"""
- def __init__(self, ip_address, port_range):
+ def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range):
"""
Parameters
----------
+ zmq_context: curvezmq.ClientContext
+ CurveZMQ client context used to create secure sockets
ip_address: str
IP address of the client (where Parsl runs)
port_range: tuple(int, int)
Port range for the comms between client and interchange
"""
- self.context = zmq.Context()
+ self.zmq_context = zmq_context
self.ip_address = ip_address
self.port_range = port_range
self.port = None
@@ -33,7 +37,7 @@ def create_socket_and_bind(self):
Upon recreating the socket, we bind to the same port.
"""
- self.zmq_socket = self.context.socket(zmq.REQ)
+ self.zmq_socket = self.zmq_context.socket(zmq.REQ)
self.zmq_socket.setsockopt(zmq.LINGER, 0)
if self.port is None:
self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(self.ip_address),
@@ -62,9 +66,7 @@ def run(self, message, max_retries=3):
except zmq.ZMQError:
logger.exception("Potential ZMQ REQ-REP deadlock caught")
logger.info("Trying to reestablish context")
- self.zmq_socket.close()
- self.context.destroy()
- self.context = zmq.Context()
+ self.zmq_context.recreate()
self.create_socket_and_bind()
else:
break
@@ -77,25 +79,27 @@ def run(self, message, max_retries=3):
def close(self):
self.zmq_socket.close()
- self.context.term()
+ self.zmq_context.term()
class TasksOutgoing:
""" Outgoing task queue from the executor to the Interchange
"""
- def __init__(self, ip_address, port_range):
+ def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range):
"""
Parameters
----------
+ zmq_context: curvezmq.ClientContext
+ CurveZMQ client context used to create secure sockets
ip_address: str
IP address of the client (where Parsl runs)
port_range: tuple(int, int)
Port range for the comms between client and interchange
"""
- self.context = zmq.Context()
- self.zmq_socket = self.context.socket(zmq.DEALER)
+ self.zmq_context = zmq_context
+ self.zmq_socket = self.zmq_context.socket(zmq.DEALER)
self.zmq_socket.set_hwm(0)
self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address),
min_port=port_range[0],
@@ -127,26 +131,28 @@ def put(self, message):
def close(self):
self.zmq_socket.close()
- self.context.term()
+ self.zmq_context.term()
class ResultsIncoming:
""" Incoming results queue from the Interchange to the executor
"""
- def __init__(self, ip_address, port_range):
+ def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range):
"""
Parameters
----------
+ zmq_context: curvezmq.ClientContext
+ CurveZMQ client context used to create secure sockets
ip_address: str
IP address of the client (where Parsl runs)
port_range: tuple(int, int)
Port range for the comms between client and interchange
"""
- self.context = zmq.Context()
- self.results_receiver = self.context.socket(zmq.DEALER)
+ self.zmq_context = zmq_context
+ self.results_receiver = self.zmq_context.socket(zmq.DEALER)
self.results_receiver.set_hwm(0)
self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address),
min_port=port_range[0],
@@ -160,4 +166,4 @@ def get(self):
def close(self):
self.results_receiver.close()
- self.context.term()
+ self.zmq_context.term()
diff --git a/parsl/tests/configs/ad_hoc_cluster_htex.py b/parsl/tests/configs/ad_hoc_cluster_htex.py
index 5c90d27918..0a3e9dc027 100644
--- a/parsl/tests/configs/ad_hoc_cluster_htex.py
+++ b/parsl/tests/configs/ad_hoc_cluster_htex.py
@@ -18,6 +18,7 @@
label='remote_htex',
max_workers=2,
worker_logdir_root=user_opts['adhoc']['script_dir'],
+ encrypted=True,
provider=AdHocProvider(
# Command to be run before starting a worker, such as:
# 'module load Anaconda; source activate parsl_env'.
diff --git a/parsl/tests/configs/azure_single_node.py b/parsl/tests/configs/azure_single_node.py
index 90e90c19cd..17f3e7def9 100644
--- a/parsl/tests/configs/azure_single_node.py
+++ b/parsl/tests/configs/azure_single_node.py
@@ -40,6 +40,7 @@
storage_access=[HTTPInTaskStaging(), FTPInTaskStaging(), RSyncStaging(getpass.getuser() + "@" + user_opts['public_ip'])],
label='azure_single_node',
address=user_opts['public_ip'],
+ encrypted=True,
provider=provider
)
]
diff --git a/parsl/tests/configs/bluewaters.py b/parsl/tests/configs/bluewaters.py
index 2e27eb8c4a..7cd088ecac 100644
--- a/parsl/tests/configs/bluewaters.py
+++ b/parsl/tests/configs/bluewaters.py
@@ -14,6 +14,7 @@ def fresh_config():
cores_per_worker=1,
worker_debug=False,
max_workers=1,
+ encrypted=True,
provider=TorqueProvider(
queue='normal',
launcher=AprunLauncher(overrides="-b -- bwpy-environ --"),
diff --git a/parsl/tests/configs/bridges.py b/parsl/tests/configs/bridges.py
index 4dea2fe468..06d0c0cd43 100644
--- a/parsl/tests/configs/bridges.py
+++ b/parsl/tests/configs/bridges.py
@@ -14,6 +14,7 @@ def fresh_config():
# which compute nodes can communicate
# address=address_by_interface('bond0.144'),
max_workers=1,
+ encrypted=True,
provider=SlurmProvider(
user_opts['bridges']['partition'], # Partition / QOS
nodes_per_block=2,
diff --git a/parsl/tests/configs/cc_in2p3.py b/parsl/tests/configs/cc_in2p3.py
index 9b54fa37e4..9a76d1f16e 100644
--- a/parsl/tests/configs/cc_in2p3.py
+++ b/parsl/tests/configs/cc_in2p3.py
@@ -12,6 +12,7 @@ def fresh_config():
HighThroughputExecutor(
label='cc_in2p3_htex',
max_workers=1,
+ encrypted=True,
provider=GridEngineProvider(
channel=LocalChannel(),
nodes_per_block=2,
diff --git a/parsl/tests/configs/comet.py b/parsl/tests/configs/comet.py
index 1a253e26b2..8f39539509 100644
--- a/parsl/tests/configs/comet.py
+++ b/parsl/tests/configs/comet.py
@@ -11,6 +11,7 @@ def fresh_config():
HighThroughputExecutor(
label='Comet_HTEX_multinode',
max_workers=1,
+ encrypted=True,
provider=SlurmProvider(
'debug',
launcher=SrunLauncher(),
diff --git a/parsl/tests/configs/cooley_htex.py b/parsl/tests/configs/cooley_htex.py
index 4228da10b0..202379d0af 100644
--- a/parsl/tests/configs/cooley_htex.py
+++ b/parsl/tests/configs/cooley_htex.py
@@ -18,6 +18,7 @@
label="cooley_htex",
worker_debug=False,
cores_per_worker=1,
+ encrypted=True,
provider=CobaltProvider(
queue='debug',
account=user_opts['cooley']['account'],
diff --git a/parsl/tests/configs/ec2_single_node.py b/parsl/tests/configs/ec2_single_node.py
index 61cccc8bea..92c8108add 100644
--- a/parsl/tests/configs/ec2_single_node.py
+++ b/parsl/tests/configs/ec2_single_node.py
@@ -28,6 +28,7 @@
HighThroughputExecutor(
label='ec2_single_node',
address=user_opts['public_ip'],
+ encrypted=True,
provider=AWSProvider(
user_opts['ec2']['image_id'],
region=user_opts['ec2']['region'],
diff --git a/parsl/tests/configs/ec2_spot.py b/parsl/tests/configs/ec2_spot.py
index c693d9b17d..37f272e4de 100644
--- a/parsl/tests/configs/ec2_spot.py
+++ b/parsl/tests/configs/ec2_spot.py
@@ -15,6 +15,7 @@
HighThroughputExecutor(
label='ec2_single_node',
address=user_opts['public_ip'],
+ encrypted=True,
provider=AWSProvider(
user_opts['ec2']['image_id'],
region=user_opts['ec2']['region'],
diff --git a/parsl/tests/configs/frontera.py b/parsl/tests/configs/frontera.py
index a003ce88c6..ba302d85b5 100644
--- a/parsl/tests/configs/frontera.py
+++ b/parsl/tests/configs/frontera.py
@@ -16,6 +16,7 @@ def fresh_config():
HighThroughputExecutor(
label="frontera_htex",
max_workers=1,
+ encrypted=True,
provider=SlurmProvider(
cmd_timeout=60, # Add extra time for slow scheduler responses
channel=LocalChannel(),
diff --git a/parsl/tests/configs/htex_ad_hoc_cluster.py b/parsl/tests/configs/htex_ad_hoc_cluster.py
index 80949f1d1e..82ae7bf621 100644
--- a/parsl/tests/configs/htex_ad_hoc_cluster.py
+++ b/parsl/tests/configs/htex_ad_hoc_cluster.py
@@ -13,6 +13,7 @@
cores_per_worker=1,
worker_debug=False,
address=user_opts['public_ip'],
+ encrypted=True,
provider=AdHocProvider(
move_files=False,
parallelism=1,
diff --git a/parsl/tests/configs/htex_local.py b/parsl/tests/configs/htex_local.py
index 2039918a4d..0f09fce9dc 100644
--- a/parsl/tests/configs/htex_local.py
+++ b/parsl/tests/configs/htex_local.py
@@ -13,6 +13,7 @@ def fresh_config():
label="htex_local",
worker_debug=True,
cores_per_worker=1,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/configs/htex_local_alternate.py b/parsl/tests/configs/htex_local_alternate.py
index e3a7b6ff1e..2598a91b85 100644
--- a/parsl/tests/configs/htex_local_alternate.py
+++ b/parsl/tests/configs/htex_local_alternate.py
@@ -48,6 +48,7 @@ def fresh_config():
heartbeat_period=2,
heartbeat_threshold=5,
poll_period=100,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=0,
diff --git a/parsl/tests/configs/htex_local_intask_staging.py b/parsl/tests/configs/htex_local_intask_staging.py
index c6ad88be69..634cbf1654 100644
--- a/parsl/tests/configs/htex_local_intask_staging.py
+++ b/parsl/tests/configs/htex_local_intask_staging.py
@@ -15,6 +15,7 @@
label="htex_Local",
worker_debug=True,
cores_per_worker=1,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/configs/htex_local_rsync_staging.py b/parsl/tests/configs/htex_local_rsync_staging.py
index 514fbfaee7..6cb47e55e9 100644
--- a/parsl/tests/configs/htex_local_rsync_staging.py
+++ b/parsl/tests/configs/htex_local_rsync_staging.py
@@ -16,6 +16,7 @@
worker_debug=True,
cores_per_worker=1,
working_dir="./rsync-workdir/",
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/configs/local_adhoc.py b/parsl/tests/configs/local_adhoc.py
index 7ac6d7e781..96c08ee309 100644
--- a/parsl/tests/configs/local_adhoc.py
+++ b/parsl/tests/configs/local_adhoc.py
@@ -9,6 +9,7 @@ def fresh_config():
executors=[
HighThroughputExecutor(
label='AdHoc',
+ encrypted=True,
provider=AdHocProvider(
channels=[LocalChannel(), LocalChannel()]
)
diff --git a/parsl/tests/configs/midway.py b/parsl/tests/configs/midway.py
index b5362838e0..ca32edd2b1 100644
--- a/parsl/tests/configs/midway.py
+++ b/parsl/tests/configs/midway.py
@@ -13,6 +13,7 @@ def fresh_config():
label='Midway_HTEX_multinode',
worker_debug=False,
max_workers=1,
+ encrypted=True,
provider=SlurmProvider(
'broadwl', # Partition name, e.g 'broadwl'
launcher=SrunLauncher(),
diff --git a/parsl/tests/configs/nscc_singapore.py b/parsl/tests/configs/nscc_singapore.py
index d8abbd45ba..c78018dc55 100644
--- a/parsl/tests/configs/nscc_singapore.py
+++ b/parsl/tests/configs/nscc_singapore.py
@@ -17,6 +17,7 @@ def fresh_config():
worker_debug=False,
max_workers=1,
address=address_by_interface('ib0'),
+ encrypted=True,
provider=PBSProProvider(
launcher=MpiRunLauncher(),
# string to prepend to #PBS blocks in the submit
diff --git a/parsl/tests/configs/osg_htex.py b/parsl/tests/configs/osg_htex.py
index e4f1d4a98f..17250cca5d 100644
--- a/parsl/tests/configs/osg_htex.py
+++ b/parsl/tests/configs/osg_htex.py
@@ -14,6 +14,7 @@
HighThroughputExecutor(
label='OSG_HTEX',
max_workers=1,
+ encrypted=True,
provider=CondorProvider(
nodes_per_block=1,
init_blocks=4,
diff --git a/parsl/tests/configs/petrelkube.py b/parsl/tests/configs/petrelkube.py
index dbc2afec39..6e61cd101e 100644
--- a/parsl/tests/configs/petrelkube.py
+++ b/parsl/tests/configs/petrelkube.py
@@ -23,6 +23,7 @@ def fresh_config():
# Address for the pod worker to connect back
address=address_by_route(),
+ encrypted=True,
provider=KubernetesProvider(
namespace="dlhub-privileged",
diff --git a/parsl/tests/configs/summit.py b/parsl/tests/configs/summit.py
index ae0d009334..ed87366e55 100644
--- a/parsl/tests/configs/summit.py
+++ b/parsl/tests/configs/summit.py
@@ -21,6 +21,7 @@ def fresh_config():
# address=address_by_interface('ib0'), # This assumes Parsl is running on login node
worker_port_range=(50000, 55000),
max_workers=1,
+ encrypted=True,
provider=LSFProvider(
launcher=JsrunLauncher(),
walltime="00:10:00",
diff --git a/parsl/tests/configs/swan_htex.py b/parsl/tests/configs/swan_htex.py
index 8db9f5a975..4884703a2a 100644
--- a/parsl/tests/configs/swan_htex.py
+++ b/parsl/tests/configs/swan_htex.py
@@ -24,6 +24,7 @@
executors=[
HighThroughputExecutor(
label='swan_htex',
+ encrypted=True,
provider=TorqueProvider(
channel=SSHChannel(
hostname='swan.cray.com',
diff --git a/parsl/tests/configs/theta.py b/parsl/tests/configs/theta.py
index f2ce169390..36092a948e 100644
--- a/parsl/tests/configs/theta.py
+++ b/parsl/tests/configs/theta.py
@@ -12,6 +12,7 @@ def fresh_config():
HighThroughputExecutor(
label='theta_local_htex_multinode',
max_workers=1,
+ encrypted=True,
provider=CobaltProvider(
queue=user_opts['theta']['queue'],
account=user_opts['theta']['account'],
diff --git a/parsl/tests/manual_tests/htex_local.py b/parsl/tests/manual_tests/htex_local.py
index a2cd3fae7f..a5c0449be1 100644
--- a/parsl/tests/manual_tests/htex_local.py
+++ b/parsl/tests/manual_tests/htex_local.py
@@ -13,6 +13,7 @@
label="htex_local",
# worker_debug=True,
cores_per_worker=1,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/manual_tests/test_ad_hoc_htex.py b/parsl/tests/manual_tests/test_ad_hoc_htex.py
index d022686d6e..030658f243 100644
--- a/parsl/tests/manual_tests/test_ad_hoc_htex.py
+++ b/parsl/tests/manual_tests/test_ad_hoc_htex.py
@@ -15,6 +15,7 @@
label='AdHoc',
max_workers=2,
worker_logdir_root="/scratch/midway2/yadunand/parsl_scripts",
+ encrypted=True,
provider=AdHocProvider(
worker_init="source /scratch/midway2/yadunand/parsl_env_setup.sh",
channels=[SSHChannel(hostname=m,
diff --git a/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py b/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py
index 72dc4c7fd7..ce09e4bcad 100644
--- a/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py
+++ b/parsl/tests/manual_tests/test_fan_in_out_htex_remote.py
@@ -16,6 +16,7 @@ def local_setup():
label="theta_htex",
# worker_debug=True,
cores_per_worker=4,
+ encrypted=True,
provider=CobaltProvider(
queue='debug-flat-quad',
account='CSC249ADCD01',
diff --git a/parsl/tests/manual_tests/test_memory_limits.py b/parsl/tests/manual_tests/test_memory_limits.py
index 9024939827..be4507808f 100644
--- a/parsl/tests/manual_tests/test_memory_limits.py
+++ b/parsl/tests/manual_tests/test_memory_limits.py
@@ -28,6 +28,7 @@ def test_simple(mem_per_worker):
mem_per_worker=mem_per_worker,
cores_per_worker=0.1,
suppress_failure=True,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/scaling_tests/htex_local.py b/parsl/tests/scaling_tests/htex_local.py
index ff6ed5b96b..1037d194c4 100644
--- a/parsl/tests/scaling_tests/htex_local.py
+++ b/parsl/tests/scaling_tests/htex_local.py
@@ -10,6 +10,7 @@
label="htex_local",
cores_per_worker=1,
max_workers=8,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/sites/test_affinity.py b/parsl/tests/sites/test_affinity.py
index b2b588aa83..db3760916f 100644
--- a/parsl/tests/sites/test_affinity.py
+++ b/parsl/tests/sites/test_affinity.py
@@ -18,6 +18,7 @@ def local_config():
max_workers=2,
cpu_affinity='block',
available_accelerators=2,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/sites/test_concurrent.py b/parsl/tests/sites/test_concurrent.py
index 01c72fd6b9..a80e61a626 100644
--- a/parsl/tests/sites/test_concurrent.py
+++ b/parsl/tests/sites/test_concurrent.py
@@ -17,6 +17,7 @@ def make_config():
max_workers=2,
heartbeat_period=2,
heartbeat_threshold=4,
+ encrypted=True,
)
],
strategy='none',
diff --git a/parsl/tests/sites/test_dynamic_executor.py b/parsl/tests/sites/test_dynamic_executor.py
index 9c04d8b126..c3d449a5cc 100644
--- a/parsl/tests/sites/test_dynamic_executor.py
+++ b/parsl/tests/sites/test_dynamic_executor.py
@@ -60,6 +60,7 @@ def test_dynamic_executor():
label='htex_local',
cores_per_worker=1,
max_workers=5,
+ encrypted=True,
provider=LocalProvider(
init_blocks=1,
max_blocks=1,
diff --git a/parsl/tests/sites/test_worker_info.py b/parsl/tests/sites/test_worker_info.py
index d9bb3e4c93..296ddfaf05 100644
--- a/parsl/tests/sites/test_worker_info.py
+++ b/parsl/tests/sites/test_worker_info.py
@@ -15,6 +15,7 @@ def local_config():
label="htex_Local",
worker_debug=True,
max_workers=4,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
diff --git a/parsl/tests/test_htex/test_htex.py b/parsl/tests/test_htex/test_htex.py
new file mode 100644
index 0000000000..ffe693c69e
--- /dev/null
+++ b/parsl/tests/test_htex/test_htex.py
@@ -0,0 +1,46 @@
+import pathlib
+
+import pytest
+
+from parsl import curvezmq
+from parsl import HighThroughputExecutor
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False))
+@pytest.mark.parametrize("cert_dir_provided", (True, False))
+def test_htex_start_encrypted(
+ encrypted: bool, cert_dir_provided: bool, tmpd_cwd: pathlib.Path
+):
+ htex = HighThroughputExecutor(encrypted=encrypted)
+ htex.run_dir = str(tmpd_cwd)
+ if cert_dir_provided:
+ provided_base_dir = tmpd_cwd / "provided"
+ provided_base_dir.mkdir()
+ cert_dir = curvezmq.create_certificates(provided_base_dir)
+ htex.cert_dir = cert_dir
+ else:
+ cert_dir = curvezmq.create_certificates(htex.logdir)
+
+ if not encrypted and cert_dir_provided:
+ with pytest.raises(AttributeError) as pyt_e:
+ htex.start()
+ assert "change cert_dir to None" in str(pyt_e.value)
+ return
+
+ htex.start()
+
+ assert htex.encrypted is encrypted
+ if encrypted:
+ assert htex.cert_dir == cert_dir
+ assert htex.outgoing_q.zmq_context.cert_dir == cert_dir
+ assert htex.incoming_q.zmq_context.cert_dir == cert_dir
+ assert htex.command_client.zmq_context.cert_dir == cert_dir
+ assert isinstance(htex.outgoing_q.zmq_context, curvezmq.ClientContext)
+ assert isinstance(htex.incoming_q.zmq_context, curvezmq.ClientContext)
+ assert isinstance(htex.command_client.zmq_context, curvezmq.ClientContext)
+ else:
+ assert htex.cert_dir is None
+ assert htex.outgoing_q.zmq_context.cert_dir is None
+ assert htex.incoming_q.zmq_context.cert_dir is None
+ assert htex.command_client.zmq_context.cert_dir is None
diff --git a/parsl/tests/test_htex/test_htex_zmq_binding.py b/parsl/tests/test_htex/test_htex_zmq_binding.py
index 442f0431b6..41de9055b0 100644
--- a/parsl/tests/test_htex/test_htex_zmq_binding.py
+++ b/parsl/tests/test_htex/test_htex_zmq_binding.py
@@ -1,42 +1,82 @@
-import logging
+import pathlib
+from typing import Optional
+from unittest import mock
import psutil
import pytest
import zmq
+from parsl import curvezmq
from parsl.executors.high_throughput.interchange import Interchange
-def test_interchange_binding_no_address():
- ix = Interchange()
+@pytest.fixture
+def encrypted(request: pytest.FixtureRequest):
+ if hasattr(request, "param"):
+ return request.param
+ return True
+
+
+@pytest.fixture
+def cert_dir(encrypted: bool, tmpd_cwd: pathlib.Path):
+ if not encrypted:
+ return None
+ return curvezmq.create_certificates(tmpd_cwd)
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+@mock.patch.object(curvezmq.ServerContext, "socket", return_value=mock.MagicMock())
+def test_interchange_curvezmq_sockets(
+ mock_socket: mock.MagicMock, cert_dir: Optional[str], encrypted: bool
+):
+ address = "127.0.0.1"
+ ix = Interchange(interchange_address=address, cert_dir=cert_dir)
+ assert isinstance(ix.zmq_context, curvezmq.ServerContext)
+ assert ix.zmq_context.encrypted is encrypted
+ assert mock_socket.call_count == 5
+
+
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_interchange_binding_no_address(cert_dir: Optional[str]):
+ ix = Interchange(cert_dir=cert_dir)
assert ix.interchange_address == "*"
-def test_interchange_binding_with_address():
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_interchange_binding_with_address(cert_dir: Optional[str]):
# Using loopback address
address = "127.0.0.1"
- ix = Interchange(interchange_address=address)
+ ix = Interchange(interchange_address=address, cert_dir=cert_dir)
assert ix.interchange_address == address
-def test_interchange_binding_with_non_ipv4_address():
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_interchange_binding_with_non_ipv4_address(cert_dir: Optional[str]):
# Confirm that a ipv4 address is required
address = "localhost"
with pytest.raises(zmq.error.ZMQError):
- Interchange(interchange_address=address)
+ Interchange(interchange_address=address, cert_dir=cert_dir)
-def test_interchange_binding_bad_address():
- """ Confirm that we raise a ZMQError when a bad address is supplied"""
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_interchange_binding_bad_address(cert_dir: Optional[str]):
+ """Confirm that we raise a ZMQError when a bad address is supplied"""
address = "550.0.0.0"
with pytest.raises(zmq.error.ZMQError):
- Interchange(interchange_address=address)
+ Interchange(interchange_address=address, cert_dir=cert_dir)
-def test_limited_interface_binding():
- """ When address is specified the worker_port would be bound to it rather than to 0.0.0.0"""
+@pytest.mark.local
+@pytest.mark.parametrize("encrypted", (True, False), indirect=True)
+def test_limited_interface_binding(cert_dir: Optional[str]):
+ """When address is specified the worker_port would be bound to it rather than to 0.0.0.0"""
address = "127.0.0.1"
- ix = Interchange(interchange_address=address)
+ ix = Interchange(interchange_address=address, cert_dir=cert_dir)
ix.worker_result_port
proc = psutil.Process()
conns = proc.connections(kind="tcp")
diff --git a/parsl/tests/test_regression/test_97_parallelism_0.py b/parsl/tests/test_regression/test_97_parallelism_0.py
index ad47e3971f..de06bcd9df 100644
--- a/parsl/tests/test_regression/test_97_parallelism_0.py
+++ b/parsl/tests/test_regression/test_97_parallelism_0.py
@@ -15,6 +15,7 @@ def local_config() -> Config:
label="htex_local",
worker_debug=True,
cores_per_worker=1,
+ encrypted=True,
provider=LocalProvider(
init_blocks=0,
min_blocks=0,
diff --git a/parsl/tests/test_scaling/test_block_error_handler.py b/parsl/tests/test_scaling/test_block_error_handler.py
index 9d680212e3..7d9ffedd17 100644
--- a/parsl/tests/test_scaling/test_block_error_handler.py
+++ b/parsl/tests/test_scaling/test_block_error_handler.py
@@ -11,7 +11,7 @@
@pytest.mark.local
def test_block_error_handler_false():
mock = Mock()
- htex = HighThroughputExecutor(block_error_handler=False)
+ htex = HighThroughputExecutor(block_error_handler=False, encrypted=True)
assert htex.block_error_handler is noop_error_handler
htex.set_bad_state_and_fail_all = mock
@@ -27,7 +27,7 @@ def test_block_error_handler_false():
@pytest.mark.local
def test_block_error_handler_mock():
handler_mock = Mock()
- htex = HighThroughputExecutor(block_error_handler=handler_mock)
+ htex = HighThroughputExecutor(block_error_handler=handler_mock, encrypted=True)
assert htex.block_error_handler is handler_mock
bad_jobs = {'1': JobStatus(JobState.FAILED),
@@ -43,6 +43,7 @@ def test_block_error_handler_mock():
@pytest.mark.local
def test_simple_error_handler():
htex = HighThroughputExecutor(block_error_handler=simple_error_handler,
+ encrypted=True,
provider=LocalProvider(init_blocks=3))
assert htex.block_error_handler is simple_error_handler
@@ -76,7 +77,7 @@ def test_simple_error_handler():
@pytest.mark.local
def test_windowed_error_handler():
- htex = HighThroughputExecutor(block_error_handler=windowed_error_handler)
+ htex = HighThroughputExecutor(block_error_handler=windowed_error_handler, encrypted=True)
assert htex.block_error_handler is windowed_error_handler
bad_state_mock = Mock()
@@ -110,7 +111,7 @@ def test_windowed_error_handler():
@pytest.mark.local
def test_windowed_error_handler_sorting():
- htex = HighThroughputExecutor(block_error_handler=windowed_error_handler)
+ htex = HighThroughputExecutor(block_error_handler=windowed_error_handler, encrypted=True)
assert htex.block_error_handler is windowed_error_handler
bad_state_mock = Mock()
@@ -136,7 +137,7 @@ def test_windowed_error_handler_sorting():
@pytest.mark.local
def test_windowed_error_handler_with_threshold():
error_handler = partial(windowed_error_handler, threshold=2)
- htex = HighThroughputExecutor(block_error_handler=error_handler)
+ htex = HighThroughputExecutor(block_error_handler=error_handler, encrypted=True)
assert htex.block_error_handler is error_handler
bad_state_mock = Mock()
diff --git a/parsl/tests/test_scaling/test_regression_1621.py b/parsl/tests/test_scaling/test_regression_1621.py
index 2a953c3fc7..bb367c379a 100644
--- a/parsl/tests/test_scaling/test_regression_1621.py
+++ b/parsl/tests/test_scaling/test_regression_1621.py
@@ -49,6 +49,7 @@ def test_one_block(tmpd_cwd):
address="127.0.0.1",
worker_debug=True,
cores_per_worker=1,
+ encrypted=True,
provider=oneshot_provider,
worker_logdir_root=str(tmpd_cwd)
)
diff --git a/parsl/tests/test_scaling/test_scale_down.py b/parsl/tests/test_scaling/test_scale_down.py
index 0a23a5acb5..6ebe441760 100644
--- a/parsl/tests/test_scaling/test_scale_down.py
+++ b/parsl/tests/test_scaling/test_scale_down.py
@@ -28,6 +28,7 @@ def local_config():
label="htex_local",
address="127.0.0.1",
max_workers=1,
+ encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=0,
From 431daefada930a7f2f1789f9d53b88b133cfc0fa Mon Sep 17 00:00:00 2001
From: Reid Mello <30907815+rjmello@users.noreply.github.com>
Date: Tue, 30 Jan 2024 11:32:59 -0500
Subject: [PATCH 3/3] Add section about HTEX encryption to docs
---
docs/userguide/execution.rst | 47 ++++++++++++++++++++++++++++++++++++
1 file changed, 47 insertions(+)
diff --git a/docs/userguide/execution.rst b/docs/userguide/execution.rst
index 0601b38802..1a61e40e73 100644
--- a/docs/userguide/execution.rst
+++ b/docs/userguide/execution.rst
@@ -342,3 +342,50 @@ The following code snippet shows how apps can specify suitable executors in the
def visualize(inputs=(), outputs=()):
bash_array = " ".join(inputs)
return "viz {} -o {}".format(bash_array, outputs[0])
+
+
+Encryption
+----------
+
+Users can enable encryption for the ``HighThroughputExecutor`` by setting its ``encrypted``
+initialization argument to ``True``.
+
+For example,
+
+.. code-block:: python
+
+ from parsl.config import Config
+ from parsl.executors import HighThroughputExecutor
+
+ config = Config(
+ executors=[
+ HighThroughputExecutor(
+ encrypted=True
+ )
+ ]
+ )
+
+Under the hood, we use `CurveZMQ `_ to encrypt all communication channels
+between the executor and related nodes.
+
+Encryption performance
+^^^^^^^^^^^^^^^^^^^^^^
+
+CurveZMQ depends on `libzmq `_ and `libsodium `_,
+which `pyzmq `_ (a Parsl dependency) includes as part of its
+installation via ``pip``. This installation path should work on most systems, but users have
+reported significant performance degradation as a result.
+
+If you experience a significant performance hit after enabling encryption, we recommend installing
+``pyzmq`` with conda:
+
+.. code-block:: bash
+
+ conda install conda-forge::pyzmq
+
+Alternatively, you can `install libsodium `_, then
+`install libzmq `_, then build ``pyzmq`` from source:
+
+.. code-block:: bash
+
+ pip3 install parsl --no-binary pyzmq