Skip to content

Commit

Permalink
Make monitor slighly advanced from Client
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 5, 2024
1 parent d2d871c commit 9cfd2fd
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 131 deletions.
96 changes: 55 additions & 41 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,35 @@ class Client:
CONNECTION_TIMEOUT = 60

def __enter__(self) -> Self:
self.loop.run_until_complete(self.reconnect())
self.loop.run_until_complete(self.__aenter__())
return self

def term(self) -> None:
self.socket.close()
self.context.term()
self.loop.close()

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
self.send("DISCONNECT")
self.socket.disconnect(self.url)
self.term()
self.loop.run_until_complete(self.__aexit__(exc_type, exc_value, exc_traceback))
self.loop.close()

async def __aenter__(self) -> Self:
await self.reconnect()
await self.connect()
return self

async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
await self._send("DISCONNECT")
self.socket.disconnect(self.url)
await self._term_receiver_task()
self.term()

async def _term_receiver_task(self):
if self._receiver_task and not self._receiver_task.done():
self._receiver_task.cancel()
await asyncio.gather(self._receiver_task, return_exceptions=True)
self._receiver_task = None

def __init__(
self,
url: str,
Expand All @@ -67,17 +72,17 @@ def __init__(
self.url = url
self.token = token

# Set up ZeroMQ context and socket
# Set up ZeroMQ context and socke
self._ack_event: asyncio.Event = asyncio.Event()
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.DEALER)
self.socket.setsockopt(zmq.LINGER, 0)
if dealer_name is None:
dispatch_id = f"dispatch-{uuid.uuid4().hex[:8]}"
self.dealer_id = f"dispatch-{uuid.uuid4().hex[:8]}"
else:
dispatch_id = dealer_name
self.dispatch_id = dispatch_id
self.socket.setsockopt_string(zmq.IDENTITY, dispatch_id)
print(f"{self.dispatch_id} {token}")
self.dealer_id = dealer_name
self.socket.setsockopt_string(zmq.IDENTITY, self.dealer_id)
print(f"Created: {self.dealer_id=} {token=} {self._connection_timeout=}")
if token is not None:
client_public, client_secret = zmq.curve_keypair()
self.socket.curve_secretkey = client_secret
Expand All @@ -86,33 +91,46 @@ def __init__(

self._max_retries = max_retries
self.loop = new_event_loop()
self._receiver_task: Optional[asyncio.Task[None]] = None

async def reconnect(self) -> None:
async def connect(self) -> None:
self.socket.connect(self.url)
print(f"{self.dispatch_id=} CONNECTING to {self.url=}")
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self._send("CONNECT", max_retries=1)
_, ack = await asyncio.wait_for(
self.socket.recv_multipart(), timeout=self._connection_timeout
)
if ack.decode() != "ACK":
raise ClientConnectionError("No Ack for connect")
print(f"{self.dispatch_id=} CONNECTED to {self.url=}")
except asyncio.TimeoutError as exc:
logger.warning("Failed to get acknowledgment on dealer connect!")
except ClientConnectionError:
await self._term_receiver_task()
self.term()
raise ClientConnectionError(
"Connection to evaluator not established!"
) from exc
raise

def send(
self, messages: str | list[str], max_retries: Optional[int] = None
) -> None:
self.loop.run_until_complete(self._send(messages, max_retries))

async def process_message(self, msg: str) -> None:
pass

async def _receiver(self) -> None:
while True:
try:
_, raw_msg = await self.socket.recv_multipart()
if raw_msg == b"ACK":
self._ack_event.set()
else:
await self.process_message(raw_msg.decode("utf-8"))
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
await asyncio.sleep(1)
self.socket.connect(self.url)

async def _send(
self, messages: str | list[str], max_retries: Optional[int] = None
) -> None:
self._ack_event.clear()
if isinstance(messages, str):
messages = [messages]

Expand All @@ -124,27 +142,21 @@ async def _send(
await self.socket.send_multipart(
[b""] + [message.encode("utf-8") for message in messages]
)

# Wait for acknowledgment
try:
_, ack = await asyncio.wait_for(
self.socket.recv_multipart(), timeout=self._connection_timeout
)
if ack.decode() == "ACK":
logger.info("Message acknowledged.")
print(f"message sent {messages=} from {self.dispatch_id=}")
return
logger.warning(
"Got acknowledgment but not the expected message. Resending."
await asyncio.wait_for(
self._ack_event.wait(), timeout=self._connection_timeout
)
return
except asyncio.TimeoutError:
logger.warning(
"Failed to get acknowledgment on the message. Resending."
)

except zmq.ZMQError as e:
logger.warning(f"ZMQ error occurred: {e}. Reconnecting...")
await self.reconnect()
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
await asyncio.sleep(1)
self.socket.connect(self.url)
except asyncio.CancelledError:
self.term()
raise
Expand All @@ -155,4 +167,6 @@ async def _send(
await asyncio.sleep(backoff)
backoff = min(backoff * 2, 10) # Exponential backoff

raise ClientConnectionError("Failed to send message after retries.")
raise ClientConnectionError(
f"{self.dealer_id} Failed to send {messages=} after retries."
)
6 changes: 3 additions & 3 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ async def handle_client(self, dealer: bytes, frames: list[bytes]) -> None:
for frame in frames:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
await self._router_socket.send_multipart([dealer, b"", b"ACK"])
self._clients_connected.add(dealer)
self._clients_empty.clear()
current_snapshot_dict = self._ensemble.snapshot.to_dict()
Expand Down Expand Up @@ -234,7 +233,6 @@ async def handle_dispatch(self, dealer: bytes, frames: list[bytes]) -> None:
for frame in frames:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
print(f"GOT MESSAGE {raw_msg=} from {dealer=}")
self._dispatchers_connected.add(dealer)
self._dispatchers_connected.clear()
elif raw_msg == "DISCONNECT":
Expand All @@ -260,11 +258,13 @@ async def listen_for_messages(self) -> None:
while True:
try:
dealer, _, *frames = await self._router_socket.recv_multipart()
# print(f"GOT MESSAGE {frames=} from {dealer=}")
await self._router_socket.send_multipart([dealer, b"", b"ACK"])
sender = dealer.decode("utf-8")
if sender.startswith("client"):
await self.handle_client(dealer, frames)
elif sender.startswith("dispatch"):
await self._router_socket.send_multipart([dealer, b"", b"ACK"])
# await self._router_socket.send_multipart([dealer, b"", b"ACK"])
await self.handle_dispatch(dealer, frames)
else:
logger.info(f"Connection attempt to unknown sender: {sender}.")
Expand Down
105 changes: 18 additions & 87 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

import asyncio
import logging
import ssl
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union

import zmq.asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Final, Optional, Union

from _ert.events import (
EETerminated,
Expand All @@ -16,6 +13,7 @@
event_from_json,
event_to_json,
)
from _ert.forward_model_runner.client import Client

if TYPE_CHECKING:
from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo
Expand All @@ -28,74 +26,41 @@ class EventSentinel:
pass


class Monitor:
class Monitor(Client):
_sentinel: Final = EventSentinel()

def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None:
self._ee_con_info = ee_con_info
self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0]
self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue()
self._receiver_task: Optional[asyncio.Task[None]] = None
self._connected: asyncio.Future[None] = asyncio.Future()
self._connection_timeout: float = 120.0
self._receiver_timeout: float = 60.0
# zmq connection
self._zmq_context = zmq.asyncio.Context()
self._socket = self._zmq_context.socket(zmq.DEALER)
self._socket.setsockopt(zmq.LINGER, 0)
self._socket.setsockopt_string(zmq.IDENTITY, f"client-{self._id}")
print(f"{self._id=} wiith {ee_con_info.token=}")
if ee_con_info.token is not None:
client_public, client_secret = zmq.curve_keypair()
self._socket.curve_secretkey = client_secret
self._socket.curve_publickey = client_public
self._socket.curve_serverkey = ee_con_info.token.encode("utf-8")

async def __aenter__(self) -> "Monitor":
try:
await self.reconnect()
except asyncio.TimeoutError as exc:
await self._term()
msg = "Couldn't establish connection with the ensemble evaluator!"
logger.error(msg)
raise RuntimeError(msg) from exc
self._receiver_task = asyncio.create_task(self._receiver())
return self

async def _term(self) -> None:
if self._receiver_task:
await self._socket.send_multipart([b"", b"DISCONNECT"])
self._socket.disconnect(self._ee_con_info.router_uri)
if not self._receiver_task.done():
self._receiver_task.cancel()
await asyncio.gather(
self._receiver_task,
return_exceptions=True,
)
self._socket.close()
self._zmq_context.term()

async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self._term()
super().__init__(
ee_con_info.router_uri,
ee_con_info.token,
ee_con_info.cert,
dealer_name=f"client-{self._id}",
)

async def process_message(self, msg: str) -> None:
# print("*********")
# print(f"Monitor {self.dealer_id} processing {msg} from evaluator")
# print("*********")
event = event_from_json(msg)
await self._event_queue.put(event)

async def signal_cancel(self) -> None:
await self._event_queue.put(Monitor._sentinel)
logger.debug(f"monitor-{self._id} asking server to cancel...")

cancel_event = EEUserCancel(monitor=self._id)
await self._socket.send_multipart(
[b"", event_to_json(cancel_event).encode("utf-8")]
)
await self._send(event_to_json(cancel_event))
logger.debug(f"monitor-{self._id} asked server to cancel")

async def signal_done(self) -> None:
await self._event_queue.put(Monitor._sentinel)
logger.debug(f"monitor-{self._id} informing server monitor is done...")

done_event = EEUserDone(monitor=self._id)
await self._socket.send_multipart(
[b"", event_to_json(done_event).encode("utf-8")]
)
await self._send(event_to_json(done_event))
logger.debug(f"monitor-{self._id} informed server monitor is done")

async def track(
Expand Down Expand Up @@ -128,37 +93,3 @@ async def track(
break
if event is not None:
self._event_queue.task_done()

async def reconnect(self) -> None:
self._socket.connect(self._ee_con_info.router_uri)
await self._socket.send_multipart([b"", b"CONNECT"])
try:
_, ack = await asyncio.wait_for(
self._socket.recv_multipart(), timeout=self._connection_timeout
)
if ack.decode() != "ACK":
raise asyncio.TimeoutError("No Ack for connect")
print(f"{self._id=} MONITOR CONNECTED")
except asyncio.TimeoutError:
print("NO CONNECTION")
logger.warning(
f"Failed to get acknowledgment on the monitor {self._id} connect!"
)
raise

async def _receiver(self) -> None:
tls: Optional[ssl.SSLContext] = None
if self._ee_con_info.cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_con_info.cert)
while True:
try:
_, raw_msg = await self._socket.recv_multipart()
event = event_from_json(raw_msg.decode("utf-8"))
await self._event_queue.put(event)
except zmq.ZMQError as exc:
# Handle disconnection or other ZMQ errors (reconnect or log)
logger.debug(
f"ZeroMQ connection to EnsembleEvaluator went down, reconnecting: {exc}"
)
await self.reconnect()

0 comments on commit 9cfd2fd

Please sign in to comment.