Skip to content

Commit

Permalink
Replace connect, disconnect and ack strings with constants
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 13, 2024
1 parent 126441f commit 73624e2
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
11 changes: 8 additions & 3 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class ClientConnectionClosedOK(Exception):
pass


CONNECT_MSG = "CONNECT"
DISCONNECT_MSG = "DISCONNECT"
ACK_MSG = b"ACK"


class Client:
DEFAULT_MAX_RETRIES = 5
DEFAULT_ACK_TIMEOUT = 5
Expand All @@ -46,7 +51,7 @@ async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
try:
await self._send("DISCONNECT")
await self._send(DISCONNECT_MSG)
except ClientConnectionError:
logger.error("No ack for dealer disconnection. Connection is down!")
finally:
Expand Down Expand Up @@ -96,7 +101,7 @@ async def connect(self) -> None:
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self._send("CONNECT", retries=1)
await self._send(CONNECT_MSG, retries=1)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
Expand All @@ -112,7 +117,7 @@ async def _receiver(self) -> None:
while True:
try:
_, raw_msg = await self.socket.recv_multipart()
if raw_msg == b"ACK":
if raw_msg == ACK_MSG:
self._ack_event.set()
else:
await self.process_message(raw_msg.decode("utf-8"))
Expand Down
12 changes: 6 additions & 6 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
event_from_json,
event_to_json,
)
from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
from ert.ensemble_evaluator import identifiers as ids

from ._ensemble import FMStepSnapshot
Expand Down Expand Up @@ -187,7 +188,7 @@ def ensemble(self) -> Ensemble:

async def handle_client(self, dealer: bytes, frame: bytes) -> None:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
if raw_msg == CONNECT_MSG:
self._clients_connected.add(dealer)
self._clients_empty.clear()
current_snapshot_dict = self._ensemble.snapshot.to_dict()
Expand All @@ -198,7 +199,7 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None:
await self._router_socket.send_multipart(
[dealer, b"", event_to_json(event).encode("utf-8")]
)
elif raw_msg == "DISCONNECT":
elif raw_msg == DISCONNECT_MSG:
self._clients_connected.discard(dealer)
if not self._clients_connected:
self._clients_empty.set()
Expand All @@ -213,10 +214,10 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None:

async def handle_dispatch(self, dealer: bytes, frame: bytes) -> None:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
if raw_msg == CONNECT_MSG:
self._dispatchers_connected.add(dealer)
self._dispatchers_empty.clear()
elif raw_msg == "DISCONNECT":
elif raw_msg == DISCONNECT_MSG:
self._dispatchers_connected.discard(dealer)
if not self._dispatchers_connected:
self._dispatchers_empty.set()
Expand All @@ -239,8 +240,7 @@ async def listen_for_messages(self) -> None:
while True:
try:
dealer, _, frame = await self._router_socket.recv_multipart()
# print(f"GOT MESSAGE {frames=} from {dealer=}")
await self._router_socket.send_multipart([dealer, b"", b"ACK"])
await self._router_socket.send_multipart([dealer, b"", ACK_MSG])
sender = dealer.decode("utf-8")
if sender.startswith("client"):
await self.handle_client(dealer, frame)
Expand Down
25 changes: 15 additions & 10 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import zmq.asyncio

from _ert.events import EEUserCancel, EEUserDone, event_from_json
from _ert.forward_model_runner.client import ClientConnectionError
from _ert.forward_model_runner.client import (
ACK_MSG,
CONNECT_MSG,
DISCONNECT_MSG,
ClientConnectionError,
)
from ert.ensemble_evaluator import Monitor
from ert.ensemble_evaluator.config import EvaluatorConnectionInfo

Expand Down Expand Up @@ -39,14 +44,14 @@ async def mock_event_handler(router_socket):
nonlocal connected
while True:
dealer, _, *frames = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
await router_socket.send_multipart([dealer, b"", ACK_MSG])
dealer = dealer.decode("utf-8")
for frame in frames:
frame = frame.decode("utf-8")
assert dealer.startswith("client-")
if frame == "CONNECT":
if frame == CONNECT_MSG:
connected = True
elif frame == "DISCONNECT":
elif frame == DISCONNECT_MSG:
connected = False
return
else:
Expand Down Expand Up @@ -74,11 +79,11 @@ async def test_unexpected_close_after_connection_successful(

async def mock_event_handler(router_socket):
dealer, _, frame = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
await router_socket.send_multipart([dealer, b"", ACK_MSG])
dealer = dealer.decode("utf-8")
assert dealer.startswith("client-")
frame = frame.decode("utf-8")
assert frame == "CONNECT"
assert frame == CONNECT_MSG
router_socket.close()

websocket_server_task = asyncio.create_task(
Expand All @@ -103,14 +108,14 @@ async def mock_event_handler(router_socket):
nonlocal connected
while True:
dealer, _, *frames = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
await router_socket.send_multipart([dealer, b"", ACK_MSG])
dealer = dealer.decode("utf-8")
for frame in frames:
frame = frame.decode("utf-8")
assert dealer.startswith("client-")
if frame == "CONNECT":
if frame == CONNECT_MSG:
connected = True
elif frame == "DISCONNECT":
elif frame == DISCONNECT_MSG:
connected = False
return
else:
Expand Down Expand Up @@ -147,7 +152,7 @@ async def mock_event_handler(router_socket):
while True:
try:
dealer, _, __ = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", b"ACK"])
await router_socket.send_multipart([dealer, b"", ACK_MSG])
except asyncio.CancelledError:
break

Expand Down
19 changes: 11 additions & 8 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import zmq
import zmq.asyncio

from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
from _ert.threading import ErtThread
from ert.scheduler.event import FinishedEvent, StartedEvent

Expand Down Expand Up @@ -75,10 +76,11 @@ async def _handler(router_socket):
)

print(f"{dealer=} {frame=} {signal_value=}")
if frame in [b"CONNECT", b"DISCONNECT"] or signal_value == 0:
await router_socket.send_multipart([dealer, b"", b"ACK"])
if frame not in [b"CONNECT", b"DISCONNECT"] and signal_value != 1:
messages.append(frame.decode("utf-8"))
frame = frame.decode("utf-8")
if frame in [CONNECT_MSG, DISCONNECT_MSG] or signal_value == 0:
await router_socket.send_multipart([dealer, b"", ACK_MSG])
if frame not in [CONNECT_MSG, DISCONNECT_MSG] and signal_value != 1:
messages.append(frame)

zmq_context = zmq.asyncio.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
Expand Down Expand Up @@ -123,10 +125,11 @@ async def _handler(router_socket):
signal_value = signal_queue.get(timeout=0.1)

print(f"{dealer=} {frame=} {signal_value=}")
if frame in [b"CONNECT", b"DISCONNECT"] or signal_value == 0:
await router_socket.send_multipart([dealer, b"", b"ACK"])
if frame not in [b"CONNECT", b"DISCONNECT"] and signal_value != 1:
messages.append(frame.decode("utf-8"))
frame = frame.decode("utf-8")
if frame in [CONNECT_MSG, DISCONNECT_MSG] or signal_value == 0:
await router_socket.send_multipart([dealer, b"", ACK_MSG])
if frame not in [CONNECT_MSG, DISCONNECT_MSG] and signal_value != 1:
messages.append(frame)

except asyncio.CancelledError:
break
Expand Down

0 comments on commit 73624e2

Please sign in to comment.