Skip to content

Commit

Permalink
UPdate zmq mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 6, 2024
1 parent 9be8ace commit 0884d51
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 118 deletions.
2 changes: 1 addition & 1 deletion src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def _send(
return
except asyncio.TimeoutError:
logger.warning(
"Failed to get acknowledgment on the message. Resending."
f"{self.dealer_id} failed to get acknowledgment on the {messages}. Resending."
)
except zmq.ZMQError as exc:
logger.debug(
Expand Down
5 changes: 5 additions & 0 deletions src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def __init__(self, evaluator_url, token=None, cert_path=None):
self._event_publisher_thread = ErtThread(target=self._event_publisher)
self._done = False

def stop(self):
self._event_queue.put(Event._sentinel)
if self._event_publisher_thread.is_alive():
self._event_publisher_thread.join()

def _event_publisher(self):
logger.debug("Publishing event.")
with Client(
Expand Down
95 changes: 45 additions & 50 deletions tests/ert/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import asyncio
import json
import os
import stat
from pathlib import Path
from unittest.mock import MagicMock, Mock

import pytest
import zmq
import zmq.asyncio

import ert.ensemble_evaluator
from _ert.async_utils import new_event_loop
from ert.config import QueueConfig, QueueSystem
from ert.config.ert_config import _forward_model_step_from_config_file
from ert.config.queue_config import LocalQueueOptions
Expand All @@ -21,52 +17,51 @@
from ert.storage import Ensemble
from tests.ert import SnapshotBuilder


@pytest.fixture(name="zmq_server")
def _mock_zmq_server(host, port, messages, delay_startup=0):
loop = new_event_loop()
done = loop.create_future()

async def _handler(router_socket):
while True:
_, __, *frames = await router_socket.recv_multipart()
for frame in frames:
raw_msg = frame.decode("utf-8")
messages.append(raw_msg)
if raw_msg == "stop":
done.set_result(None)
break

async def _run_server():
await asyncio.sleep(delay_startup)
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
handler_task = asyncio.create_task(_handler(router_socket))
await handler_task
router_socket.close()

loop.run_until_complete(_run_server())
loop.close()


@pytest.fixture(name="async_zmq_server")
def _async_mock_zmq_server(port, handler, set_when_done):
loop = new_event_loop()

async def _run_server():
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
while True:
dealer, __, *frames = await router_socket.recv_multipart()
for frame in frames:
handler(dealer.decode("utf-8"), frame)
if set_when_done:
return

loop.run_until_complete(_run_server())
loop.close()
# @pytest.fixture(name="zmq_server")
# def _mock_zmq_server(host, port, messages, delay_startup=0):
# loop = new_event_loop()
# done = loop.create_future()

# async def _handler(router_socket):
# while True:
# _, __, *frames = await router_socket.recv_multipart()
# for frame in frames:
# raw_msg = frame.decode("utf-8")
# messages.append(raw_msg)
# if raw_msg == "stop":
# done.set_result(None)
# break

# async def _run_server():
# await asyncio.sleep(delay_startup)
# zmq_context = zmq.asyncio.Context() # type: ignore
# router_socket = zmq_context.socket(zmq.ROUTER)
# router_socket.bind(f"tcp://*:{port}")
# handler_task = asyncio.create_task(_handler(router_socket))
# await handler_task
# router_socket.close()

# loop.run_until_complete(_run_server())
# loop.close()


# @pytest.fixture(name="async_zmq_server")
# def _async_mock_zmq_server(port, handler, set_when_done):
# loop = new_event_loop()

# async def _run_server():
# zmq_context = zmq.asyncio.Context() # type: ignore
# router_socket = zmq_context.socket(zmq.ROUTER)
# router_socket.bind(f"tcp://*:{port}")
# while True:
# dealer, __, *frames = await router_socket.recv_multipart()
# for frame in frames:
# handler(dealer.decode("utf-8"), frame)
# if set_when_done:
# return

# loop.run_until_complete(_run_server())
# loop.close()


@pytest.fixture
Expand Down
37 changes: 13 additions & 24 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import zmq.asyncio

from _ert.events import EEUserCancel, EEUserDone, event_from_json
from _ert.forward_model_runner.client import ClientConnectionError
from ert.ensemble_evaluator import Monitor
from ert.ensemble_evaluator.config import EvaluatorConnectionInfo

Expand All @@ -24,9 +25,7 @@ async def test_no_connection_established(make_ee_config):
ee_config = make_ee_config()
monitor = Monitor(ee_config.get_connection_info())
monitor._connection_timeout = 0.1
with pytest.raises(
RuntimeError, match="Couldn't establish connection with the ensemble evaluator!"
):
with pytest.raises(ClientConnectionError):
async with monitor:
pass

Expand All @@ -40,22 +39,18 @@ async def mock_event_handler(router_socket):
nonlocal connected
while True:
dealer, _, *frames = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
dealer = dealer.decode("utf-8")
for frame in frames:
frame = frame.decode("utf-8")
assert dealer.startswith("client-")
if frame == "CONNECT":
await router_socket.send_multipart(
[dealer.encode("utf-8"), b"", b"ACK"]
)
connected = True
elif frame == "DISCONNECT":
connected = False
print(connected)
return
else:
event = event_from_json(frame)
print(f"{event=}")
assert connected
assert type(event) is EEUserDone

Expand Down Expand Up @@ -97,14 +92,12 @@ async def mock_event_handler(router_socket):
nonlocal connected
while True:
dealer, _, *frames = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
dealer = dealer.decode("utf-8")
for frame in frames:
frame = frame.decode("utf-8")
assert dealer.startswith("client-")
if frame == "CONNECT":
await router_socket.send_multipart(
[dealer.encode("utf-8"), b"", b"ACK"]
)
connected = True
elif frame == "DISCONNECT":
connected = False
Expand Down Expand Up @@ -139,18 +132,13 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port):
If the heartbeat is never sent, this test function will hang and then timeout."""
ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()

async def mock_event_handler(router_socket):
dealer, _, *frames = await router_socket.recv_multipart()
dealer = dealer.decode("utf-8")
for frame in frames:
frame = frame.decode("utf-8")
if frame == "CONNECT":
await router_socket.send_multipart(
[dealer.encode("utf-8"), b"", b"ACK"]
)
await set_when_done.wait()
while True:
try:
dealer, _, __ = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
except asyncio.CancelledError:
break

websocket_server_task = asyncio.create_task(
async_zmq_server(unused_tcp_port, mock_event_handler)
Expand All @@ -161,5 +149,6 @@ async def mock_event_handler(router_socket):
if event is None:
break

set_when_done.set() # shuts down websocket server
await websocket_server_task
if not websocket_server_task.done():
websocket_server_task.cancel()
asyncio.gather(websocket_server_task, return_exceptions=True)
12 changes: 7 additions & 5 deletions tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from _ert.forward_model_runner.reporting import Event, Interactive
from _ert.forward_model_runner.reporting.message import Finish, Init
from _ert.threading import ErtThread
from tests.ert.utils import _mock_ws_thread, wait_until
from tests.ert.utils import mock_zmq_thread, wait_until

from .test_event_reporter import _wait_until

Expand Down Expand Up @@ -302,7 +302,7 @@ def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, ca
jobs_json = json.dumps(
{
"ens_id": "_id_",
"dispatch_url": f"ws://localhost:{unused_tcp_port}",
"dispatch_url": f"tcp://localhost:{unused_tcp_port}",
"jobList": [],
}
)
Expand All @@ -316,7 +316,7 @@ def create_jobs_file_after_lock():
(tmp_path / JOBS_FILE).write_text(jobs_json)
lock.release()

with _mock_ws_thread("localhost", unused_tcp_port, []):
with mock_zmq_thread("localhost", unused_tcp_port, []):
thread = ErtThread(target=create_jobs_file_after_lock)
thread.start()
main(args=["script.py", str(tmp_path)])
Expand Down Expand Up @@ -347,7 +347,9 @@ def test_setup_reporters(is_interactive_run, ens_id):
def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port):
host = "localhost"
port = unused_tcp_port
jobs_json = json.dumps({"ens_id": "_id_", "dispatch_url": f"ws://localhost:{port}"})
jobs_json = json.dumps(
{"ens_id": "_id_", "dispatch_url": f"tcp://localhost:{port}"}
)

with (
patch("_ert.forward_model_runner.cli.os.killpg") as mock_killpg,
Expand All @@ -361,7 +363,7 @@ def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port):
]
mock_getpgid.return_value = 17

with _mock_ws_thread(host, port, []):
with mock_zmq_thread(host, port, []):
main(["script.py"])

mock_killpg.assert_called_with(17, signal.SIGKILL)
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def kill():
async def test_add_dispatch_information_to_jobs_file(
storage, tmp_path: Path, mock_driver
):
test_ee_uri = "ws://test_ee_uri.com/121/"
test_ee_uri = "tcp://test_ee_uri.com/121/"
test_ens_id = "test_ens_id121"
test_ee_token = "test_ee_token_t0k€n121"
test_ee_cert = "test_ee_cert121.pem"
Expand Down
78 changes: 41 additions & 37 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import asyncio
import contextlib
import time
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING

import zmq
import zmq.asyncio

from _ert.forward_model_runner.client import Client
from _ert.threading import ErtThread
from ert.scheduler.event import FinishedEvent, StartedEvent

Expand Down Expand Up @@ -62,37 +60,11 @@ def wait_until(func, interval=0.5, timeout=30):
)


def mock_zmq_server(messages, port):
loop = asyncio.new_event_loop()

async def _handler(router_socket):
while True:
dealer, __, *frames = await router_socket.recv_multipart()
if dealer.decode("utf-8").startswith("dispatch"):
await router_socket.send_multipart([dealer, b"", b"ACK"])
for frame in frames:
raw_msg = frame.decode("utf-8")
messages.append(raw_msg)
if raw_msg == "stop":
return

async def _run_server():
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
await _handler(router_socket)
router_socket.close()

loop.run_until_complete(_run_server())
loop.close()


async def async_mock_zmq_server(messages, port, server_started):
async def _handler(router_socket):
while True:
dealer, __, *frames = await router_socket.recv_multipart()
if dealer.decode("utf-8").startswith("dispatch"):
await router_socket.send_multipart([dealer, b"", b"ACK"])
await router_socket.send_multipart([dealer, b"", b"ACK"])
for frame in frames:
raw_msg = frame.decode("utf-8")
messages.append(raw_msg)
Expand All @@ -110,19 +82,51 @@ async def _handler(router_socket):

@contextlib.contextmanager
def mock_zmq_thread(host, port, messages):
mock_ws_thread = ErtThread(
target=partial(mock_zmq_server, messages=messages),
args=(port,),
loop = None
handler_task = None

def mock_zmq_server(messages, port):
nonlocal loop, handler_task
loop = asyncio.new_event_loop()

async def _handler(router_socket):
while True:
try:
dealer, __, *frames = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
for frame in frames:
raw_msg = frame.decode("utf-8")
messages.append(raw_msg)
except asyncio.CancelledError:
break

async def _run_server():
zmq_context = zmq.asyncio.Context() # type: ignore
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://*:{port}")
nonlocal handler_task
handler_task = asyncio.create_task(_handler(router_socket))
await handler_task
router_socket.close()

loop.run_until_complete(_run_server())
loop.close()

mock_zmq_thread = ErtThread(
target=lambda: mock_zmq_server(messages, port),
)
mock_ws_thread.start()
mock_zmq_thread.start()
try:
yield
# Make sure to join the thread even if an exception occurs
finally:
url = f"tcp://{host}:{port}"
with Client(url) as client:
client.send("stop")
mock_ws_thread.join()
# url = f"tcp://{host}:{port}"
# with Client(url) as client:
# client.send("stop")
# # Cancel the handler task explicitly
if handler_task and not handler_task.done():
loop.call_soon_threadsafe(handler_task.cancel)
mock_zmq_thread.join()
messages.pop()


Expand Down

0 comments on commit 0884d51

Please sign in to comment.