Skip to content

Commit

Permalink
remove 127.0.0.1 usage and enable ipv6 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiGlobus committed Dec 2, 2024
1 parent 02d4c98 commit c04a101
Show file tree
Hide file tree
Showing 21 changed files with 157 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
if linger is not None:
self.zmq_socket.setsockopt(zmq.LINGER, linger)

self.zmq_socket.setsockopt(zmq.IPV6, True)

# all zmq setsockopt calls must be done before bind/connect is called
if self.mode == "server":
self.zmq_socket.bind(f"tcp://*:{port}")
Expand Down Expand Up @@ -121,7 +123,7 @@ def setup_server_auth(self):
# Start an authenticator for this context.
self.auth = ThreadAuthenticator(self.context)
self.auth.start()
self.auth.allow("127.0.0.1")
self.auth.allow("::1")
# Tell the authenticator how to handle CURVE requests

if not self.ironhouse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import multiprocessing
import os
import queue
import socket
import threading
import time
import typing as t
Expand Down Expand Up @@ -299,14 +300,15 @@ def __init__(
self.endpoint_id = endpoint_id
self._task_counter = 0

try:
ipaddress.ip_address(address=address)
except Exception:
log.critical(
f"Invalid address supplied: {address}. "
"Please use a valid IPv4 or IPv6 address"
if not HighThroughputEngine.is_hostname_or_ip(address):
err_msg = (
# yes, suggesting `=` formatter, so it's clear which argument.
f"Invalid address: {address=}\n\n"
"Expecting an interface name, hostname, IPv4 address, or IPv6 address."
)
raise
log.critical(err_msg)
raise ValueError(err_msg)

self.address = address
self.worker_ports = worker_ports
self.worker_port_range = worker_port_range
Expand Down Expand Up @@ -376,14 +378,10 @@ def start(
self.run_dir = run_dir
self.endpoint_id = endpoint_id

self.outgoing_q = zmq_pipes.TasksOutgoing(
"127.0.0.1", self.interchange_port_range
)
self.incoming_q = zmq_pipes.ResultsIncoming(
"127.0.0.1", self.interchange_port_range
)
self.outgoing_q = zmq_pipes.TasksOutgoing("::1", self.interchange_port_range)
self.incoming_q = zmq_pipes.ResultsIncoming("::1", self.interchange_port_range)
self.command_client = zmq_pipes.CommandClient(
"127.0.0.1", self.interchange_port_range
"::1", self.interchange_port_range
)

self.is_alive = True
Expand Down Expand Up @@ -419,6 +417,27 @@ def start(

return self.outgoing_q.port, self.incoming_q.port, self.command_client.port

@staticmethod
def is_hostname_or_ip(hostname_or_ip: str) -> bool:
"""
Utility method to verify that the input is a valid hostname or
IP address.
"""
if not hostname_or_ip:
return False
else:
try:
socket.gethostbyname(hostname_or_ip)
return True
except socket.gaierror:
# Not a hostname, now check IP
pass
try:
ipaddress.ip_address(address=hostname_or_ip)
except ValueError:
return False
return True

def _start_local_interchange_process(self):
"""Starts the interchange process locally
Expand All @@ -431,7 +450,7 @@ def _start_local_interchange_process(self):
name="Engine-Interchange",
args=(comm_q,),
kwargs={
"client_address": "127.0.0.1", # engine and ix are on the same node
"client_address": "127.0.0.1", # engine and ix are on same node
"client_ports": (
self.outgoing_q.port,
self.incoming_q.port,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def __init__(
client_address : str
The ip address at which the parsl client can be reached.
Default: "127.0.0.1"
Default: "localhost"
interchange_address : str
The ip address at which the workers will be able to reach the Interchange.
Default: "127.0.0.1"
Default: "localhost"
client_ports : tuple[int, int, int]
The ports at which the client can be reached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class Manager:

def __init__(
self,
task_q_url="tcp://127.0.0.1:50097",
result_q_url="tcp://127.0.0.1:50098",
task_q_url="tcp://localhost:50097",
result_q_url="tcp://localhost:50098",
max_queue_size=10,
cores_per_worker=1,
available_accelerators: list[str] | None = None,
Expand Down Expand Up @@ -171,6 +171,7 @@ def __init__(
# Linger is set to 0, so that the manager can exit even when there might be
# messages in the pipe
self.task_incoming.setsockopt(zmq.LINGER, 0)
self.task_incoming.setsockopt(zmq.IPV6, True)
self.task_incoming.connect(task_q_url)

self.logdir = logdir
Expand All @@ -179,6 +180,7 @@ def __init__(
self.result_outgoing = self.context.socket(zmq.DEALER)
self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode("utf-8"))
self.result_outgoing.setsockopt(zmq.LINGER, 0)
self.result_outgoing.setsockopt(zmq.IPV6, True)
self.result_outgoing.connect(result_q_url)

log.info("Manager connected")
Expand Down Expand Up @@ -213,7 +215,8 @@ def __init__(

self.funcx_task_socket = self.context.socket(zmq.ROUTER)
self.funcx_task_socket.set_hwm(0)
self.address = "127.0.0.1"
self.funcx_task_socket.setsockopt(zmq.IPV6, True)
self.address = "localhost"
self.worker_port = self.funcx_task_socket.bind_to_random_port(
"tcp://*",
min_port=self.internal_worker_port_range[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Worker:
Worker id string
address : str
Address at which the manager might be reached. This is usually 127.0.0.1
Address at which the manager might be reached. This is usually the ipv4
or ipv6 loopback address 127.0.0.1 or ::1
port : int
Port at which the manager can be reached
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(

self.task_socket = self.context.socket(zmq.DEALER)
self.task_socket.setsockopt(zmq.IDENTITY, self.identity)
self.task_socket.setsockopt(zmq.IPV6, True)

log.info(f"Trying to connect to : tcp://{self.address}:{self.port}")
self.task_socket.connect(f"tcp://{self.address}:{self.port}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#!/usr/bin/env python3

from __future__ import annotations

import ipaddress
import logging
import time

Expand All @@ -10,6 +13,42 @@
log = logging.getLogger(__name__)


def _zmq_canonicalize_address(addr: str | int) -> str:
try:
ip = ipaddress.ip_address(addr)
except ValueError:
# Not a valid IPv4 or IPv6 address
if isinstance(addr, int):
# If it was an integer, then it's just plain invalid
raise

# Otherwise, it was likely a hostname; let another layer deal with it
return addr

if ip.version == 4:
return str(ip) # like "12.34.56.78"
elif ip.version == 6:
return f"[{ip}]" # like "[::1]"


def _zmq_create_socket_port(context: zmq.Context, ip_address: str | int, port_range):
"""
Utility method with logic shared by all the pipes
"""
sock = context.socket(zmq.DEALER)
sock.set_hwm(0)
# This option should work for both IPv4 and IPv6...?
# May not work until Parsl is updated?
sock.setsockopt(zmq.IPV6, True)

port = sock.bind_to_random_port(
f"tcp://{_zmq_canonicalize_address(ip_address)}",
min_port=port_range[0],
max_port=port_range[1],
)
return sock, port


class CommandClient:
"""CommandClient"""

Expand All @@ -24,13 +63,10 @@ def __init__(self, ip_address, port_range):
Port range for the comms between client and interchange
"""

self.context = zmq.Context()
self.zmq_socket = self.context.socket(zmq.DEALER)
self.zmq_socket.set_hwm(0)
self.port = self.zmq_socket.bind_to_random_port(
f"tcp://{ip_address}",
min_port=port_range[0],
max_port=port_range[1],
self.zmq_socket, self.port = _zmq_create_socket_port(
self.context, ip_address, port_range
)

def run(self, message):
Expand Down Expand Up @@ -66,12 +102,8 @@ def __init__(self, ip_address, port_range):
"""
self.context = zmq.Context()
self.zmq_socket = self.context.socket(zmq.DEALER)
self.zmq_socket.set_hwm(0)
self.port = self.zmq_socket.bind_to_random_port(
f"tcp://{ip_address}",
min_port=port_range[0],
max_port=port_range[1],
self.zmq_socket, self.port = _zmq_create_socket_port(
self.context, ip_address, port_range
)
self.poller = zmq.Poller()
self.poller.register(self.zmq_socket, zmq.POLLOUT)
Expand Down Expand Up @@ -141,12 +173,8 @@ def __init__(self, ip_address, port_range):
"""
self.context = zmq.Context()
self.results_receiver = self.context.socket(zmq.DEALER)
self.results_receiver.set_hwm(0)
self.port = self.results_receiver.bind_to_random_port(
f"tcp://{ip_address}",
min_port=port_range[0],
max_port=port_range[1],
self.results_receiver, self.port = _zmq_create_socket_port(
self.context, ip_address, port_range
)

def get(self, block=True, timeout=None):
Expand Down
4 changes: 2 additions & 2 deletions compute_endpoint/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _runner(engine_type: t.Type[GlobusComputeEngineBase], **kwargs):
k = dict(max_workers=2)
elif engine_type is engines.GlobusComputeEngine:
k = dict(
address="127.0.0.1",
address="::1",
heartbeat_period=engine_heartbeat,
heartbeat_threshold=2,
job_status_kwargs=dict(max_idletime=0, strategy_period=0.1),
Expand All @@ -153,7 +153,7 @@ def _runner(engine_type: t.Type[GlobusComputeEngineBase], **kwargs):
"""

k = dict(
address="127.0.0.1",
address="::1",
heartbeat_period=engine_heartbeat,
heartbeat_threshold=1,
mpi_launcher="mpiexec",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,10 @@ def test_with_funcx_config(self, mocker):
mock_interchange.return_value.stop.return_value = None

mock_optionals = {}
mock_optionals["interchange_address"] = "127.0.0.1"
mock_optionals["interchange_address"] = "::1"

mock_funcx_config = {}
mock_funcx_config["endpoint_address"] = "127.0.0.1"
mock_funcx_config["endpoint_address"] = "::1"

manager = Endpoint(funcx_dir=os.getcwd())
manager.name = "test"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def gc_engine_scaling(tmp_path):
ep_id = uuid.uuid4()
engine = GlobusComputeEngine(
address="127.0.0.1",
address="::1",
heartbeat_period=1,
heartbeat_threshold=2,
provider=LocalProvider(
Expand All @@ -37,7 +37,7 @@ def gc_engine_scaling(tmp_path):
def gc_engine_non_scaling(tmp_path):
ep_id = uuid.uuid4()
engine = GlobusComputeEngine(
address="127.0.0.1",
address="::1",
heartbeat_period=1,
heartbeat_threshold=2,
provider=LocalProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,5 @@ def test_repeated_fail(mock_gce, ez_pack_task):


def test_default_retries_is_0():
engine = GlobusComputeEngine(address="127.0.0.1")
engine = GlobusComputeEngine(address="localhost")
assert engine.max_retries_on_system_failure == 0, "Users must knowingly opt-in"
2 changes: 1 addition & 1 deletion compute_endpoint/tests/unit/test_bad_endpoint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
_MOCK_BASE = "globus_compute_endpoint.engines.high_throughput.engine."


@pytest.mark.parametrize("address", ("localhost", "login1.theta.alcf.anl.gov", "*"))
@pytest.mark.parametrize("address", ("example", "a.b.c.d.e", "*"))
def test_invalid_address(address, htex_warns):
with mock.patch(f"{_MOCK_BASE}log") as mock_log:
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion compute_endpoint/tests/unit/test_cli_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_config_yaml_display_none(run_line, mock_command_ensure, display_name):
run_line(config_cmd)

conf_dict = dict(yaml.safe_load(conf.read_text()))
conf_dict["engine"]["address"] = "127.0.0.1" # avoid unnecessary DNS lookup
conf_dict["engine"]["address"] = "::1" # avoid unnecessary DNS lookup
conf = load_config_yaml(yaml.safe_dump(conf_dict))

assert conf.display_name is None, conf.display_name
Expand Down
8 changes: 5 additions & 3 deletions compute_endpoint/tests/unit/test_endpoint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@pytest.fixture
def config_dict():
return {"engine": {"type": "GlobusComputeEngine", "address": "127.0.0.1"}}
return {"engine": {"type": "GlobusComputeEngine", "address": "localhost"}}


@pytest.fixture
Expand Down Expand Up @@ -140,7 +140,9 @@ def test_conditional_engine_strategy(
):
config_dict["engine"]["type"] = engine_type
config_dict["engine"]["strategy"] = strategy
config_dict["engine"]["address"] = "127.0.0.1"
config_dict["engine"]["address"] = (
"::1" if engine_type != "HighThroughputEngine" else "127.0.0.1"
)

if engine_type == "GlobusComputeEngine":
if isinstance(strategy, str) or strategy is None:
Expand Down Expand Up @@ -173,7 +175,7 @@ def test_provider_container_compatibility(
):
config_dict["engine"]["container_uri"] = "docker://ubuntu"
config_dict["engine"]["provider"] = {"type": provider_type}
config_dict["engine"]["address"] = "127.0.0.1"
config_dict["engine"]["address"] = "::1"

if compatible:
UserEndpointConfigModel(**config_dict)
Expand Down
4 changes: 2 additions & 2 deletions compute_endpoint/tests/unit/test_endpoint_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def test_endpoint_get_metadata(mocker, engine_cls):

k = {}
if engine_cls is GlobusComputeEngine:
k["address"] = "127.0.0.1"
k["address"] = "::1"
executors = [engine_cls(**k)]
test_config = UserEndpointConfig(executors=executors)
test_config.source_content = "foo: bar"
Expand Down Expand Up @@ -720,7 +720,7 @@ def test_always_prints_endpoint_id_to_terminal(mocker, mock_ep_data, mock_reg_in
def test_serialize_config_field_types():
fns = [str(uuid.uuid4()) for _ in range(5)]

ep_config = UserEndpointConfig(executors=[GlobusComputeEngine(address="127.0.0.1")])
ep_config = UserEndpointConfig(executors=[GlobusComputeEngine(address="::1")])
ep_config._hidden_attr = "123"
ep_config.rando_attr = "howdy"
ep_config.allowed_functions = fns
Expand Down
Loading

0 comments on commit c04a101

Please sign in to comment.