Skip to content

Commit

Permalink
Update EvaluatorServerConfig to contain zmq connection info
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Nov 13, 2024
1 parent c592a7d commit 6ddf210
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 27 deletions.
3 changes: 3 additions & 0 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(
# Set up ZeroMQ context and socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
# self.socket.setsockopt(zmq.LINGER, 0)
# self.socket.setsockopt(zmq.SNDTIMEO, self.CONNECTION_TIMEOUT * 1000)
self.socket.connect(url)

if cert:
client_public, client_secret = zmq.curve_keypair()
Expand Down
8 changes: 2 additions & 6 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@

from ._wait_for_evaluator import wait_for_evaluator
from .config import EvaluatorServerConfig
from .snapshot import (
EnsembleSnapshot,
FMStepSnapshot,
RealizationSnapshot,
)
from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot
from .state import (
ENSEMBLE_STATE_CANCELLED,
ENSEMBLE_STATE_FAILED,
Expand Down Expand Up @@ -282,7 +278,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
max_running=self._queue_config.max_running,
submit_sleep=self._queue_config.submit_sleep,
ens_id=self.id_,
ee_uri=self._config.dispatch_uri,
ee_uri=self._config.get_connection_info().push_pull_uri,
ee_cert=self._config.cert,
ee_token=self._config.token,
)
Expand Down
17 changes: 12 additions & 5 deletions src/ert/ensemble_evaluator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,16 @@ def __init__(
custom_range=custom_port_range, custom_host=custom_host
)
host, port = self._socket_handle.getsockname()
self.protocol = "wss" if generate_cert else "ws"
self.url = f"{self.protocol}://{host}:{port}"
self.client_uri = f"{self.url}/client"
self.dispatch_uri = f"{self.url}/dispatch"
self.host = host
self.pub_sub_port = port
host, port = self._socket_handle.getsockname()
self.push_pull_port = port

# self.protocol = "wss" if generate_cert else "ws"
# self.url = f"{self.protocol}://{host}:{port}"
# self.client_uri = f"{self.url}/client"
# self.dispatch_uri = f"{self.url}/dispatch"

if generate_cert:
cert, key, pw = _generate_certificate(host)
else:
Expand All @@ -151,7 +157,8 @@ def get_socket(self) -> socket.socket:

def get_connection_info(self) -> EvaluatorConnectionInfo:
return EvaluatorConnectionInfo(
self.url,
f"tcp://{self.host}:{self.push_pull_port}",
f"tcp://{self.host}:{self.pub_sub_port}",
self.cert,
self.token,
)
Expand Down
8 changes: 5 additions & 3 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):

async def _initialize_zmq(self) -> None:
self._zmq_context = zmq.asyncio.Context() # type: ignore
self._listen_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL)
self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL)
self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}")
self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUB)
self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}")

async def _publisher(self) -> None:
while True:
Expand Down Expand Up @@ -203,7 +205,7 @@ def ensemble(self) -> Ensemble:

async def listen_for_messages(self) -> None:
while True:
sender, raw_msg = await self._listen_socket.recv_multipart()
sender, raw_msg = await self._pull_socket.recv_multipart()
sender = sender.decode("utf-8")
if sender == "client":
event = event_from_json(raw_msg)
Expand Down Expand Up @@ -247,7 +249,7 @@ async def _server(self) -> None:
event = EETerminated(ensemble=self._ensemble.id_)
await self._events_to_send.put(event)
await self._events_to_send.join()
self._listen_socket.close()
self._pull_socket.close()
self._publisher_socket.close()
logger.debug("Async server exiting.")

Expand Down
15 changes: 2 additions & 13 deletions src/ert/ensemble_evaluator/evaluator_connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,7 @@
class EvaluatorConnectionInfo:
"""Read only server-info"""

url: str
push_pull_uri: str
pub_sub_uri: str
cert: Optional[Union[str, bytes]] = None
token: Optional[str] = None

@property
def dispatch_uri(self) -> str:
return f"{self.url}/dispatch"

@property
def client_uri(self) -> str:
return f"{self.url}/client"

@property
def result_uri(self) -> str:
return f"{self.url}/result"
2 changes: 2 additions & 0 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ async def _receiver(self) -> None:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_con_info.cert)

self._listen_socket.connect(self._ee_con_info.pub_sub_uri)
self._push_socket.connect(self._ee_con_info.push_pull_uri)
while True:
try:
raw_msg = await self._listen_socket.recv_string()
Expand Down

0 comments on commit 6ddf210

Please sign in to comment.