Skip to content

Commit

Permalink
Fixing staff...
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Nov 14, 2024
1 parent 2d53e4e commit 0a25ff4
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 49 deletions.
2 changes: 0 additions & 2 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def reconnect(self):

def send(self, message):
try:
# if self.token:
# message = f"{self.token}:{message}"
self.socket.send_multipart([b"dispatch", message.encode("utf-8")])
except zmq.ZMQError as e:
logger.warning(f"Failed to send message: {e}")
Expand Down
116 changes: 71 additions & 45 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):

self._loop: Optional[asyncio.AbstractEventLoop] = None

self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue()

self._events: asyncio.Queue[Event] = asyncio.Queue()
self._events_to_send: asyncio.Queue[Event] = asyncio.Queue()
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()
Expand All @@ -81,20 +79,22 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._batching_interval: int = 2
self._complete_batch: asyncio.Event = asyncio.Event()
self._zmq_context: zmq.asyncio.Context | None = None
self._clients_connected: asyncio.Queue[None] = asyncio.Queue()
self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue()

async def _initialize_zmq(self) -> None:
self._zmq_context = zmq.asyncio.Context() # type: ignore
try:
self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL)
self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}")
self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(
zmq.PUB
zmq.XPUB
)
self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}")
except zmq.error.ZMQError as e:
logger.error(f"ZMQ error: {e}")
raise
logger.error("ZMQ initialized")
logger.info("ZMQ initialized")

async def _publisher(self) -> None:
while True:
Expand Down Expand Up @@ -212,51 +212,71 @@ def ensemble(self) -> Ensemble:

async def listen_for_clients(self) -> None:
while True:
event = await self._publisher_socket.recv()
# TODO change to router-dealer as this would inform all subscribers about the snapshot
if event[0] == 1:
print("Subscriber connected")
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
snapshot=current_snapshot_dict, ensemble=self.ensemble.id_
)
await self._publisher_socket.send_string(event_to_json(event))

elif event[0] == 0:
print("Subscriber disconnected")
try:
raw_msg = await self._publisher_socket.recv()
# this would inform all subscribers about the snapshot
if raw_msg[0] == 1:
await self._clients_connected.put(None)
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
snapshot=current_snapshot_dict, ensemble=self.ensemble.id_
)
await self._publisher_socket.send_string(event_to_json(event))

elif raw_msg[0] == 0:
await self._clients_connected.get()
self._clients_connected.task_done()
except zmq.error.ZMQError as e:
if e.errno == zmq.ENOTSOCK:
logger.warning(
"Evaluator publisher closed, no new clients accepted"
)
else:
logger.error(f"Unexpected error when connecting new clients: {e}")
return
except asyncio.CancelledError:
return

async def listen_for_messages(self) -> None:
while True:
sender, raw_msg = await self._pull_socket.recv_multipart()
sender = sender.decode("utf-8")
raw_msg = raw_msg.decode("utf-8")
if sender == "client":
print(f"Got client {raw_msg=}")
event = event_from_json(raw_msg)
if type(event) is EEUserCancel:
logger.debug("Client asked to cancel.")
self._signal_cancel()
elif type(event) is EEUserDone:
logger.debug("Client signalled done.")
self.stop()
elif sender == "dispatch":
event = dispatch_event_from_json(raw_msg)
# print(f"Got dispatch {event=}")
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
f"{event.ensemble}. "
f"Ignoring since I am {self.ensemble.id_}"
try:
sender, raw_msg = await self._pull_socket.recv_multipart()
sender = sender.decode("utf-8")
raw_msg = raw_msg.decode("utf-8")
if sender == "client":
event = event_from_json(raw_msg)
if type(event) is EEUserCancel:
logger.debug("Client asked to cancel.")
self._signal_cancel()
elif type(event) is EEUserDone:
logger.debug("Client signalled done.")
self.stop()
elif sender == "dispatch":
event = dispatch_event_from_json(raw_msg)
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
f"{event.ensemble}. "
f"Ignoring since I am {self.ensemble.id_}"
)
continue
if type(event) is ForwardModelStepChecksum:
await self.forward_checksum(event)
else:
await self._events.put(event)
# if type(event) in [EnsembleSucceeded, EnsembleFailed]:
# return
else:
logger.info(f"Connection attempt to unknown sender: {sender}.")
except zmq.error.ZMQError as e:
if e.errno == zmq.ENOTSOCK:
logger.warning(
"Evaluator receiver closed, no new messages are received"
)
continue
if type(event) is ForwardModelStepChecksum:
await self.forward_checksum(event)
else:
await self._events.put(event)
# if type(event) in [EnsembleSucceeded, EnsembleFailed]:
# return
else:
logger.info(f"Connection attempt to unknown sender: {sender}.")
logger.error(f"Unexpected error when listening to messages: {e}")
except asyncio.CancelledError:
return

async def forward_checksum(self, event: Event) -> None:
# clients still need to receive events via ws
Expand All @@ -271,6 +291,7 @@ async def _server(self) -> None:
event = EETerminated(ensemble=self._ensemble.id_)
await self._events_to_send.put(event)
await self._events_to_send.join()
await self._clients_connected.join()
self._pull_socket.close()
self._publisher_socket.close()
logger.debug("Async server exiting.")
Expand Down Expand Up @@ -308,6 +329,7 @@ async def _start_running(self) -> None:
asyncio.create_task(self._process_event_buffer(), name="processing_task"),
asyncio.create_task(self._publisher(), name="publisher_task"),
asyncio.create_task(self.listen_for_messages(), name="listener_task"),
asyncio.create_task(self.listen_for_clients(), name="client_task"),
]

self._ee_tasks.append(
Expand Down Expand Up @@ -360,7 +382,11 @@ async def _monitor_and_handle_tasks(self) -> None:
if stop_timeout_task:
stop_timeout_task.cancel()
return
elif task.get_name() == "ensemble_task":
elif task.get_name() in [
"ensemble_task",
"listener_task",
"client_task",
]:
stop_timeout_task = asyncio.create_task(
self._wait_for_stopped_server()
)
Expand Down
4 changes: 2 additions & 2 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def _receiver(self) -> None:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_con_info.cert)

self._listen_socket = self._zmq_context.socket(zmq.XSUB)
self._listen_socket = self._zmq_context.socket(zmq.SUB)
self._listen_socket.connect(self._ee_con_info.pub_sub_uri)
self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "")

Expand All @@ -139,7 +139,7 @@ async def _receiver(self) -> None:
raw_msg = await self._listen_socket.recv_string()
# print(f"monitor-{self._id} received msg: {raw_msg}")
event = event_from_json(raw_msg)
# print(f"monitor-{self._id} received event: {event}")
print(f"monitor-{self._id} received event: {event}")
await self._event_queue.put(event)
except zmq.ZMQError as exc:
# Handle disconnection or other ZMQ errors (reconnect or log)
Expand Down

0 comments on commit 0a25ff4

Please sign in to comment.