Skip to content

Commit

Permalink
Adding support for IPV6 addresses (#3704)
Browse files Browse the repository at this point in the history
# Description

This PR adds a new option: `HighThoughputExecutor(loopback_address: str
= "127.0.0.1")` which can be used to specify the internal address used
by HTEX for communication between the executor and the interchange. In
addition, all ZMQ sockets are now are set to having IPv6 enabled.

The test config `htex_local` has been updated to use
`loopback_address="::1"` for testing.

# Changed Behaviour

* IPv6 support is enabled on all HTEX ZMQ components.
* HTEX now supports a `loopback_address` which allows configuring the
address used for internal communication

Fixes # (issue)

## Type of change

Choose which options apply, and delete the ones which do not apply.

- New feature
- Update to human readable text: Documentation/error messages/comments
  • Loading branch information
yadudoc authored Nov 25, 2024
1 parent 07dfb42 commit 1c3e509
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 25 deletions.
20 changes: 19 additions & 1 deletion parsl/addresses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
so some experimentation will probably be needed to choose the correct one.
"""

import ipaddress
import logging
import platform
import socket
Expand All @@ -17,7 +18,7 @@
except ImportError:
fcntl = None # type: ignore[assignment]
import struct
from typing import Callable, List, Set
from typing import Callable, List, Set, Union

import psutil
import typeguard
Expand Down Expand Up @@ -156,3 +157,20 @@ def get_any_address() -> str:
if addr == '':
raise Exception('Cannot find address of the local machine.')
return addr


def tcp_url(address: str, port: Union[str, int, None] = None) -> str:
"""Construct a tcp url safe for IPv4 and IPv6"""
if address == "*":
return "tcp://*"

ip_addr = ipaddress.ip_address(address)

port_suffix = f":{port}" if port else ""

if ip_addr.version == 6 and port_suffix:
url = f"tcp://[{address}]{port_suffix}"
else:
url = f"tcp://{address}{port_suffix}"

return url
4 changes: 4 additions & 0 deletions parsl/curvezmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
except zmq.ZMQError as e:
raise ValueError("Invalid CurveZMQ key format") from e
sock.setsockopt(zmq.CURVE_SERVER, True) # Must come before bind

# This flag enables IPV6 in addition to IPV4
sock.setsockopt(zmq.IPV6, True)
return sock

def term(self):
Expand Down Expand Up @@ -202,4 +205,5 @@ def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
sock.setsockopt(zmq.CURVE_SERVERKEY, server_public_key)
except zmq.ZMQError as e:
raise ValueError("Invalid CurveZMQ key format") from e
sock.setsockopt(zmq.IPV6, True)
return sock
19 changes: 14 additions & 5 deletions parsl/executors/high_throughput/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,19 @@
address : string
An address to connect to the main Parsl process which is reachable from the network in which
workers will be running. This field expects an IPv4 address (xxx.xxx.xxx.xxx).
workers will be running. This field expects an IPv4 or IPv6 address.
Most login nodes on clusters have several network interfaces available, only some of which
can be reached from the compute nodes. This field can be used to limit the executor to listen
only on a specific interface, and limiting connections to the internal network.
By default, the executor will attempt to enumerate and connect through all possible addresses.
Setting an address here overrides the default behavior.
default=None
loopback_address: string
Specify address used for internal communication between executor and interchange.
Supports IPv4 and IPv6 addresses
default=127.0.0.1
worker_ports : (int, int)
Specify the ports to be used by workers to connect to Parsl. If this option is specified,
worker_port_range will not be honored.
Expand Down Expand Up @@ -224,6 +229,7 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin, UsageIn
Parsl will create names as integers starting with 0.
default: empty list
"""

@typeguard.typechecked
Expand All @@ -233,6 +239,7 @@ def __init__(self,
launch_cmd: Optional[str] = None,
interchange_launch_cmd: Optional[Sequence[str]] = None,
address: Optional[str] = None,
loopback_address: str = "127.0.0.1",
worker_ports: Optional[Tuple[int, int]] = None,
worker_port_range: Optional[Tuple[int, int]] = (54000, 55000),
interchange_port_range: Optional[Tuple[int, int]] = (55000, 56000),
Expand Down Expand Up @@ -268,6 +275,8 @@ def __init__(self,
self.address = address
self.address_probe_timeout = address_probe_timeout
self.manager_selector = manager_selector
self.loopback_address = loopback_address

if self.address:
self.all_addresses = address
else:
Expand Down Expand Up @@ -408,13 +417,13 @@ def start(self):
)

self.outgoing_q = zmq_pipes.TasksOutgoing(
"127.0.0.1", self.interchange_port_range, self.cert_dir
self.loopback_address, self.interchange_port_range, self.cert_dir
)
self.incoming_q = zmq_pipes.ResultsIncoming(
"127.0.0.1", self.interchange_port_range, self.cert_dir
self.loopback_address, self.interchange_port_range, self.cert_dir
)
self.command_client = zmq_pipes.CommandClient(
"127.0.0.1", self.interchange_port_range, self.cert_dir
self.loopback_address, self.interchange_port_range, self.cert_dir
)

self._result_queue_thread = None
Expand Down Expand Up @@ -515,7 +524,7 @@ def _start_local_interchange_process(self) -> None:
get the worker task and result ports that the interchange has bound to.
"""

interchange_config = {"client_address": "127.0.0.1",
interchange_config = {"client_address": self.loopback_address,
"client_ports": (self.outgoing_q.port,
self.incoming_q.port,
self.command_client.port),
Expand Down
15 changes: 8 additions & 7 deletions parsl/executors/high_throughput/interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import zmq

from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.high_throughput.errors import ManagerLost, VersionMismatch
from parsl.executors.high_throughput.manager_record import ManagerRecord
Expand Down Expand Up @@ -115,13 +116,13 @@ def __init__(self,
self.zmq_context = curvezmq.ServerContext(self.cert_dir)
self.task_incoming = self.zmq_context.socket(zmq.DEALER)
self.task_incoming.set_hwm(0)
self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))
self.task_incoming.connect(tcp_url(client_address, client_ports[0]))
self.results_outgoing = self.zmq_context.socket(zmq.DEALER)
self.results_outgoing.set_hwm(0)
self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))
self.results_outgoing.connect(tcp_url(client_address, client_ports[1]))

self.command_channel = self.zmq_context.socket(zmq.REP)
self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
self.command_channel.connect(tcp_url(client_address, client_ports[2]))
logger.info("Connected to client")

self.run_id = run_id
Expand All @@ -144,14 +145,14 @@ def __init__(self,
self.worker_task_port = self.worker_ports[0]
self.worker_result_port = self.worker_ports[1]

self.task_outgoing.bind(f"tcp://{self.interchange_address}:{self.worker_task_port}")
self.results_incoming.bind(f"tcp://{self.interchange_address}:{self.worker_result_port}")
self.task_outgoing.bind(tcp_url(self.interchange_address, self.worker_task_port))
self.results_incoming.bind(tcp_url(self.interchange_address, self.worker_result_port))

else:
self.worker_task_port = self.task_outgoing.bind_to_random_port(f"tcp://{self.interchange_address}",
self.worker_task_port = self.task_outgoing.bind_to_random_port(tcp_url(self.interchange_address),
min_port=worker_port_range[0],
max_port=worker_port_range[1], max_tries=100)
self.worker_result_port = self.results_incoming.bind_to_random_port(f"tcp://{self.interchange_address}",
self.worker_result_port = self.results_incoming.bind_to_random_port(tcp_url(self.interchange_address),
min_port=worker_port_range[0],
max_port=worker_port_range[1], max_tries=100)

Expand Down
2 changes: 2 additions & 0 deletions parsl/executors/high_throughput/mpi_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self,
launch_cmd: Optional[str] = None,
interchange_launch_cmd: Optional[str] = None,
address: Optional[str] = None,
loopback_address: str = "127.0.0.1",
worker_ports: Optional[Tuple[int, int]] = None,
worker_port_range: Optional[Tuple[int, int]] = (54000, 55000),
interchange_port_range: Optional[Tuple[int, int]] = (55000, 56000),
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(self,
launch_cmd=launch_cmd,
interchange_launch_cmd=interchange_launch_cmd,
address=address,
loopback_address=loopback_address,
worker_ports=worker_ports,
worker_port_range=worker_port_range,
interchange_port_range=interchange_port_range,
Expand Down
8 changes: 4 additions & 4 deletions parsl/executors/high_throughput/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import zmq
from zmq.utils.monitor import recv_monitor_message

from parsl.addresses import get_all_addresses
from parsl.addresses import get_all_addresses, tcp_url

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +32,8 @@ def probe_addresses(addresses, task_port, timeout=120):
for addr in addresses:
socket = context.socket(zmq.DEALER)
socket.setsockopt(zmq.LINGER, 0)
url = "tcp://{}:{}".format(addr, task_port)
socket.setsockopt(zmq.IPV6, True)
url = tcp_url(addr, task_port)
logger.debug("Trying to connect back on {}".format(url))
socket.connect(url)
addr_map[addr] = {'sock': socket,
Expand Down Expand Up @@ -71,8 +72,7 @@ def __init__(self, addresses, port):

address = probe_addresses(addresses, port)
print("Viable address :", address)
self.task_incoming.connect("tcp://{}:{}".format(address, port))
print("Here")
self.task_incoming.connect(tcp_url(address, port))

def heartbeat(self):
""" Send heartbeat to the incoming task queue
Expand Down
5 changes: 3 additions & 2 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import zmq

from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.execute_task import execute_task
from parsl.executors.high_throughput.errors import WorkerLost
Expand Down Expand Up @@ -159,8 +160,8 @@ def __init__(self, *,
raise Exception("No viable address found")
else:
logger.info("Connection to Interchange successful on {}".format(ix_address))
task_q_url = "tcp://{}:{}".format(ix_address, task_port)
result_q_url = "tcp://{}:{}".format(ix_address, result_port)
task_q_url = tcp_url(ix_address, task_port)
result_q_url = tcp_url(ix_address, result_port)
logger.info("Task url : {}".format(task_q_url))
logger.info("Result url : {}".format(result_q_url))
except Exception:
Expand Down
9 changes: 5 additions & 4 deletions parsl/executors/high_throughput/zmq_pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import zmq

from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.errors import InternalConsistencyError
from parsl.executors.high_throughput.errors import (
CommandClientBadError,
Expand Down Expand Up @@ -52,11 +53,11 @@ def create_socket_and_bind(self):
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
self.zmq_socket.setsockopt(zmq.LINGER, 0)
if self.port is None:
self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(self.ip_address),
self.port = self.zmq_socket.bind_to_random_port(tcp_url(self.ip_address),
min_port=self.port_range[0],
max_port=self.port_range[1])
else:
self.zmq_socket.bind("tcp://{}:{}".format(self.ip_address, self.port))
self.zmq_socket.bind(tcp_url(self.ip_address, self.port))

def run(self, message, max_retries=3, timeout_s=None):
""" This function needs to be fast at the same time aware of the possibility of
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None):
self.zmq_context = curvezmq.ClientContext(cert_dir)
self.zmq_socket = self.zmq_context.socket(zmq.DEALER)
self.zmq_socket.set_hwm(0)
self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address),
self.port = self.zmq_socket.bind_to_random_port(tcp_url(ip_address),
min_port=port_range[0],
max_port=port_range[1])
self.poller = zmq.Poller()
Expand Down Expand Up @@ -202,7 +203,7 @@ def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None):
self.zmq_context = curvezmq.ClientContext(cert_dir)
self.results_receiver = self.zmq_context.socket(zmq.DEALER)
self.results_receiver.set_hwm(0)
self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address),
self.port = self.results_receiver.bind_to_random_port(tcp_url(ip_address),
min_port=port_range[0],
max_port=port_range[1])

Expand Down
1 change: 1 addition & 0 deletions parsl/tests/configs/htex_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def fresh_config():
executors=[
HighThroughputExecutor(
label="htex_local",
loopback_address="::1",
worker_debug=True,
cores_per_worker=1,
encrypted=True,
Expand Down
5 changes: 3 additions & 2 deletions parsl/tests/test_htex/test_zmq_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_interchange_binding_with_non_ipv4_address(cert_dir: Optional[str]):
def test_interchange_binding_bad_address(cert_dir: Optional[str]):
"""Confirm that we raise a ZMQError when a bad address is supplied"""
address = "550.0.0.0"
with pytest.raises(zmq.error.ZMQError):
with pytest.raises(ValueError):
make_interchange(interchange_address=address, cert_dir=cert_dir)


Expand All @@ -103,4 +103,5 @@ def test_limited_interface_binding(cert_dir: Optional[str]):

matched_conns = [conn for conn in conns if conn.laddr.port == ix.worker_result_port]
assert len(matched_conns) == 1
assert matched_conns[0].laddr.ip == address
# laddr.ip can return ::ffff:127.0.0.1 when using IPv6
assert address in matched_conns[0].laddr.ip
20 changes: 20 additions & 0 deletions parsl/tests/unit/test_address.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from parsl.addresses import tcp_url


@pytest.mark.local
@pytest.mark.parametrize("address, port,expected", [
("127.0.0.1", 55001, "tcp://127.0.0.1:55001"),
("127.0.0.1", "55001", "tcp://127.0.0.1:55001"),
("127.0.0.1", None, "tcp://127.0.0.1"),
("::1", "55001", "tcp://[::1]:55001"),
("::ffff:127.0.0.1", 55001, "tcp://[::ffff:127.0.0.1]:55001"),
("::ffff:127.0.0.1", None, "tcp://::ffff:127.0.0.1"),
("::ffff:127.0.0.1", None, "tcp://::ffff:127.0.0.1"),
("*", None, "tcp://*"),
])
def test_tcp_url(address, port, expected):
"""Confirm valid address generation"""
result = tcp_url(address, port)
assert result == expected

0 comments on commit 1c3e509

Please sign in to comment.