Skip to content

Commit

Permalink
Fix test_monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 1, 2024
1 parent 3377024 commit b1af1ad
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __enter__(self) -> Self:
self.loop.run_until_complete(self.reconnect())
return self

def term(self):
def term(self) -> None:
self._connected = False
self.socket.close()
self.context.term()
Expand Down
6 changes: 3 additions & 3 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
def ensemble(self) -> Ensemble:
return self._ensemble

async def handle_client(self, dealer, frames):
async def handle_client(self, dealer: bytes, frames: list[bytes]) -> None:
for frame in frames:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
Expand Down Expand Up @@ -230,7 +230,7 @@ async def handle_client(self, dealer, frames):
logger.debug("Client signalled done.")
self.stop()

async def handle_dispatch(self, dealer, frames):
async def handle_dispatch(self, dealer: bytes, frames: list[bytes]) -> None:
for frame in frames:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
Expand Down Expand Up @@ -283,7 +283,7 @@ async def forward_checksum(self, event: Event) -> None:
await self._events_to_send.put(event)
await self._manifest_queue.put(event)

async def _server(self):
async def _server(self) -> None:
_zmq_context = zmq.asyncio.Context()
try:
print("INIT ZMQ ...")
Expand Down
6 changes: 4 additions & 2 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None:
# zmq connection
self._zmq_context = zmq.asyncio.Context()
self._socket = self._zmq_context.socket(zmq.DEALER)
self._socket.setsockopt(zmq.LINGER, 0)
self._socket.setsockopt_string(zmq.IDENTITY, f"client-{self._id}")
if ee_con_info.token is not None:
client_public, client_secret = zmq.curve_keypair()
Expand Down Expand Up @@ -70,8 +71,8 @@ async def _term(self) -> None:
self._receiver_task,
return_exceptions=True,
)
self._socket.close()
self._zmq_context.term()
self._socket.close()
self._zmq_context.term()

async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self._term()
Expand Down Expand Up @@ -137,6 +138,7 @@ async def reconnect(self) -> None:
if ack.decode() != "ACK":
raise asyncio.TimeoutError("No Ack for connect")
except asyncio.TimeoutError:
print("NO CONNECTION")
logger.warning(
f"Failed to get acknowledgment on the monitor {self._id} connect!"
)
Expand Down
16 changes: 9 additions & 7 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import zmq
import zmq.asyncio

import ert
import ert.ensemble_evaluator
from _ert.events import EEUserCancel, EEUserDone, event_from_json
from ert.ensemble_evaluator import Monitor
from ert.ensemble_evaluator.config import EvaluatorConnectionInfo
Expand All @@ -15,10 +13,11 @@
async def async_zmq_server(port, handler):
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.setsockopt(zmq.LINGER, 0)
router_socket.bind(f"tcp://*:{port}")
await handler(router_socket)
router_socket.close()
zmq_context.term()
zmq_context.destroy()


async def test_no_connection_established(make_ee_config):
Expand Down Expand Up @@ -47,24 +46,27 @@ async def mock_event_handler(router_socket):
assert dealer.startswith("client-")
if frame == "CONNECT":
await router_socket.send_multipart(
[dealer.encode("utf-8"), b"", b"ACK_CONNECT"]
[dealer.encode("utf-8"), b"", b"ACK"]
)
connected = True
elif frame == "DISCONNECT":
connected = False
print(connected)
return
else:
event = event_from_json(frame)
print(f"{event=}")
assert connected
assert type(event) is EEUserDone

websocket_server_task = asyncio.create_task(
async_zmq_server(unused_tcp_port, mock_event_handler)
)
async with Monitor(ee_con_info) as monitor:
assert connected is True
await monitor.signal_done()
assert connected is False
await websocket_server_task
assert connected is False


# TODO: refactor
Expand Down Expand Up @@ -101,7 +103,7 @@ async def mock_event_handler(router_socket):
assert dealer.startswith("client-")
if frame == "CONNECT":
await router_socket.send_multipart(
[dealer.encode("utf-8"), b"", b"ACK_CONNECT"]
[dealer.encode("utf-8"), b"", b"ACK"]
)
connected = True
elif frame == "DISCONNECT":
Expand Down Expand Up @@ -146,7 +148,7 @@ async def mock_event_handler(router_socket):
frame = frame.decode("utf-8")
if frame == "CONNECT":
await router_socket.send_multipart(
[dealer.encode("utf-8"), b"", b"ACK_CONNECT"]
[dealer.encode("utf-8"), b"", b"ACK"]
)
await set_when_done.wait()

Expand Down

0 comments on commit b1af1ad

Please sign in to comment.