Skip to content

Commit

Permalink
Let zmq select ports itself
Browse files Browse the repository at this point in the history
zmq will select a port each time ensemble evaluator starts. This way
there will be no time where other programs can pick up the target
port before zmq does.
Removed large parts of net_utils and its tests that were no longer
needed.
  • Loading branch information
JHolba committed Feb 7, 2025
1 parent 7437d71 commit 825a648
Show file tree
Hide file tree
Showing 21 changed files with 168 additions and 599 deletions.
5 changes: 4 additions & 1 deletion src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def run_cli(args: Namespace, plugin_manager: ErtPluginManager | None = None) ->

use_ipc_protocol = model.queue_system == QueueSystem.LOCAL
evaluator_server_config = EvaluatorServerConfig(
custom_port_range=args.port_range, use_ipc_protocol=use_ipc_protocol
port_range=None
if args.port_range is None
else (min(args.port_range), max(args.port_range) + 1),
use_ipc_protocol=use_ipc_protocol,
)

if model.check_if_runpath_exists():
Expand Down
4 changes: 2 additions & 2 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ async def evaluate(
ce_unary_send_method_name,
partialmethod(
self.__class__.send_event,
self._config.get_connection_info().router_uri,
self._config.get_uri(),
token=self._config.token,
),
)
Expand Down Expand Up @@ -267,7 +267,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
max_running=self._queue_config.max_running,
submit_sleep=self._queue_config.submit_sleep,
ens_id=self.id_,
ee_uri=self._config.get_connection_info().router_uri,
ee_uri=self._config.get_uri(),
ee_token=self._config.token,
)
logger.info(
Expand Down
53 changes: 26 additions & 27 deletions src/ert/ensemble_evaluator/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import logging
import socket
import uuid
import warnings

import zmq

from ert.shared import find_available_socket
from ert.shared import get_machine_name as ert_shared_get_machine_name

from .evaluator_connection_info import EvaluatorConnectionInfo
from ert.shared.net_utils import get_ip_address

logger = logging.getLogger(__name__)

Expand All @@ -25,39 +22,41 @@ def get_machine_name() -> str:
class EvaluatorServerConfig:
def __init__(
self,
custom_port_range: range | None = None,
port_range: tuple[int, int] | None = None,
use_token: bool = True,
custom_host: str | None = None,
host: str | None = None,
use_ipc_protocol: bool = True,
) -> None:
self.host: str | None = None
self.host: str | None = host
self.router_port: int | None = None
self.url = f"ipc:///tmp/socket-{uuid.uuid4().hex[:8]}"
self.token: str | None = None
self._socket_handle: socket.socket | None = None

self.server_public_key: bytes | None = None
self.server_secret_key: bytes | None = None
if not use_ipc_protocol:
self._socket_handle = find_available_socket(
custom_range=custom_port_range,
custom_host=custom_host,
will_close_then_reopen_socket=True,
)
self.host, self.router_port = self._socket_handle.getsockname()
self.url = f"tcp://{self.host}:{self.router_port}"
self.use_ipc_protocol: bool = use_ipc_protocol

if port_range is None:
port_range = (51820, 51840 + 1)
else:
if port_range[0] > port_range[1]:
raise ValueError("Minimum port in range is higher than maximum port")

if port_range[0] == port_range[1]:
port_range = (port_range[0], port_range[0] + 1)

self.min_port = port_range[0]
self.max_port = port_range[1]

if use_ipc_protocol:
self.uri = f"ipc:///tmp/socket-{uuid.uuid4().hex[:8]}"
elif self.host is None:
self.host = get_ip_address()

if use_token:
self.server_public_key, self.server_secret_key = zmq.curve_keypair()
self.token = self.server_public_key.decode("utf-8")

def get_socket(self) -> socket.socket | None:
if self._socket_handle:
return self._socket_handle.dup()
return None
def get_uri(self) -> str:
if not self.use_ipc_protocol:
return f"tcp://{self.host}:{self.router_port}"

def get_connection_info(self) -> EvaluatorConnectionInfo:
return EvaluatorConnectionInfo(
self.url,
self.token,
)
return self.uri
13 changes: 9 additions & 4 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,17 @@ async def _server(self) -> None:
self._router_socket.curve_publickey = self._config.server_public_key
self._router_socket.curve_server = True

if self._config.router_port:
self._router_socket.bind(f"tcp://*:{self._config.router_port}")
if self._config.use_ipc_protocol:
self._router_socket.bind(self._config.get_uri())
else:
self._router_socket.bind(self._config.url)
self._config.router_port = self._router_socket.bind_to_random_port(
"tcp://*",
min_port=self._config.min_port,
max_port=self._config.max_port,
)

self._server_started.set_result(None)
except zmq.error.ZMQError as e:
except zmq.error.ZMQBaseError as e:
logger.error(f"ZMQ error encountered {e} during evaluator initialization")
self._server_started.set_exception(e)
zmq_context.destroy(linger=0)
Expand Down
9 changes: 0 additions & 9 deletions src/ert/ensemble_evaluator/evaluator_connection_info.py

This file was deleted.

14 changes: 3 additions & 11 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Final
from typing import Final

from _ert.events import (
EETerminated,
Expand All @@ -16,10 +16,6 @@
)
from _ert.forward_model_runner.client import Client

if TYPE_CHECKING:
from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo


logger = logging.getLogger(__name__)


Expand All @@ -30,15 +26,11 @@ class EventSentinel:
class Monitor(Client):
_sentinel: Final = EventSentinel()

def __init__(self, ee_con_info: EvaluatorConnectionInfo) -> None:
def __init__(self, uri: str, token: str | None = None) -> None:
self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0]
self._event_queue: asyncio.Queue[Event | EventSentinel] = asyncio.Queue()
self._receiver_timeout: float = 60.0
super().__init__(
ee_con_info.router_uri,
ee_con_info.token,
dealer_name=f"client-{self._id}",
)
super().__init__(uri, token, dealer_name=f"client-{self._id}")

async def process_message(self, msg: str) -> None:
event = event_from_json(msg)
Expand Down
7 changes: 1 addition & 6 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,8 @@ def run_experiment(self, restart: bool = False) -> None:
self._snapshot_model.reset()
self._tab_widget.clear()

port_range = None
use_ipc_protocol = False
if self._run_model.queue_system == QueueSystem.LOCAL:
port_range = range(49152, 51819)
use_ipc_protocol = True
evaluator_server_config = EvaluatorServerConfig(
custom_port_range=port_range, use_ipc_protocol=use_ipc_protocol
use_ipc_protocol=self._run_model.queue_system == QueueSystem.LOCAL
)

def run() -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ async def run_monitor(
) -> bool:
try:
logger.debug("connecting to new monitor...")
async with Monitor(ee_config.get_connection_info()) as monitor:
async with Monitor(ee_config.get_uri(), ee_config.token) as monitor:
logger.debug("connected")
async for event in monitor.track(heartbeat_interval=0.1):
if type(event) in {
Expand Down
8 changes: 2 additions & 6 deletions src/ert/services/_storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def run_server(
if args is None:
args = parse_args()

if "ERT_STORAGE_TOKEN" in os.environ:
authtoken = os.environ["ERT_STORAGE_TOKEN"]
else:
if (authtoken := os.environ.get("ERT_STORAGE_TOKEN")) is None:
authtoken = generate_authtoken()
os.environ["ERT_STORAGE_TOKEN"] = authtoken

Expand All @@ -106,9 +104,7 @@ def run_server(
config_args.update(reload=True, reload_dirs=[os.path.dirname(ert_shared_path)])
os.environ["ERT_STORAGE_DEBUG"] = "1"

sock = find_available_socket(
custom_host=args.host, custom_range=range(51850, 51870)
)
sock = find_available_socket(host=args.host, port_range=range(51850, 51870 + 1))
connection_info = _create_connection_info(sock, authtoken)

# Appropriated from uvicorn.main:run
Expand Down
61 changes: 21 additions & 40 deletions src/ert/shared/net_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def get_machine_name() -> str:


def find_available_socket(
custom_host: str | None = None,
custom_range: range | None = None,
will_close_then_reopen_socket: bool = False,
host: str | None = None,
port_range: range = range(51820, 51840 + 1),
) -> socket.socket:
"""
The default and recommended approach here is to return a bound socket to the
Expand All @@ -70,49 +69,30 @@ def find_available_socket(
See e.g. implementation and comments in EvaluatorServerConfig
"""
current_host = custom_host if custom_host is not None else _get_ip_address()
current_range = (
custom_range if custom_range is not None else range(51820, 51840 + 1)
)
current_host = host if host is not None else get_ip_address()

if current_range.start == current_range.stop:
ports = list(range(current_range.start, current_range.stop + 1))
if port_range.start == port_range.stop:
ports = list(range(port_range.start, port_range.stop + 1))
else:
ports = list(range(current_range.start, current_range.stop))
ports = list(range(port_range.start, port_range.stop))

random.shuffle(ports)
for port in ports:
try:
return _bind_socket(
host=current_host,
port=port,
will_close_then_reopen_socket=will_close_then_reopen_socket,
)
except PortAlreadyInUseException:
continue

raise NoPortsInRangeException(f"No available ports in range {current_range}.")
raise NoPortsInRangeException(f"No available ports in range {port_range}.")


def _bind_socket(
host: str, port: int, will_close_then_reopen_socket: bool = False
) -> socket.socket:
def _bind_socket(host: str, port: int) -> socket.socket:
try:
family = get_family(host=host)
sock = socket.socket(family=family, type=socket.SOCK_STREAM)

# Setting flags like SO_REUSEADDR and/or SO_REUSEPORT may have
# undesirable side-effects but we allow it if caller insists. Refer to
# comment on find_available_socket()
#
# See e.g. https://stackoverflow.com/a/14388707 for an extensive
# explanation of these flags, in particular the part about TIME_WAIT

if will_close_then_reopen_socket:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
else:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0)

sock.bind((host, port))
return sock
except socket.gaierror as err_info:
Expand All @@ -139,18 +119,19 @@ def get_family(host: str) -> socket.AddressFamily:


# See https://stackoverflow.com/a/28950776
def _get_ip_address() -> str:
def get_ip_address() -> str:
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
# try pinging a reserved, internal address in order
# to determine IP representing the default route
s.connect(("10.255.255.255", 1))
retval = s.getsockname()[0]
try:
s.settimeout(0)
# try pinging a reserved, internal address in order
# to determine IP representing the default route
s.connect(("10.255.255.255", 1))
address = s.getsockname()[0]
finally:
s.close()
except BaseException:
logger.warning("Cannot determine ip-address. Fallback to localhost...")
retval = "127.0.0.1"
finally:
s.close()
logger.debug(f"ip-address: {retval}")
return retval
logger.warning("Cannot determine ip-address. Falling back to localhost.")
address = "127.0.0.1"
logger.debug(f"ip-address: {address}")
return address
2 changes: 1 addition & 1 deletion src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def main():
evaluator_server_config = EvaluatorServerConfig()
else:
evaluator_server_config = EvaluatorServerConfig(
custom_port_range=range(49152, 51819), use_ipc_protocol=False
port_range=(49152, 51819), use_ipc_protocol=False
)

run_model.run_experiment(evaluator_server_config)
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,6 @@ def _dump_forward_model(forward_model, index):
@pytest.fixture(name="make_ee_config")
def make_ee_config_fixture():
def _ee_config(**kwargs):
return EvaluatorServerConfig(custom_port_range=range(1024, 65535), **kwargs)
return EvaluatorServerConfig(**kwargs)

return _ee_config
Loading

0 comments on commit 825a648

Please sign in to comment.