Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect HTEX communication with CurveZMQ #3030

Merged
merged 4 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions docs/userguide/execution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,50 @@ The following code snippet shows how apps can specify suitable executors in the
def visualize(inputs=(), outputs=()):
bash_array = " ".join(inputs)
return "viz {} -o {}".format(bash_array, outputs[0])


Encryption
----------

Users can enable encryption for the ``HighThroughputExecutor`` by setting its ``encrypted``
initialization argument to ``True``.

For example,

.. code-block:: python

from parsl.config import Config
from parsl.executors import HighThroughputExecutor

config = Config(
executors=[
HighThroughputExecutor(
encrypted=True
)
]
)

Under the hood, we use `CurveZMQ <http://curvezmq.org/>`_ to encrypt all communication channels
between the executor and related nodes.

Encryption performance
^^^^^^^^^^^^^^^^^^^^^^

CurveZMQ depends on `libzmq <https://github.com/zeromq/libzmq>`_ and `libsodium <https://github.com/jedisct1/libsodium>`_,
which `pyzmq <https://github.com/zeromq/pyzmq>`_ (a Parsl dependency) includes as part of its
installation via ``pip``. This installation path should work on most systems, but users have
reported significant performance degradation as a result.

If you experience a significant performance hit after enabling encryption, we recommend installing
``pyzmq`` with conda:

.. code-block:: bash

conda install conda-forge::pyzmq

Alternatively, you can `install libsodium <https://doc.libsodium.org/installation>`_, then
`install libzmq <https://zeromq.org/download/>`_, then build ``pyzmq`` from source:

.. code-block:: bash

pip3 install parsl --no-binary pyzmq
202 changes: 202 additions & 0 deletions parsl/curvezmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import os
from abc import ABCMeta, abstractmethod
from typing import Optional, Tuple, Union

import zmq
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator


def create_certificates(base_dir: Union[str, os.PathLike]):
"""Create server and client certificates in a private directory.

This will overwrite existing certificate files.

Parameters
----------
base_dir : str | os.PathLike
Parent directory of the private certificates directory.
"""
cert_dir = os.path.join(base_dir, "certificates")
os.makedirs(cert_dir, mode=0o700, exist_ok=True)
rjmello marked this conversation as resolved.
Show resolved Hide resolved

zmq.auth.create_certificates(cert_dir, name="server")
zmq.auth.create_certificates(cert_dir, name="client")

return cert_dir


def _load_certificate(
cert_dir: Union[str, os.PathLike], name: str
) -> Tuple[bytes, bytes]:
if os.stat(cert_dir).st_mode & 0o777 != 0o700:
raise OSError(f"The certificates directory must be private: {cert_dir}")

# pyzmq creates secret key files with the '.key_secret' extension
# Ref: https://github.com/zeromq/pyzmq/blob/ae615d4097ccfbc6b5c17de60355cbe6e00a6065/zmq/auth/certs.py#L73
secret_key_file = os.path.join(cert_dir, f"{name}.key_secret")
public_key, secret_key = zmq.auth.load_certificate(secret_key_file)
rjmello marked this conversation as resolved.
Show resolved Hide resolved
if secret_key is None:
raise ValueError(f"No secret key found in {secret_key_file}")

return public_key, secret_key


class BaseContext(metaclass=ABCMeta):
"""Base CurveZMQ context"""

def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None:
self.cert_dir = cert_dir
self._ctx = zmq.Context()

def __del__(self):
self.destroy()
rjmello marked this conversation as resolved.
Show resolved Hide resolved

@property
def encrypted(self):
"""Indicates whether encryption is enabled.

False (disabled) when self.cert_dir is set to None.
"""
return self.cert_dir is not None

@property
def closed(self):
return self._ctx.closed

@abstractmethod
def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
"""Create a socket associated with this context.

This method will apply all necessary certificates and socket options.

Parameters
----------
socket_type : int
The socket type, which can be any of the 0MQ socket types: REQ, REP,
PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc.

args:
passed to the zmq.Context.socket method.

kwargs:
passed to the zmq.Context.socket method.
"""
...

def term(self):
"""Terminate the context."""
self._ctx.term()

def destroy(self, linger: Optional[int] = None):
"""Close all sockets associated with this context and then terminate
the context.

.. warning::

destroy involves calling ``zmq_close()``, which is **NOT** threadsafe.
If there are active sockets in other threads, this must not be called.

Parameters
----------
linger : int, optional
If specified, set LINGER on sockets prior to closing them.
"""
self._ctx.destroy(linger)
rjmello marked this conversation as resolved.
Show resolved Hide resolved

def recreate(self, linger: Optional[int] = None):
"""Destroy then recreate the context.

Parameters
----------
linger : int, optional
If specified, set LINGER on sockets prior to closing them.
"""
self.destroy(linger)
self._ctx = zmq.Context()


class ServerContext(BaseContext):
"""CurveZMQ server context

We create server sockets via the `ctx.socket` method, which automatically
applies the necessary certificates and socket options.

We handle client certificate authentication in a separate dedicated thread.

Parameters
----------
cert_dir : str | os.PathLike | None
Path to the certificate directory. Setting this to None will disable encryption.
"""

def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None:
super().__init__(cert_dir)
self.auth_thread = None
if self.encrypted:
self.auth_thread = self._start_auth_thread()

def _start_auth_thread(self) -> ThreadAuthenticator:
auth_thread = ThreadAuthenticator(self._ctx)
auth_thread.start()
# Only allow certs that are in the cert dir
assert self.cert_dir # For mypy
auth_thread.configure_curve(domain="*", location=self.cert_dir)
return auth_thread

def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
sock = self._ctx.socket(socket_type, *args, **kwargs)
if self.encrypted:
assert self.cert_dir # For mypy
_, secret_key = _load_certificate(self.cert_dir, name="server")
try:
# Only the clients need the server's public key to encrypt
# messages and verify the server's identity.
# Ref: http://curvezmq.org/page:read-the-docs
sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
except zmq.ZMQError as e:
raise ValueError("Invalid CurveZMQ key format") from e
sock.setsockopt(zmq.CURVE_SERVER, True) # Must come before bind
return sock

def term(self):
if self.auth_thread:
self.auth_thread.stop()
super().term()

def destroy(self, linger: Optional[int] = None):
if self.auth_thread:
self.auth_thread.stop()
super().destroy(linger)

def recreate(self, linger: Optional[int] = None):
super().recreate(linger)
if self.auth_thread:
self.auth_thread = self._start_auth_thread()


class ClientContext(BaseContext):
"""CurveZMQ client context

We create client sockets via the `ctx.socket` method, which automatically
applies the necessary certificates and socket options.

Parameters
----------
cert_dir : str | os.PathLike | None
Path to the certificate directory. Setting this to None will disable encryption.
"""

def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
sock = self._ctx.socket(socket_type, *args, **kwargs)
if self.encrypted:
assert self.cert_dir # For mypy
public_key, secret_key = _load_certificate(self.cert_dir, name="client")
server_public_key, _ = _load_certificate(self.cert_dir, name="server")
try:
sock.setsockopt(zmq.CURVE_PUBLICKEY, public_key)
sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
sock.setsockopt(zmq.CURVE_SERVERKEY, server_public_key)
except zmq.ZMQError as e:
raise ValueError("Invalid CurveZMQ key format") from e
return sock
53 changes: 43 additions & 10 deletions parsl/executors/high_throughput/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
UnsupportedFeatureError
)

from parsl import curvezmq
from parsl.executors.status_handling import BlockProviderExecutor
from parsl.providers.base import ExecutionProvider
from parsl.data_provider.staging import Staging
Expand Down Expand Up @@ -174,6 +175,9 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):

worker_logdir_root : string
In case of a remote file system, specify the path to where logs will be kept.

encrypted : bool
Flag to enable/disable encryption (CurveZMQ). Default is False.
"""

@typeguard.typechecked
Expand All @@ -199,7 +203,8 @@ def __init__(self,
poll_period: int = 10,
address_probe_timeout: Optional[int] = None,
worker_logdir_root: Optional[str] = None,
block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True):
block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True,
encrypted: bool = False):

logger.debug("Initializing HighThroughputExecutor")

Expand Down Expand Up @@ -256,6 +261,8 @@ def __init__(self,
self.run_dir = '.'
self.worker_logdir_root = worker_logdir_root
self.cpu_affinity = cpu_affinity
self.encrypted = encrypted
self.cert_dir = None

if not launch_cmd:
launch_cmd = (
Expand All @@ -267,6 +274,7 @@ def __init__(self,
"--poll {poll_period} "
"--task_port={task_port} "
"--result_port={result_port} "
"--cert_dir {cert_dir} "
"--logdir={logdir} "
"--block_id={{block_id}} "
"--hb_period={heartbeat_period} "
Expand All @@ -280,6 +288,16 @@ def __init__(self,

radio_mode = "htex"

@property
def logdir(self):
return "{}/{}".format(self.run_dir, self.label)

@property
def worker_logdir(self):
if self.worker_logdir_root is not None:
return "{}/{}".format(self.worker_logdir_root, self.label)
return self.logdir

def initialize_scaling(self):
"""Compose the launch command and scale out the initial blocks.
"""
Expand All @@ -289,9 +307,6 @@ def initialize_scaling(self):
address_probe_timeout_string = ""
if self.address_probe_timeout:
address_probe_timeout_string = "--address_probe_timeout={}".format(self.address_probe_timeout)
worker_logdir = "{}/{}".format(self.run_dir, self.label)
if self.worker_logdir_root is not None:
worker_logdir = "{}/{}".format(self.worker_logdir_root, self.label)

l_cmd = self.launch_cmd.format(debug=debug_opts,
prefetch_capacity=self.prefetch_capacity,
Expand All @@ -306,7 +321,8 @@ def initialize_scaling(self):
heartbeat_period=self.heartbeat_period,
heartbeat_threshold=self.heartbeat_threshold,
poll_period=self.poll_period,
logdir=worker_logdir,
cert_dir=self.cert_dir,
logdir=self.worker_logdir,
cpu_affinity=self.cpu_affinity,
accelerators=" ".join(self.available_accelerators))
self.launch_cmd = l_cmd
Expand All @@ -327,9 +343,25 @@ def initialize_scaling(self):
def start(self):
"""Create the Interchange process and connect to it.
"""
self.outgoing_q = zmq_pipes.TasksOutgoing("127.0.0.1", self.interchange_port_range)
self.incoming_q = zmq_pipes.ResultsIncoming("127.0.0.1", self.interchange_port_range)
self.command_client = zmq_pipes.CommandClient("127.0.0.1", self.interchange_port_range)
if self.encrypted and self.cert_dir is None:
logger.debug("Creating CurveZMQ certificates")
self.cert_dir = curvezmq.create_certificates(self.logdir)
elif not self.encrypted and self.cert_dir:
raise AttributeError(
"The certificates directory path attribute (cert_dir) is defined, but the "
"encrypted attribute is set to False. You must either change cert_dir to "
"None or encrypted to True."
)

self.outgoing_q = zmq_pipes.TasksOutgoing(
curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
)
self.incoming_q = zmq_pipes.ResultsIncoming(
curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
)
self.command_client = zmq_pipes.CommandClient(
curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
)

self._queue_management_thread = None
self._start_queue_management_thread()
Expand Down Expand Up @@ -450,10 +482,11 @@ def _start_local_interchange_process(self):
"worker_port_range": self.worker_port_range,
"hub_address": self.hub_address,
"hub_port": self.hub_port,
"logdir": "{}/{}".format(self.run_dir, self.label),
"logdir": self.logdir,
"heartbeat_threshold": self.heartbeat_threshold,
"poll_period": self.poll_period,
"logging_level": logging.DEBUG if self.worker_debug else logging.INFO
"logging_level": logging.DEBUG if self.worker_debug else logging.INFO,
"cert_dir": self.cert_dir,
},
daemon=True,
name="HTEX-Interchange"
Expand Down
Loading
Loading