diff --git a/tests/ert/unit_tests/ensemble_evaluator/conftest.py b/tests/ert/unit_tests/ensemble_evaluator/conftest.py index eda4a55b27a..64059518df7 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/conftest.py +++ b/tests/ert/unit_tests/ensemble_evaluator/conftest.py @@ -1,3 +1,4 @@ +import asyncio import json import os import stat @@ -5,8 +6,11 @@ from unittest.mock import MagicMock, Mock import pytest +import zmq +import zmq.asyncio import ert.ensemble_evaluator +from _ert.async_utils import new_event_loop from ert.config import QueueConfig, QueueSystem from ert.config.ert_config import _forward_model_step_from_config_file from ert.config.queue_config import LocalQueueOptions @@ -18,6 +22,53 @@ from tests.ert import SnapshotBuilder +@pytest.fixture(name="zmq_server") +def _mock_zmq_server(host, port, messages, delay_startup=0): + loop = new_event_loop() + done = loop.create_future() + + async def _handler(router_socket): + while True: + _, __, *frames = await router_socket.recv_multipart() + for frame in frames: + raw_msg = frame.decode("utf-8") + messages.append(raw_msg) + if raw_msg == "stop": + done.set_result(None) + break + + async def _run_server(): + await asyncio.sleep(delay_startup) + zmq_context = zmq.asyncio.Context() # type: ignore + router_socket = zmq_context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://*:{port}") + handler_task = asyncio.create_task(_handler(router_socket)) + await handler_task + router_socket.close() + + loop.run_until_complete(_run_server()) + loop.close() + + +@pytest.fixture(name="async_zmq_server") +def _async_mock_zmq_server(port, handler, set_when_done): + loop = new_event_loop() + + async def _run_server(): + zmq_context = zmq.asyncio.Context() # type: ignore + router_socket = zmq_context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://*:{port}") + while True: + dealer, __, *frames = await router_socket.recv_multipart() + for frame in frames: + handler(dealer.decode("utf-8"), frame) + if set_when_done: + return + + loop.run_until_complete(_run_server()) + loop.close() + + @pytest.fixture def snapshot(): return ( diff --git a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py index 7abe430a1a2..dd065c7e684 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py +++ b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py @@ -1,9 +1,3 @@ -import asyncio - -import zmq -import zmq.asyncio - -from _ert.async_utils import new_event_loop from _ert.events import ( EnsembleStarted, EnsembleSucceeded, @@ -19,31 +13,6 @@ from ert.ensemble_evaluator._ensemble import ForwardModelStep, Realization -def _mock_zmq_server(host, port, messages, delay_startup=0): - loop = new_event_loop() - done = loop.create_future() - - async def _handler(router_socket): - while True: - _, __, *frames = await router_socket.recv_multipart() - for frame in frames: - raw_msg = frame.decode("utf-8") - messages.append(raw_msg) - if raw_msg == "stop": - done.set_result(None) - break - - async def _run_server(): - await asyncio.sleep(delay_startup) - zmq_context = zmq.asyncio.Context() # type: ignore - router_socket = zmq_context.socket(zmq.ROUTER) - router_socket.bind(f"tcp://*:{port}") - await done - - loop.run_until_complete(_run_server()) - loop.close() - - class TestEnsemble(Ensemble): __test__ = False diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py index 2a201a46c21..e8a3af11c41 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py @@ -1,29 +1,33 @@ import asyncio import logging -from http import HTTPStatus -from urllib.parse import urlparse import pytest -from websockets.asyncio import server -from websockets.exceptions import ConnectionClosedOK +import zmq +import zmq.asyncio from _ert.events import EEUserCancel, EEUserDone, event_from_json from ert.ensemble_evaluator import Monitor from ert.ensemble_evaluator.config import EvaluatorConnectionInfo +# async def _mock_ws( +# set_when_done: asyncio.Event, handler, ee_config: EvaluatorConnectionInfo +# ): +# async def process_request(connection, request): +# if request.path == "/healthcheck": +# return connection.respond(HTTPStatus.OK, "") -async def _mock_ws( - set_when_done: asyncio.Event, handler, ee_config: EvaluatorConnectionInfo -): - async def process_request(connection, request): - if request.path == "/healthcheck": - return connection.respond(HTTPStatus.OK, "") +# url = urlparse(ee_config.url) +# async with server.serve( +# handler, url.hostname, url.port, process_request=process_request +# ): +# await set_when_done.wait() - url = urlparse(ee_config.url) - async with server.serve( - handler, url.hostname, url.port, process_request=process_request - ): - await set_when_done.wait() + +async def async_zmq_server(port, handler): + zmq_context = zmq.asyncio.Context() # type: ignore + router_socket = zmq_context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://*:{port}") + await handler(router_socket) async def test_no_connection_established(make_ee_config): @@ -38,48 +42,50 @@ async def test_no_connection_established(make_ee_config): async def test_immediate_stop(unused_tcp_port): - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") - - set_when_done = asyncio.Event() - - async def mock_ws_event_handler(websocket): - async for raw_msg in websocket: - event = event_from_json(raw_msg) - assert type(event) is EEUserDone - break - await websocket.close() + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") + + connected = False + + async def mock_event_handler(router_socket): + nonlocal connected + while True: + dealer, _, *frames = await router_socket.recv_multipart() + dealer = dealer.decode("utf-8") + for frame in frames: + frame = frame.decode("utf-8") + assert dealer.startswith("client-") + if frame == "CONNECT": + connected = True + elif frame == "DISCONNECT": + connected = False + return + else: + event = event_from_json(frame) + assert connected + assert type(event) is EEUserDone websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, mock_ws_event_handler, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) async with Monitor(ee_con_info) as monitor: await monitor.signal_done() - set_when_done.set() + assert connected is False await websocket_server_task +# TODO: refactor async def test_unexpected_close(unused_tcp_port): - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") - - set_when_done = asyncio.Event() - socket_closed = asyncio.Event() + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") - async def mock_ws_event_handler(websocket): - await websocket.close() - socket_closed.set() + async def mock_event_handler(router_socket): + router_socket.close() websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, mock_ws_event_handler, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) async with Monitor(ee_con_info) as monitor: - # this expects Event send to fail - # but no attempt on resubmitting - # since connection closed via websocket.close - with pytest.raises(ConnectionClosedOK): - await socket_closed.wait() - await monitor.signal_done() - - set_when_done.set() + await monitor.signal_done() + await websocket_server_task @@ -87,20 +93,32 @@ async def test_that_monitor_track_can_exit_without_terminated_event_from_evaluat unused_tcp_port, caplog ): caplog.set_level(logging.ERROR) - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") - - set_when_done = asyncio.Event() - - async def mock_ws_event_handler(websocket): - async for raw_msg in websocket: - event = event_from_json(raw_msg) - assert type(event) is EEUserCancel - break - await websocket.close() + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") + + connected = False + + async def mock_event_handler(router_socket): + nonlocal connected + while True: + dealer, _, *frames = await router_socket.recv_multipart() + dealer = dealer.decode("utf-8") + for frame in frames: + frame = frame.decode("utf-8") + assert dealer.startswith("client-") + if frame == "CONNECT": + connected = True + elif frame == "DISCONNECT": + connected = False + return + else: + event = event_from_json(frame) + assert connected + assert type(event) is EEUserCancel websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, mock_ws_event_handler, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) + async with Monitor(ee_con_info) as monitor: monitor._receiver_timeout = 0.1 await monitor.signal_cancel() @@ -112,7 +130,6 @@ async def mock_ws_event_handler(websocket): "Evaluator did not send the TERMINATED event!" ) in caplog.messages, "Monitor receiver did not stop!" - set_when_done.set() await websocket_server_task @@ -121,11 +138,16 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port): exit anytime. A heartbeat is a None event. If the heartbeat is never sent, this test function will hang and then timeout.""" - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") set_when_done = asyncio.Event() + + async def mock_event_handler(router_socket): + nonlocal set_when_done + await set_when_done.wait() + websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, None, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) async with Monitor(ee_con_info) as monitor: