Skip to content

Commit

Permalink
Replace websockets with zmq in monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Nov 12, 2024
1 parent 9285352 commit 455b3c4
Showing 1 changed file with 24 additions and 43 deletions.
67 changes: 24 additions & 43 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union

from aiohttp import ClientError
from websockets import ConnectionClosed, Headers, WebSocketClientProtocol
from websockets.client import connect
import zmq.asyncio

from _ert.events import (
EETerminated,
Expand All @@ -16,7 +14,6 @@
event_from_json,
event_to_json,
)
from ert.ensemble_evaluator._wait_for_evaluator import wait_for_evaluator

if TYPE_CHECKING:
from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo
Expand All @@ -36,11 +33,16 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None:
self._ee_con_info = ee_con_info
self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0]
self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue()
self._connection: Optional[WebSocketClientProtocol] = None
self._receiver_task: Optional[asyncio.Task[None]] = None
self._connected: asyncio.Event = asyncio.Event()
self._connection_timeout: float = 120.0
self._receiver_timeout: float = 60.0
self._zmq_context = zmq.asyncio.Context() # type: ignore
self._listen_socket: zmq.asyncio.Socket = self._zmq_context.socket(
zmq.SUBSCRIBE
)
self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self._push_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUSH)

async def __aenter__(self) -> "Monitor":
self._receiver_task = asyncio.create_task(self._receiver())
Expand All @@ -65,27 +67,27 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None
return_exceptions=True,
)

if self._connection:
await self._connection.close()
self._listen_socket.close()
self._push_socket.close()

async def signal_cancel(self) -> None:
if not self._connection:
return
await self._event_queue.put(Monitor._sentinel)
logger.debug(f"monitor-{self._id} asking server to cancel...")

cancel_event = EEUserCancel(monitor=self._id)
await self._connection.send(event_to_json(cancel_event))
await self._push_socket.send_multipart(
[b"client", event_to_json(cancel_event).encode()]
)
logger.debug(f"monitor-{self._id} asked server to cancel")

async def signal_done(self) -> None:
if not self._connection:
return
await self._event_queue.put(Monitor._sentinel)
logger.debug(f"monitor-{self._id} informing server monitor is done...")

done_event = EEUserDone(monitor=self._id)
await self._connection.send(event_to_json(done_event))
await self._push_socket.send_multipart(
[b"client", event_to_json(done_event).encode()]
)
logger.debug(f"monitor-{self._id} informed server monitor is done")

async def track(
Expand Down Expand Up @@ -124,36 +126,15 @@ async def _receiver(self) -> None:
if self._ee_con_info.cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_con_info.cert)
headers = Headers()
if self._ee_con_info.token:
headers["token"] = self._ee_con_info.token

await wait_for_evaluator(
base_url=self._ee_con_info.url,
token=self._ee_con_info.token,
cert=self._ee_con_info.cert,
timeout=5,
)
async for conn in connect(
self._ee_con_info.client_uri,
ssl=tls,
extra_headers=headers,
max_size=2**26,
max_queue=500,
open_timeout=5,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
):

while True:
try:
self._connection = conn
self._connected.set()
async for raw_msg in self._connection:
event = event_from_json(raw_msg)
await self._event_queue.put(event)
except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc:
self._connection = None
self._connected.clear()
raw_msg = await self._listen_socket.recv_string()
event = event_from_json(raw_msg)
await self._event_queue.put(event)
except (zmq.ZMQError, asyncio.CancelledError) as exc:
# Handle disconnection or other ZMQ errors (reconnect or log)
logger.debug(
f"Monitor connection to EnsembleEvaluator went down, reconnecting: {exc}"
f"ZeroMQ connection to EnsembleEvaluator went down, reconnecting: {exc}"
)
await asyncio.sleep(1)

0 comments on commit 455b3c4

Please sign in to comment.