Skip to content

Commit

Permalink
WIP: evaluator -> zmq
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Nov 12, 2024
1 parent 3d937f1 commit 9285352
Showing 1 changed file with 43 additions and 139 deletions.
182 changes: 43 additions & 139 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import annotations

import asyncio
import datetime
import logging
import traceback
from contextlib import asynccontextmanager, contextmanager
from http import HTTPStatus
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Expand All @@ -22,15 +20,9 @@
get_args,
)

import websockets
import zmq.asyncio
from pydantic_core._pydantic_core import ValidationError
from websockets.datastructures import Headers, HeadersLike
from websockets.exceptions import ConnectionClosedError
from websockets.server import WebSocketServerProtocol

from _ert.events import (
EESnapshot,
EESnapshotUpdate,
EETerminated,
EEUserCancel,
Expand Down Expand Up @@ -71,7 +63,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):

self._loop: Optional[asyncio.AbstractEventLoop] = None

self._clients: Set[WebSocketServerProtocol] = set()
self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue()

self._events: asyncio.Queue[Event] = asyncio.Queue()
Expand All @@ -89,19 +80,16 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._max_batch_size: int = 500
self._batching_interval: int = 2
self._complete_batch: asyncio.Event = asyncio.Event()
self._zmq_context: zmq.asyncio.Context | None = None

async def _initialize_zmq(self) -> None:
self._zmq_context = zmq.asyncio.Context()
self._receiver_socket = self._zmq_context.socket(zmq.PULL)
self._publisher_socket = self._zmq_context.socket(zmq.PUB)
self._zmq_context = zmq.asyncio.Context() # type: ignore
self._listen_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL)
self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUB)

async def _publisher(self) -> None:
while True:
event = await self._events_to_send.get()
# await asyncio.gather(
# *[client.send(event_to_json(event)) for client in self._clients],
# return_exceptions=True,
# )
self._publisher_socket.send_json(event_to_json(event))
self._events_to_send.task_done()

Expand Down Expand Up @@ -213,139 +201,54 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
def ensemble(self) -> Ensemble:
return self._ensemble

@contextmanager
def store_client(
self, websocket: WebSocketServerProtocol
) -> Generator[None, None, None]:
self._clients.add(websocket)
yield
self._clients.remove(websocket)

async def handle_client(self, websocket: WebSocketServerProtocol) -> None:
with self.store_client(websocket):
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
snapshot=current_snapshot_dict, ensemble=self.ensemble.id_
)
await websocket.send(event_to_json(event))

async for raw_msg in websocket:
async def listen_for_messages(self) -> None:
while True:
sender, raw_msg = await self._listen_socket.recv_multipart()
sender = sender.decode("utf-8")
if sender == "client":
event = event_from_json(raw_msg)
logger.debug(f"got message from client: {event}")
if type(event) is EEUserCancel:
logger.debug(f"Client {websocket.remote_address} asked to cancel.")
logger.debug("Client asked to cancel.")
self._signal_cancel()

elif type(event) is EEUserDone:
logger.debug(f"Client {websocket.remote_address} signalled done.")
logger.debug("Client signalled done.")
self.stop()

@asynccontextmanager
async def count_dispatcher(self) -> AsyncIterator[None]:
await self._dispatchers_connected.put(None)
yield
await self._dispatchers_connected.get()
self._dispatchers_connected.task_done()

async def handle_dispatch(self, websocket: WebSocketServerProtocol) -> None:
async with self.count_dispatcher():
try:
async for raw_msg in websocket:
try:
event = dispatch_event_from_json(raw_msg)
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
f"{event.ensemble}. "
f"Ignoring since I am {self.ensemble.id_}"
)
continue
if type(event) is ForwardModelStepChecksum:
await self.forward_checksum(event)
else:
await self._events.put(event)
except ValidationError as ex:
logger.warning(
"cannot handle event - "
f"closing connection to dispatcher: {ex}"
)
await websocket.close(
code=1011, reason=f"failed handling message {raw_msg!r}"
)
return

if type(event) in [EnsembleSucceeded, EnsembleFailed]:
return
except ConnectionClosedError as connection_error:
# Dispatchers may close the connection abruptly in the case of
# * flaky network (then the dispatcher will try to reconnect)
# * job being killed due to MAX_RUNTIME
# * job being killed by user
logger.error(
f"a dispatcher abruptly closed a websocket: {connection_error!s}"
)
elif sender == "dispatch":
event = dispatch_event_from_json(raw_msg)
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
f"{event.ensemble}. "
f"Ignoring since I am {self.ensemble.id_}"
)
continue
if type(event) is ForwardModelStepChecksum:
await self.forward_checksum(event)
else:
await self._events.put(event)
if type(event) in [EnsembleSucceeded, EnsembleFailed]:
return
else:
logger.info(f"Connection attempt to unknown sender: {sender}.")

async def forward_checksum(self, event: Event) -> None:
# clients still need to receive events via ws
await self._events_to_send.put(event)
await self._manifest_queue.put(event)

async def connection_handler(self, websocket: WebSocketServerProtocol) -> None:
path = websocket.path
elements = path.split("/")
if elements[1] == "client":
await self.handle_client(websocket)
elif elements[1] == "dispatch":
await self.handle_dispatch(websocket)
else:
logger.info(f"Connection attempt to unknown path: {path}.")

async def process_request(
self, path: str, request_headers: Headers
) -> Optional[Tuple[HTTPStatus, HeadersLike, bytes]]:
if request_headers.get("token") != self._config.token:
return HTTPStatus.UNAUTHORIZED, {}, b""
if path == "/healthcheck":
return HTTPStatus.OK, {}, b""
return None

async def _server(self) -> None:
async with websockets.serve(
self.connection_handler,
sock=self._config.get_socket(),
ssl=self._config.get_server_ssl_context(),
process_request=self.process_request,
max_queue=None,
max_size=2**26,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
) as server:
self._server_started.set()
await self._server_done.wait()
server.close(close_connections=False)
if self._dispatchers_connected is not None:
logger.debug(
f"Got done signal. {self._dispatchers_connected.qsize()} "
"dispatchers to disconnect..."
)
try: # Wait for dispatchers to disconnect
await asyncio.wait_for(
self._dispatchers_connected.join(), timeout=20
)
except asyncio.TimeoutError:
logger.debug("Timed out waiting for dispatchers to disconnect")
else:
logger.debug("Got done signal. No dispatchers connected")

logger.debug("Sending termination-message to clients...")

await self._events.join()
await self._complete_batch.wait()
await self._batch_processing_queue.join()
event = EETerminated(ensemble=self._ensemble.id_)
await self._events_to_send.put(event)
await self._events_to_send.join()
await self._initialize_zmq()
self._server_started.set()
await self._server_done.wait()

await self._events.join()
await self._complete_batch.wait()
await self._batch_processing_queue.join()
event = EETerminated(ensemble=self._ensemble.id_)
await self._events_to_send.put(event)
await self._events_to_send.join()
self._listen_socket.close()
self._publisher_socket.close()
logger.debug("Async server exiting.")

def stop(self) -> None:
Expand Down Expand Up @@ -379,6 +282,7 @@ async def _start_running(self) -> None:
),
asyncio.create_task(self._process_event_buffer(), name="processing_task"),
asyncio.create_task(self._publisher(), name="publisher_task"),
asyncio.create_task(self.listen_for_messages(), name="listener_task"),
]
# now we wait for the server to actually start
await self._server_started.wait()
Expand Down

0 comments on commit 9285352

Please sign in to comment.