Skip to content

Commit

Permalink
Fix test_monitor with zmq
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Nov 20, 2024
1 parent dec3956 commit b693ebf
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 88 deletions.
51 changes: 51 additions & 0 deletions tests/ert/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import json
import os
import stat
from pathlib import Path
from unittest.mock import MagicMock, Mock

import pytest
import zmq
import zmq.asyncio

import ert.ensemble_evaluator
from _ert.async_utils import new_event_loop
from ert.config import QueueConfig, QueueSystem
from ert.config.ert_config import _forward_model_step_from_config_file
from ert.config.queue_config import LocalQueueOptions
Expand All @@ -18,6 +22,53 @@
from tests.ert import SnapshotBuilder


@pytest.fixture(name="zmq_server")
def _mock_zmq_server(host, port, messages, delay_startup=0):
loop = new_event_loop()
done = loop.create_future()

async def _handler(router_socket):
while True:
_, __, *frames = await router_socket.recv_multipart()
for frame in frames:
raw_msg = frame.decode("utf-8")
messages.append(raw_msg)
if raw_msg == "stop":
done.set_result(None)
break

async def _run_server():
await asyncio.sleep(delay_startup)
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
handler_task = asyncio.create_task(_handler(router_socket))
await handler_task
router_socket.close()

loop.run_until_complete(_run_server())
loop.close()


@pytest.fixture(name="async_zmq_server")
def _async_mock_zmq_server(port, handler, set_when_done):
loop = new_event_loop()

async def _run_server():
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
while True:
dealer, __, *frames = await router_socket.recv_multipart()
for frame in frames:
handler(dealer.decode("utf-8"), frame)
if set_when_done:
return

loop.run_until_complete(_run_server())
loop.close()


@pytest.fixture
def snapshot():
return (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
import asyncio

import zmq
import zmq.asyncio

from _ert.async_utils import new_event_loop
from _ert.events import (
EnsembleStarted,
EnsembleSucceeded,
Expand All @@ -19,31 +13,6 @@
from ert.ensemble_evaluator._ensemble import ForwardModelStep, Realization


def _mock_zmq_server(host, port, messages, delay_startup=0):
loop = new_event_loop()
done = loop.create_future()

async def _handler(router_socket):
while True:
_, __, *frames = await router_socket.recv_multipart()
for frame in frames:
raw_msg = frame.decode("utf-8")
messages.append(raw_msg)
if raw_msg == "stop":
done.set_result(None)
break

async def _run_server():
await asyncio.sleep(delay_startup)
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
await done

loop.run_until_complete(_run_server())
loop.close()


class TestEnsemble(Ensemble):
__test__ = False

Expand Down
136 changes: 79 additions & 57 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
import asyncio
import logging
from http import HTTPStatus
from urllib.parse import urlparse

import pytest
from websockets.asyncio import server
from websockets.exceptions import ConnectionClosedOK
import zmq
import zmq.asyncio

from _ert.events import EEUserCancel, EEUserDone, event_from_json
from ert.ensemble_evaluator import Monitor
from ert.ensemble_evaluator.config import EvaluatorConnectionInfo

# async def _mock_ws(
# set_when_done: asyncio.Event, handler, ee_config: EvaluatorConnectionInfo
# ):
# async def process_request(connection, request):
# if request.path == "/healthcheck":
# return connection.respond(HTTPStatus.OK, "")

async def _mock_ws(
set_when_done: asyncio.Event, handler, ee_config: EvaluatorConnectionInfo
):
async def process_request(connection, request):
if request.path == "/healthcheck":
return connection.respond(HTTPStatus.OK, "")
# url = urlparse(ee_config.url)
# async with server.serve(
# handler, url.hostname, url.port, process_request=process_request
# ):
# await set_when_done.wait()

url = urlparse(ee_config.url)
async with server.serve(
handler, url.hostname, url.port, process_request=process_request
):
await set_when_done.wait()

async def async_zmq_server(port, handler):
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
await handler(router_socket)


async def test_no_connection_established(make_ee_config):
Expand All @@ -38,69 +42,83 @@ async def test_no_connection_established(make_ee_config):


async def test_immediate_stop(unused_tcp_port):
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()

async def mock_ws_event_handler(websocket):
async for raw_msg in websocket:
event = event_from_json(raw_msg)
assert type(event) is EEUserDone
break
await websocket.close()
ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}")

connected = False

async def mock_event_handler(router_socket):
nonlocal connected
while True:
dealer, _, *frames = await router_socket.recv_multipart()
dealer = dealer.decode("utf-8")
for frame in frames:
frame = frame.decode("utf-8")
assert dealer.startswith("client-")
if frame == "CONNECT":
connected = True
elif frame == "DISCONNECT":
connected = False
return
else:
event = event_from_json(frame)
assert connected
assert type(event) is EEUserDone

websocket_server_task = asyncio.create_task(
_mock_ws(set_when_done, mock_ws_event_handler, ee_con_info)
async_zmq_server(unused_tcp_port, mock_event_handler)
)
async with Monitor(ee_con_info) as monitor:
await monitor.signal_done()
set_when_done.set()
assert connected is False
await websocket_server_task


# TODO: refactor
async def test_unexpected_close(unused_tcp_port):
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()
socket_closed = asyncio.Event()
ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}")

async def mock_ws_event_handler(websocket):
await websocket.close()
socket_closed.set()
async def mock_event_handler(router_socket):
router_socket.close()

websocket_server_task = asyncio.create_task(
_mock_ws(set_when_done, mock_ws_event_handler, ee_con_info)
async_zmq_server(unused_tcp_port, mock_event_handler)
)
async with Monitor(ee_con_info) as monitor:
# this expects Event send to fail
# but no attempt on resubmitting
# since connection closed via websocket.close
with pytest.raises(ConnectionClosedOK):
await socket_closed.wait()
await monitor.signal_done()

set_when_done.set()
await monitor.signal_done()

await websocket_server_task


async def test_that_monitor_track_can_exit_without_terminated_event_from_evaluator(
unused_tcp_port, caplog
):
caplog.set_level(logging.ERROR)
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()

async def mock_ws_event_handler(websocket):
async for raw_msg in websocket:
event = event_from_json(raw_msg)
assert type(event) is EEUserCancel
break
await websocket.close()
ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}")

connected = False

async def mock_event_handler(router_socket):
nonlocal connected
while True:
dealer, _, *frames = await router_socket.recv_multipart()
dealer = dealer.decode("utf-8")
for frame in frames:
frame = frame.decode("utf-8")
assert dealer.startswith("client-")
if frame == "CONNECT":
connected = True
elif frame == "DISCONNECT":
connected = False
return
else:
event = event_from_json(frame)
assert connected
assert type(event) is EEUserCancel

websocket_server_task = asyncio.create_task(
_mock_ws(set_when_done, mock_ws_event_handler, ee_con_info)
async_zmq_server(unused_tcp_port, mock_event_handler)
)

async with Monitor(ee_con_info) as monitor:
monitor._receiver_timeout = 0.1
await monitor.signal_cancel()
Expand All @@ -112,7 +130,6 @@ async def mock_ws_event_handler(websocket):
"Evaluator did not send the TERMINATED event!"
) in caplog.messages, "Monitor receiver did not stop!"

set_when_done.set()
await websocket_server_task


Expand All @@ -121,11 +138,16 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port):
exit anytime. A heartbeat is a None event.
If the heartbeat is never sent, this test function will hang and then timeout."""
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")
ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()

async def mock_event_handler(router_socket):
nonlocal set_when_done
await set_when_done.wait()

websocket_server_task = asyncio.create_task(
_mock_ws(set_when_done, None, ee_con_info)
async_zmq_server(unused_tcp_port, mock_event_handler)
)

async with Monitor(ee_con_info) as monitor:
Expand Down

0 comments on commit b693ebf

Please sign in to comment.