Skip to content

Commit

Permalink
WIP: full snapshot udpate needs to be sent to client on connection
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Nov 14, 2024
1 parent 932c425 commit 2d53e4e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
22 changes: 19 additions & 3 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import zmq.asyncio

from _ert.events import (
EESnapshot,
EESnapshotUpdate,
EETerminated,
EEUserCancel,
Expand Down Expand Up @@ -209,6 +210,21 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
def ensemble(self) -> Ensemble:
return self._ensemble

async def listen_for_clients(self) -> None:
while True:
event = await self._publisher_socket.recv()
# TODO change to router-dealer as this would inform all subscribers about the snapshot
if event[0] == 1:
print("Subscriber connected")
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
snapshot=current_snapshot_dict, ensemble=self.ensemble.id_
)
await self._publisher_socket.send_string(event_to_json(event))

elif event[0] == 0:
print("Subscriber disconnected")

async def listen_for_messages(self) -> None:
while True:
sender, raw_msg = await self._pull_socket.recv_multipart()
Expand All @@ -224,8 +240,8 @@ async def listen_for_messages(self) -> None:
logger.debug("Client signalled done.")
self.stop()
elif sender == "dispatch":
print(f"Got dispatch {raw_msg=}")
event = dispatch_event_from_json(raw_msg)
# print(f"Got dispatch {event=}")
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
Expand All @@ -237,8 +253,8 @@ async def listen_for_messages(self) -> None:
await self.forward_checksum(event)
else:
await self._events.put(event)
if type(event) in [EnsembleSucceeded, EnsembleFailed]:
return
# if type(event) in [EnsembleSucceeded, EnsembleFailed]:
# return
else:
logger.info(f"Connection attempt to unknown sender: {sender}.")

Expand Down
13 changes: 6 additions & 7 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import json
import logging
import ssl
import uuid
Expand Down Expand Up @@ -76,7 +75,7 @@ async def signal_cancel(self) -> None:

cancel_event = EEUserCancel(monitor=self._id)
await self._push_socket.send_multipart(
[b"client", event_to_json(cancel_event).encode()]
[b"client", event_to_json(cancel_event).encode("utf-8")]
)
logger.debug(f"monitor-{self._id} asked server to cancel")

Expand All @@ -86,7 +85,7 @@ async def signal_done(self) -> None:

done_event = EEUserDone(monitor=self._id)
await self._push_socket.send_multipart(
[b"client", event_to_json(done_event).encode()]
[b"client", event_to_json(done_event).encode("utf-8")]
)
logger.debug(f"monitor-{self._id} informed server monitor is done")

Expand Down Expand Up @@ -127,7 +126,7 @@ async def _receiver(self) -> None:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_con_info.cert)

self._listen_socket = self._zmq_context.socket(zmq.SUB)
self._listen_socket = self._zmq_context.socket(zmq.XSUB)
self._listen_socket.connect(self._ee_con_info.pub_sub_uri)
self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "")

Expand All @@ -138,11 +137,11 @@ async def _receiver(self) -> None:
while True:
try:
raw_msg = await self._listen_socket.recv_string()
raw_msg = json.loads(raw_msg)
# print(f"monitor-{self._id} received msg: {raw_msg}")
event = event_from_json(raw_msg)
print(f"monitor-{self._id} received event: {event}")
# print(f"monitor-{self._id} received event: {event}")
await self._event_queue.put(event)
except (zmq.ZMQError, asyncio.CancelledError) as exc:
except zmq.ZMQError as exc:
# Handle disconnection or other ZMQ errors (reconnect or log)
logger.debug(
f"ZeroMQ connection to EnsembleEvaluator went down, reconnecting: {exc}"
Expand Down
8 changes: 1 addition & 7 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@

import numpy as np

from _ert.events import (
EESnapshot,
EESnapshotUpdate,
EETerminated,
Event,
)
from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event
from ert.analysis import (
AnalysisEvent,
AnalysisStatusEvent,
Expand Down Expand Up @@ -507,7 +502,6 @@ async def run_monitor(
event,
iteration,
)

if event.snapshot.get(STATUS) in [
ENSEMBLE_STATE_STOPPED,
ENSEMBLE_STATE_FAILED,
Expand Down

0 comments on commit 2d53e4e

Please sign in to comment.