From b1af1ad869f472da1120e3c61aa311b8657e47d7 Mon Sep 17 00:00:00 2001 From: Julius Parulek Date: Sun, 1 Dec 2024 20:47:09 +0100 Subject: [PATCH] Fix test_monitor --- src/_ert/forward_model_runner/client.py | 2 +- src/ert/ensemble_evaluator/evaluator.py | 6 +++--- src/ert/ensemble_evaluator/monitor.py | 6 ++++-- .../ensemble_evaluator/test_monitor.py | 16 +++++++++------- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index ccb326aadc3..2f935fd092c 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -31,7 +31,7 @@ def __enter__(self) -> Self: self.loop.run_until_complete(self.reconnect()) return self - def term(self): + def term(self) -> None: self._connected = False self.socket.close() self.context.term() diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index d35f5c6efb0..8ed91291888 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -202,7 +202,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: def ensemble(self) -> Ensemble: return self._ensemble - async def handle_client(self, dealer, frames): + async def handle_client(self, dealer: bytes, frames: list[bytes]) -> None: for frame in frames: raw_msg = frame.decode("utf-8") if raw_msg == "CONNECT": @@ -230,7 +230,7 @@ async def handle_client(self, dealer, frames): logger.debug("Client signalled done.") self.stop() - async def handle_dispatch(self, dealer, frames): + async def handle_dispatch(self, dealer: bytes, frames: list[bytes]) -> None: for frame in frames: raw_msg = frame.decode("utf-8") if raw_msg == "CONNECT": @@ -283,7 +283,7 @@ async def forward_checksum(self, event: Event) -> None: await self._events_to_send.put(event) await self._manifest_queue.put(event) - async def _server(self): + async def _server(self) -> None: _zmq_context = zmq.asyncio.Context() try: print("INIT ZMQ ...") diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 84a2775818d..29a9d789194 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -42,6 +42,7 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: # zmq connection self._zmq_context = zmq.asyncio.Context() self._socket = self._zmq_context.socket(zmq.DEALER) + self._socket.setsockopt(zmq.LINGER, 0) self._socket.setsockopt_string(zmq.IDENTITY, f"client-{self._id}") if ee_con_info.token is not None: client_public, client_secret = zmq.curve_keypair() @@ -70,8 +71,8 @@ async def _term(self) -> None: self._receiver_task, return_exceptions=True, ) - self._socket.close() - self._zmq_context.term() + self._socket.close() + self._zmq_context.term() async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: await self._term() @@ -137,6 +138,7 @@ async def reconnect(self) -> None: if ack.decode() != "ACK": raise asyncio.TimeoutError("No Ack for connect") except asyncio.TimeoutError: + print("NO CONNECTION") logger.warning( f"Failed to get acknowledgment on the monitor {self._id} connect!" ) diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py index 3d67bab84ff..f3b478d13cd 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py @@ -5,8 +5,6 @@ import zmq import zmq.asyncio -import ert -import ert.ensemble_evaluator from _ert.events import EEUserCancel, EEUserDone, event_from_json from ert.ensemble_evaluator import Monitor from ert.ensemble_evaluator.config import EvaluatorConnectionInfo @@ -15,10 +13,11 @@ async def async_zmq_server(port, handler): zmq_context = zmq.asyncio.Context() # type: ignore router_socket = zmq_context.socket(zmq.ROUTER) + router_socket.setsockopt(zmq.LINGER, 0) router_socket.bind(f"tcp://*:{port}") await handler(router_socket) router_socket.close() - zmq_context.term() + zmq_context.destroy() async def test_no_connection_established(make_ee_config): @@ -47,14 +46,16 @@ async def mock_event_handler(router_socket): assert dealer.startswith("client-") if frame == "CONNECT": await router_socket.send_multipart( - [dealer.encode("utf-8"), b"", b"ACK_CONNECT"] + [dealer.encode("utf-8"), b"", b"ACK"] ) connected = True elif frame == "DISCONNECT": connected = False + print(connected) return else: event = event_from_json(frame) + print(f"{event=}") assert connected assert type(event) is EEUserDone @@ -62,9 +63,10 @@ async def mock_event_handler(router_socket): async_zmq_server(unused_tcp_port, mock_event_handler) ) async with Monitor(ee_con_info) as monitor: + assert connected is True await monitor.signal_done() - assert connected is False await websocket_server_task + assert connected is False # TODO: refactor @@ -101,7 +103,7 @@ async def mock_event_handler(router_socket): assert dealer.startswith("client-") if frame == "CONNECT": await router_socket.send_multipart( - [dealer.encode("utf-8"), b"", b"ACK_CONNECT"] + [dealer.encode("utf-8"), b"", b"ACK"] ) connected = True elif frame == "DISCONNECT": @@ -146,7 +148,7 @@ async def mock_event_handler(router_socket): frame = frame.decode("utf-8") if frame == "CONNECT": await router_socket.send_multipart( - [dealer.encode("utf-8"), b"", b"ACK_CONNECT"] + [dealer.encode("utf-8"), b"", b"ACK"] ) await set_when_done.wait()