diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py index f41f3ecdc..3880994f1 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py +++ b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import ipaddress import logging import time @@ -10,6 +11,41 @@ log = logging.getLogger(__name__) +def _zmq_canonicalize_address(addr: str | int) -> str: + try: + ip = ipaddress.ip_address(addr) + except ValueError: + # Not a valid IPv4 or IPv6 address + if isinstance(addr, int): + # If it was an integer, then it's just plain invalid + raise + + # Otherwise, it was likely a hostname; let another layer deal with it + return addr + + if ip.version == 4: + return str(ip) # like "12.34.56.78" + elif ip.version == 6: + return f"[{ip}]" # like "[::1]" + + +def _zmq_create_socket_port(context: zmq.Context, ip_address: str | int, port_range): + """ + Utility method with logic shared by all the pipes + """ + sock = context.socket(zmq.DEALER) + sock.set_hwm(0) + # This option should work for both IPv4 and IPv6...? + # May not work until Parsl is updated? + sock.setsockopt(zmq.IPV6, True) + + port = sock.bind_to_random_port( + f"tcp://{_zmq_canonicalize_address(ip_address)}", + min_port=port_range[0], + max_port=port_range[1], + ) + return sock, port + class CommandClient: """CommandClient""" @@ -25,12 +61,8 @@ def __init__(self, ip_address, port_range): """ self.context = zmq.Context() - self.zmq_socket = self.context.socket(zmq.DEALER) - self.zmq_socket.set_hwm(0) - self.port = self.zmq_socket.bind_to_random_port( - f"tcp://{ip_address}", - min_port=port_range[0], - max_port=port_range[1], + self.zmq_socket, self.port = _zmq_create_socket_port( + self.context, ip_address, port_range ) def run(self, message): @@ -66,12 +98,8 @@ def __init__(self, ip_address, port_range): """ self.context = zmq.Context() - self.zmq_socket = self.context.socket(zmq.DEALER) - self.zmq_socket.set_hwm(0) - self.port = self.zmq_socket.bind_to_random_port( - f"tcp://{ip_address}", - min_port=port_range[0], - max_port=port_range[1], + self.zmq_socket, self.port = _zmq_create_socket_port( + self.context, ip_address, port_range ) self.poller = zmq.Poller() self.poller.register(self.zmq_socket, zmq.POLLOUT) @@ -141,12 +169,8 @@ def __init__(self, ip_address, port_range): """ self.context = zmq.Context() - self.results_receiver = self.context.socket(zmq.DEALER) - self.results_receiver.set_hwm(0) - self.port = self.results_receiver.bind_to_random_port( - f"tcp://{ip_address}", - min_port=port_range[0], - max_port=port_range[1], + self.results_receiver, self.port = _zmq_create_socket_port( + self.context, ip_address, port_range ) def get(self, block=True, timeout=None):