Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for IPV6 addresses #3704

Merged
merged 8 commits into from
Nov 25, 2024
Merged
21 changes: 20 additions & 1 deletion parsl/addresses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
so some experimentation will probably be needed to choose the correct one.
"""

import ipaddress
import logging
import platform
import socket
Expand All @@ -17,7 +18,7 @@
except ImportError:
fcntl = None # type: ignore[assignment]
import struct
from typing import Callable, List, Set
from typing import Callable, List, Set, Union

import psutil
import typeguard
Expand Down Expand Up @@ -156,3 +157,21 @@ def get_any_address() -> str:
if addr == '':
raise Exception('Cannot find address of the local machine.')
return addr


def tcp_url(address: str, port: Union[str, int, None] = None) -> str:
"""Construct a tcp url safe for IPv4 and IPv6"""
stripped_address = address.strip('[]')
yadudoc marked this conversation as resolved.
Show resolved Hide resolved
if address == "*":
return "tcp://*"

ip_addr = ipaddress.ip_address(stripped_address)

port_suffix = f":{port}" if port else ""

if ip_addr.version == 6 and port_suffix:
url = f"tcp://[{stripped_address}]{port_suffix}"
else:
url = f"tcp://{stripped_address}{port_suffix}"

return url
4 changes: 4 additions & 0 deletions parsl/curvezmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
except zmq.ZMQError as e:
raise ValueError("Invalid CurveZMQ key format") from e
sock.setsockopt(zmq.CURVE_SERVER, True) # Must come before bind

# This flag enables IPV6 in addition to IPV4
sock.setsockopt(zmq.IPV6, True)
return sock

def term(self):
Expand Down Expand Up @@ -202,4 +205,5 @@ def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
sock.setsockopt(zmq.CURVE_SERVERKEY, server_public_key)
except zmq.ZMQError as e:
raise ValueError("Invalid CurveZMQ key format") from e
sock.setsockopt(zmq.IPV6, True)
return sock
19 changes: 14 additions & 5 deletions parsl/executors/high_throughput/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,19 @@

address : string
An address to connect to the main Parsl process which is reachable from the network in which
workers will be running. This field expects an IPv4 address (xxx.xxx.xxx.xxx).
workers will be running. This field expects an IPv4 or IPv6 address.
Most login nodes on clusters have several network interfaces available, only some of which
can be reached from the compute nodes. This field can be used to limit the executor to listen
only on a specific interface, and limiting connections to the internal network.
By default, the executor will attempt to enumerate and connect through all possible addresses.
Setting an address here overrides the default behavior.
default=None

loopback_address: string
Specify address used for internal communication between executor and interchange.
Supports IPv4 and IPv6 addresses
default=127.0.0.1

worker_ports : (int, int)
Specify the ports to be used by workers to connect to Parsl. If this option is specified,
worker_port_range will not be honored.
Expand Down Expand Up @@ -224,6 +229,7 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin, UsageIn
Parsl will create names as integers starting with 0.

default: empty list

"""

@typeguard.typechecked
Expand All @@ -233,6 +239,7 @@ def __init__(self,
launch_cmd: Optional[str] = None,
interchange_launch_cmd: Optional[Sequence[str]] = None,
address: Optional[str] = None,
loopback_address: str = "127.0.0.1",
worker_ports: Optional[Tuple[int, int]] = None,
worker_port_range: Optional[Tuple[int, int]] = (54000, 55000),
interchange_port_range: Optional[Tuple[int, int]] = (55000, 56000),
Expand Down Expand Up @@ -268,6 +275,8 @@ def __init__(self,
self.address = address
self.address_probe_timeout = address_probe_timeout
self.manager_selector = manager_selector
self.loopback_address = loopback_address

if self.address:
self.all_addresses = address
else:
Expand Down Expand Up @@ -408,13 +417,13 @@ def start(self):
)

self.outgoing_q = zmq_pipes.TasksOutgoing(
"127.0.0.1", self.interchange_port_range, self.cert_dir
self.loopback_address, self.interchange_port_range, self.cert_dir
)
self.incoming_q = zmq_pipes.ResultsIncoming(
"127.0.0.1", self.interchange_port_range, self.cert_dir
self.loopback_address, self.interchange_port_range, self.cert_dir
)
self.command_client = zmq_pipes.CommandClient(
"127.0.0.1", self.interchange_port_range, self.cert_dir
self.loopback_address, self.interchange_port_range, self.cert_dir
)

self._result_queue_thread = None
Expand Down Expand Up @@ -515,7 +524,7 @@ def _start_local_interchange_process(self) -> None:
get the worker task and result ports that the interchange has bound to.
"""

interchange_config = {"client_address": "127.0.0.1",
interchange_config = {"client_address": self.loopback_address,
"client_ports": (self.outgoing_q.port,
self.incoming_q.port,
self.command_client.port),
Expand Down
15 changes: 8 additions & 7 deletions parsl/executors/high_throughput/interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import zmq

from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.high_throughput.errors import ManagerLost, VersionMismatch
from parsl.executors.high_throughput.manager_record import ManagerRecord
Expand Down Expand Up @@ -115,13 +116,13 @@ def __init__(self,
self.zmq_context = curvezmq.ServerContext(self.cert_dir)
self.task_incoming = self.zmq_context.socket(zmq.DEALER)
self.task_incoming.set_hwm(0)
self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))
self.task_incoming.connect(tcp_url(client_address, client_ports[0]))
self.results_outgoing = self.zmq_context.socket(zmq.DEALER)
self.results_outgoing.set_hwm(0)
self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))
self.results_outgoing.connect(tcp_url(client_address, client_ports[1]))

self.command_channel = self.zmq_context.socket(zmq.REP)
self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
self.command_channel.connect(tcp_url(client_address, client_ports[2]))
logger.info("Connected to client")

self.run_id = run_id
Expand All @@ -144,14 +145,14 @@ def __init__(self,
self.worker_task_port = self.worker_ports[0]
self.worker_result_port = self.worker_ports[1]

self.task_outgoing.bind(f"tcp://{self.interchange_address}:{self.worker_task_port}")
self.results_incoming.bind(f"tcp://{self.interchange_address}:{self.worker_result_port}")
self.task_outgoing.bind(tcp_url(self.interchange_address, self.worker_task_port))
self.results_incoming.bind(tcp_url(self.interchange_address, self.worker_result_port))

else:
self.worker_task_port = self.task_outgoing.bind_to_random_port(f"tcp://{self.interchange_address}",
self.worker_task_port = self.task_outgoing.bind_to_random_port(tcp_url(self.interchange_address),
min_port=worker_port_range[0],
max_port=worker_port_range[1], max_tries=100)
self.worker_result_port = self.results_incoming.bind_to_random_port(f"tcp://{self.interchange_address}",
self.worker_result_port = self.results_incoming.bind_to_random_port(tcp_url(self.interchange_address),
min_port=worker_port_range[0],
max_port=worker_port_range[1], max_tries=100)

Expand Down
2 changes: 2 additions & 0 deletions parsl/executors/high_throughput/mpi_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self,
launch_cmd: Optional[str] = None,
interchange_launch_cmd: Optional[str] = None,
address: Optional[str] = None,
loopback_address: str = "127.0.0.1",
worker_ports: Optional[Tuple[int, int]] = None,
worker_port_range: Optional[Tuple[int, int]] = (54000, 55000),
interchange_port_range: Optional[Tuple[int, int]] = (55000, 56000),
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(self,
launch_cmd=launch_cmd,
interchange_launch_cmd=interchange_launch_cmd,
address=address,
loopback_address=loopback_address,
worker_ports=worker_ports,
worker_port_range=worker_port_range,
interchange_port_range=interchange_port_range,
Expand Down
8 changes: 4 additions & 4 deletions parsl/executors/high_throughput/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import zmq
from zmq.utils.monitor import recv_monitor_message

from parsl.addresses import get_all_addresses
from parsl.addresses import get_all_addresses, tcp_url

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +32,8 @@ def probe_addresses(addresses, task_port, timeout=120):
for addr in addresses:
socket = context.socket(zmq.DEALER)
socket.setsockopt(zmq.LINGER, 0)
url = "tcp://{}:{}".format(addr, task_port)
socket.setsockopt(zmq.IPV6, True)
url = tcp_url(addr, task_port)
logger.debug("Trying to connect back on {}".format(url))
socket.connect(url)
addr_map[addr] = {'sock': socket,
Expand Down Expand Up @@ -71,8 +72,7 @@ def __init__(self, addresses, port):

address = probe_addresses(addresses, port)
print("Viable address :", address)
self.task_incoming.connect("tcp://{}:{}".format(address, port))
print("Here")
self.task_incoming.connect(tcp_url(address, port))

def heartbeat(self):
""" Send heartbeat to the incoming task queue
Expand Down
5 changes: 3 additions & 2 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import zmq

from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.execute_task import execute_task
from parsl.executors.high_throughput.errors import WorkerLost
Expand Down Expand Up @@ -159,8 +160,8 @@ def __init__(self, *,
raise Exception("No viable address found")
else:
logger.info("Connection to Interchange successful on {}".format(ix_address))
task_q_url = "tcp://{}:{}".format(ix_address, task_port)
result_q_url = "tcp://{}:{}".format(ix_address, result_port)
task_q_url = tcp_url(ix_address, task_port)
result_q_url = tcp_url(ix_address, result_port)
logger.info("Task url : {}".format(task_q_url))
logger.info("Result url : {}".format(result_q_url))
except Exception:
Expand Down
9 changes: 5 additions & 4 deletions parsl/executors/high_throughput/zmq_pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import zmq

from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.errors import InternalConsistencyError
from parsl.executors.high_throughput.errors import (
CommandClientBadError,
Expand Down Expand Up @@ -52,11 +53,11 @@ def create_socket_and_bind(self):
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
self.zmq_socket.setsockopt(zmq.LINGER, 0)
if self.port is None:
self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(self.ip_address),
self.port = self.zmq_socket.bind_to_random_port(tcp_url(self.ip_address),
min_port=self.port_range[0],
max_port=self.port_range[1])
else:
self.zmq_socket.bind("tcp://{}:{}".format(self.ip_address, self.port))
self.zmq_socket.bind(tcp_url(self.ip_address, self.port))

def run(self, message, max_retries=3, timeout_s=None):
""" This function needs to be fast at the same time aware of the possibility of
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None):
self.zmq_context = curvezmq.ClientContext(cert_dir)
self.zmq_socket = self.zmq_context.socket(zmq.DEALER)
self.zmq_socket.set_hwm(0)
self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address),
self.port = self.zmq_socket.bind_to_random_port(tcp_url(ip_address),
min_port=port_range[0],
max_port=port_range[1])
self.poller = zmq.Poller()
Expand Down Expand Up @@ -202,7 +203,7 @@ def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None):
self.zmq_context = curvezmq.ClientContext(cert_dir)
self.results_receiver = self.zmq_context.socket(zmq.DEALER)
self.results_receiver.set_hwm(0)
self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address),
self.port = self.results_receiver.bind_to_random_port(tcp_url(ip_address),
min_port=port_range[0],
max_port=port_range[1])

Expand Down
1 change: 1 addition & 0 deletions parsl/tests/configs/htex_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def fresh_config():
executors=[
HighThroughputExecutor(
label="htex_local",
loopback_address="::1",
worker_debug=True,
cores_per_worker=1,
encrypted=True,
Expand Down
3 changes: 2 additions & 1 deletion parsl/tests/test_htex/test_zmq_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,5 @@ def test_limited_interface_binding(cert_dir: Optional[str]):

matched_conns = [conn for conn in conns if conn.laddr.port == ix.worker_result_port]
assert len(matched_conns) == 1
assert matched_conns[0].laddr.ip == address
# laddr.ip can return ::ffff:127.0.0.1 when using IPv6
assert address in matched_conns[0].laddr.ip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this always exactly the string "127.0.0.1" or the string "::ffff:127.0.0.1"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all my testing it is.

19 changes: 19 additions & 0 deletions parsl/tests/unit/test_address.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from parsl.addresses import tcp_url


@pytest.mark.local
@pytest.mark.parametrize("address, port,expected", [
("127.0.0.1", 55001, "tcp://127.0.0.1:55001"),
("127.0.0.1", "55001", "tcp://127.0.0.1:55001"),
("127.0.0.1", None, "tcp://127.0.0.1"),
("::1", "55001", "tcp://[::1]:55001"),
("::ffff:127.0.0.1", 55001, "tcp://[::ffff:127.0.0.1]:55001"),
("::ffff:127.0.0.1", None, "tcp://::ffff:127.0.0.1"),
("[::ffff:127.0.0.1]", None, "tcp://::ffff:127.0.0.1"),
])
def test_tcp_url(address, port, expected):
"""Confirm valid address generation"""
result = tcp_url(address, port)
assert result == expected
Loading