Skip to content

Commit

Permalink
Implement CurveZMQ in HTEX
Browse files Browse the repository at this point in the history
The interchange serves as a CurveZMQ server, while the executor and
various managers serve as CurveZMQ clients. Thus, all communication
between these entities is now encrypted.

The HTEX `start` method generates new certs for each run in a private
`certificates/` directory. We generate a single shared client cert
because all clients will have access to this dir.

We enable encryption by default, but users can disable it by setting the
`encrypted` initialization argument for the HTEX.
  • Loading branch information
rjmello committed Jan 26, 2024
1 parent 79187c0 commit d81ceb4
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 52 deletions.
53 changes: 43 additions & 10 deletions parsl/executors/high_throughput/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
UnsupportedFeatureError
)

from parsl import curvezmq
from parsl.executors.status_handling import BlockProviderExecutor
from parsl.providers.base import ExecutionProvider
from parsl.data_provider.staging import Staging
Expand Down Expand Up @@ -174,6 +175,9 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
worker_logdir_root : string
In case of a remote file system, specify the path to where logs will be kept.
encrypted : bool
Flag to enable/disable encryption. Default is True.
"""

@typeguard.typechecked
Expand All @@ -199,7 +203,8 @@ def __init__(self,
poll_period: int = 10,
address_probe_timeout: Optional[int] = None,
worker_logdir_root: Optional[str] = None,
block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True):
block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True,
encrypted: bool = True):

logger.debug("Initializing HighThroughputExecutor")

Expand Down Expand Up @@ -256,6 +261,8 @@ def __init__(self,
self.run_dir = '.'
self.worker_logdir_root = worker_logdir_root
self.cpu_affinity = cpu_affinity
self.encrypted = encrypted
self.cert_dir = None

if not launch_cmd:
launch_cmd = (
Expand All @@ -267,6 +274,7 @@ def __init__(self,
"--poll {poll_period} "
"--task_port={task_port} "
"--result_port={result_port} "
"--cert_dir {cert_dir} "
"--logdir={logdir} "
"--block_id={{block_id}} "
"--hb_period={heartbeat_period} "
Expand All @@ -280,6 +288,16 @@ def __init__(self,

radio_mode = "htex"

@property
def logdir(self):
return "{}/{}".format(self.run_dir, self.label)

@property
def worker_logdir(self):
if self.worker_logdir_root is not None:
return "{}/{}".format(self.worker_logdir_root, self.label)
return self.logdir

def initialize_scaling(self):
"""Compose the launch command and scale out the initial blocks.
"""
Expand All @@ -289,9 +307,6 @@ def initialize_scaling(self):
address_probe_timeout_string = ""
if self.address_probe_timeout:
address_probe_timeout_string = "--address_probe_timeout={}".format(self.address_probe_timeout)
worker_logdir = "{}/{}".format(self.run_dir, self.label)
if self.worker_logdir_root is not None:
worker_logdir = "{}/{}".format(self.worker_logdir_root, self.label)

l_cmd = self.launch_cmd.format(debug=debug_opts,
prefetch_capacity=self.prefetch_capacity,
Expand All @@ -306,7 +321,8 @@ def initialize_scaling(self):
heartbeat_period=self.heartbeat_period,
heartbeat_threshold=self.heartbeat_threshold,
poll_period=self.poll_period,
logdir=worker_logdir,
cert_dir=self.cert_dir,
logdir=self.worker_logdir,
cpu_affinity=self.cpu_affinity,
accelerators=" ".join(self.available_accelerators))
self.launch_cmd = l_cmd
Expand All @@ -327,9 +343,25 @@ def initialize_scaling(self):
def start(self):
"""Create the Interchange process and connect to it.
"""
self.outgoing_q = zmq_pipes.TasksOutgoing("127.0.0.1", self.interchange_port_range)
self.incoming_q = zmq_pipes.ResultsIncoming("127.0.0.1", self.interchange_port_range)
self.command_client = zmq_pipes.CommandClient("127.0.0.1", self.interchange_port_range)
if self.encrypted and self.cert_dir is None:
logger.debug("Creating CurveZMQ certificates")
self.cert_dir = curvezmq.create_certificates(self.logdir)
elif not self.encrypted and self.cert_dir:
raise AttributeError(
"The certificates directory path attribute (cert_dir) is defined, but the "
"encrypted attribute is set to False. You must either change cert_dir to "
"None or encrypted to True."
)

self.outgoing_q = zmq_pipes.TasksOutgoing(
curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
)
self.incoming_q = zmq_pipes.ResultsIncoming(
curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
)
self.command_client = zmq_pipes.CommandClient(
curvezmq.ClientContext(self.cert_dir), "127.0.0.1", self.interchange_port_range
)

self._queue_management_thread = None
self._start_queue_management_thread()
Expand Down Expand Up @@ -450,10 +482,11 @@ def _start_local_interchange_process(self):
"worker_port_range": self.worker_port_range,
"hub_address": self.hub_address,
"hub_port": self.hub_port,
"logdir": "{}/{}".format(self.run_dir, self.label),
"logdir": self.logdir,
"heartbeat_threshold": self.heartbeat_threshold,
"poll_period": self.poll_period,
"logging_level": logging.DEBUG if self.worker_debug else logging.INFO
"logging_level": logging.DEBUG if self.worker_debug else logging.INFO,
"cert_dir": self.cert_dir,
},
daemon=True,
name="HTEX-Interchange"
Expand Down
21 changes: 14 additions & 7 deletions parsl/executors/high_throughput/interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import cast, Any, Dict, NoReturn, Sequence, Set, Optional, Tuple

from parsl import curvezmq
from parsl.utils import setproctitle
from parsl.version import VERSION as PARSL_VERSION
from parsl.serialize import serialize as serialize_object
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self,
logdir: str = ".",
logging_level: int = logging.INFO,
poll_period: int = 10,
cert_dir: Optional[str] = "certificates",
) -> None:
"""
Parameters
Expand Down Expand Up @@ -120,7 +122,10 @@ def __init__(self,
poll_period : int
The main thread polling period, in milliseconds. Default: 10ms
cert_dir : str | None
Path to the certificate directory. Default: 'certificates'
"""
self.cert_dir = cert_dir
self.logdir = logdir
os.makedirs(self.logdir, exist_ok=True)

Expand All @@ -134,15 +139,15 @@ def __init__(self,

logger.info("Attempting connection to client at {} on ports: {},{},{}".format(
client_address, client_ports[0], client_ports[1], client_ports[2]))
self.context = zmq.Context()
self.task_incoming = self.context.socket(zmq.DEALER)
self.zmq_context = curvezmq.ServerContext(self.cert_dir)
self.task_incoming = self.zmq_context.socket(zmq.DEALER)
self.task_incoming.set_hwm(0)
self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))
self.results_outgoing = self.context.socket(zmq.DEALER)
self.results_outgoing = self.zmq_context.socket(zmq.DEALER)
self.results_outgoing.set_hwm(0)
self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))

self.command_channel = self.context.socket(zmq.REP)
self.command_channel = self.zmq_context.socket(zmq.REP)
self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
logger.info("Connected to client")

Expand All @@ -155,9 +160,9 @@ def __init__(self,
self.worker_ports = worker_ports
self.worker_port_range = worker_port_range

self.task_outgoing = self.context.socket(zmq.ROUTER)
self.task_outgoing = self.zmq_context.socket(zmq.ROUTER)
self.task_outgoing.set_hwm(0)
self.results_incoming = self.context.socket(zmq.ROUTER)
self.results_incoming = self.zmq_context.socket(zmq.ROUTER)
self.results_incoming.set_hwm(0)

if self.worker_ports:
Expand Down Expand Up @@ -241,7 +246,8 @@ def task_puller(self) -> NoReturn:
def _create_monitoring_channel(self) -> Optional[zmq.Socket]:
if self.hub_address and self.hub_port:
logger.info("Connecting to monitoring")
hub_channel = self.context.socket(zmq.DEALER)
# This is a one-off because monitoring is unencrypted
hub_channel = zmq.Context().socket(zmq.DEALER)
hub_channel.set_hwm(0)
hub_channel.connect("tcp://{}:{}".format(self.hub_address, self.hub_port))
logger.info("Monitoring enabled and connected to hub")
Expand Down Expand Up @@ -379,6 +385,7 @@ def start(self) -> None:
self.expire_bad_managers(interesting_managers, hub_channel)
self.process_tasks_to_send(interesting_managers)

self.zmq_context.destroy()
delta = time.time() - start
logger.info("Processed {} tasks in {} seconds".format(self.count, delta))
logger.warning("Exiting")
Expand Down
22 changes: 15 additions & 7 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from multiprocessing.managers import DictProxy
from multiprocessing.sharedctypes import Synchronized

from parsl import curvezmq
from parsl.process_loggers import wrap_with_logs

from parsl.version import VERSION as PARSL_VERSION
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.high_throughput.errors import WorkerLost
Expand Down Expand Up @@ -63,7 +63,8 @@ def __init__(self, *,
heartbeat_period,
poll_period,
cpu_affinity,
available_accelerators: Sequence[str]):
available_accelerators: Sequence[str],
cert_dir: Optional[str]):
"""
Parameters
----------
Expand Down Expand Up @@ -118,6 +119,8 @@ def __init__(self, *,
available_accelerators: list of str
List of accelerators available to the workers.
cert_dir : str | None
Path to the certificate directory.
"""

logger.info("Manager started")
Expand All @@ -137,15 +140,16 @@ def __init__(self, *,
print("Failed to find a viable address to connect to interchange. Exiting")
exit(5)

self.context = zmq.Context()
self.task_incoming = self.context.socket(zmq.DEALER)
self.cert_dir = cert_dir
self.zmq_context = curvezmq.ClientContext(self.cert_dir)
self.task_incoming = self.zmq_context.socket(zmq.DEALER)
self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
# Linger is set to 0, so that the manager can exit even when there might be
# messages in the pipe
self.task_incoming.setsockopt(zmq.LINGER, 0)
self.task_incoming.connect(task_q_url)

self.result_outgoing = self.context.socket(zmq.DEALER)
self.result_outgoing = self.zmq_context.socket(zmq.DEALER)
self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
self.result_outgoing.setsockopt(zmq.LINGER, 0)
self.result_outgoing.connect(result_q_url)
Expand Down Expand Up @@ -468,7 +472,7 @@ def start(self):

self.task_incoming.close()
self.result_outgoing.close()
self.context.term()
self.zmq_context.term()
delta = time.time() - start
logger.info("process_worker_pool ran for {} seconds".format(delta))
return
Expand Down Expand Up @@ -703,6 +707,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_
help="Enable logging at DEBUG level")
parser.add_argument("-a", "--addresses", default='',
help="Comma separated list of addresses at which the interchange could be reached")
parser.add_argument("--cert_dir", required=True,
help="Path to certificate directory.")
parser.add_argument("-l", "--logdir", default="process_worker_pool_logs",
help="Process worker pool log directory")
parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1],
Expand Down Expand Up @@ -746,6 +752,7 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_

logger.info("Python version: {}".format(sys.version))
logger.info("Debug logging: {}".format(args.debug))
logger.info("Certificates dir: {}".format(args.cert_dir))
logger.info("Log dir: {}".format(args.logdir))
logger.info("Manager ID: {}".format(args.uid))
logger.info("Block ID: {}".format(args.block_id))
Expand Down Expand Up @@ -777,7 +784,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_
heartbeat_period=int(args.hb_period),
poll_period=int(args.poll),
cpu_affinity=args.cpu_affinity,
available_accelerators=args.available_accelerators)
available_accelerators=args.available_accelerators,
cert_dir=None if args.cert_dir == "None" else args.cert_dir)
manager.start()

except Exception:
Expand Down
Loading

0 comments on commit d81ceb4

Please sign in to comment.