From 495aabef01385bdd8037a94269091a81372e075a Mon Sep 17 00:00:00 2001 From: xjules Date: Wed, 17 Jan 2024 16:07:17 +0100 Subject: [PATCH] Make EnsembleEvaluator async A new asyncio.loop is set in base_run_mode to run the entire EEAsync. --- src/ert/ensemble_evaluator/__init__.py | 2 + .../ensemble_evaluator/_builder/_ensemble.py | 2 + .../ensemble_evaluator/_builder/_legacy.py | 27 ++ src/ert/ensemble_evaluator/evaluator_async.py | 423 ++++++++++++++++++ .../ensemble_evaluator/evaluator_tracker.py | 1 + src/ert/job_queue/queue.py | 2 +- src/ert/run_models/base_run_model.py | 36 +- src/ert/scheduler/scheduler.py | 1 + test-data/poly_example/poly.ert | 2 +- 9 files changed, 485 insertions(+), 11 deletions(-) create mode 100644 src/ert/ensemble_evaluator/evaluator_async.py diff --git a/src/ert/ensemble_evaluator/__init__.py b/src/ert/ensemble_evaluator/__init__.py index ea093217a94..52f33d19e48 100644 --- a/src/ert/ensemble_evaluator/__init__.py +++ b/src/ert/ensemble_evaluator/__init__.py @@ -7,6 +7,7 @@ ) from .config import EvaluatorServerConfig from .evaluator import EnsembleEvaluator +from .evaluator_async import EnsembleEvaluatorAsync from .evaluator_tracker import EvaluatorTracker from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent from .monitor import Monitor @@ -17,6 +18,7 @@ "Ensemble", "EnsembleBuilder", "EnsembleEvaluator", + "EnsembleEvaluatorAsync", "EvaluatorServerConfig", "EvaluatorTracker", "ForwardModel", diff --git a/src/ert/ensemble_evaluator/_builder/_ensemble.py b/src/ert/ensemble_evaluator/_builder/_ensemble.py index 16ec3463818..be1ed65bffa 100644 --- a/src/ert/ensemble_evaluator/_builder/_ensemble.py +++ b/src/ert/ensemble_evaluator/_builder/_ensemble.py @@ -148,7 +148,9 @@ async def send_cloudevent( retries: int = 10, ) -> None: async with Client(url, token, cert, max_retries=retries) as client: + print(f"DEBUG before send {event=}") await client._send(to_json(event, data_marshaller=evaluator_marshaller)) + print(f"DEBUG after send {event=}") def get_successful_realizations(self) -> List[int]: return self._snapshot.get_successful_realizations() diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index c0735d32acd..60d9fc94202 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -121,6 +121,32 @@ def evaluate(self, config: EvaluatorServerConfig) -> None: threading.Thread(target=self._evaluate, name="LegacyEnsemble").start() + async def evaluate_async(self, config: EvaluatorServerConfig) -> asyncio.Task[Any]: + if not config: + raise ValueError("no config for evaluator") + self._config = config + await wait_for_evaluator( + base_url=self._config.url, + token=self._config.token, + cert=self._config.cert, + ) + ce_unary_send_method_name = "_ce_unary_send" + setattr( + self.__class__, + ce_unary_send_method_name, + partialmethod( + self.__class__.send_cloudevent, + self._config.dispatch_uri, + token=self._config.token, + cert=self._config.cert, + ), + ) + return asyncio.create_task( + self._evaluate_inner( + cloudevent_unary_send=getattr(self, ce_unary_send_method_name) + ) + ) + def _evaluate(self) -> None: """ This method is executed on a separate thread, i.e. in parallel @@ -244,6 +270,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches await send_timeout_future # Dispatch final result from evaluator - FAILED, CANCEL or STOPPED + print(f"DEBUG final event!@!!!!!! {result=}") await cloudevent_unary_send(event_creator(result, None)) @property diff --git a/src/ert/ensemble_evaluator/evaluator_async.py b/src/ert/ensemble_evaluator/evaluator_async.py new file mode 100644 index 00000000000..cc95d1640d3 --- /dev/null +++ b/src/ert/ensemble_evaluator/evaluator_async.py @@ -0,0 +1,423 @@ +import asyncio +import logging +import pickle +from contextlib import asynccontextmanager, contextmanager +from http import HTTPStatus +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Tuple, +) + +import cloudevents.exceptions +import cloudpickle +import websockets +from cloudevents.conversion import to_json +from cloudevents.http import CloudEvent, from_json +from websockets.datastructures import Headers, HeadersLike +from websockets.exceptions import ConnectionClosedError +from websockets.legacy.server import WebSocketServerProtocol + +from ert.serialization import evaluator_marshaller, evaluator_unmarshaller + +from ._builder import Ensemble +from .config import EvaluatorServerConfig +from .identifiers import ( + EVGROUP_FM_ALL, + EVTYPE_EE_SNAPSHOT, + EVTYPE_EE_SNAPSHOT_UPDATE, + EVTYPE_EE_TERMINATED, + EVTYPE_EE_USER_CANCEL, + EVTYPE_EE_USER_DONE, + EVTYPE_ENSEMBLE_CANCELLED, + EVTYPE_ENSEMBLE_FAILED, + EVTYPE_ENSEMBLE_STARTED, + EVTYPE_ENSEMBLE_STOPPED, +) +from .snapshot import PartialSnapshot +from .state import ( + ENSEMBLE_STATE_CANCELLED, + ENSEMBLE_STATE_FAILED, + ENSEMBLE_STATE_STOPPED, +) + +logger = logging.getLogger(__name__) +logger.debug = print + +_MAX_UNSUCCESSFUL_CONNECTION_ATTEMPTS = 3 + + +class EnsembleEvaluatorAsync: + def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig, iter_: int): + # Without information on the iteration, the events emitted from the + # evaluator are ambiguous. In the future, an experiment authority* will + # "own" the evaluators and can add iteration information to events they + # emit. In the meantime, it is added here. + # * https://github.com/equinor/ert/issues/1250 + self._iter: int = iter_ + self._config: EvaluatorServerConfig = config + self._ensemble: Ensemble = ensemble + + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._done: asyncio.Future[bool] = asyncio.Future() + + self._clients: Set[WebSocketServerProtocol] = set() + self._dispatchers_connected: Optional[asyncio.Queue[None]] = None + self._snapshot_mutex = asyncio.Lock() + + self._events: asyncio.Queue[CloudEvent] = asyncio.Queue() + + self._result = None + + self._server_task: Optional[asyncio.Task] = None + self._dispatcher_task: Optional[asyncio.Task] = None + self._evaluator_task: Optional[asyncio.Task] = None + + async def dispatcher(self): + logger.debug("dispatcher started!!!!****") + + event_handler = {} + # raise ValueError("TEST exception") + + def set_handler(event_types, function): + for event_type in event_types: + event_handler[event_type] = function + + for e_type, f in ( + (EVGROUP_FM_ALL, self._fm_handler), + ({EVTYPE_ENSEMBLE_STARTED}, self._started_handler), + ({EVTYPE_ENSEMBLE_STOPPED}, self._stopped_handler), + ({EVTYPE_ENSEMBLE_CANCELLED}, self._cancelled_handler), + ({EVTYPE_ENSEMBLE_FAILED}, self._failed_handler), + ): + set_handler(e_type, f) + + logger.debug("dispatcher started!!!!****") + while True: + event = await self._events.get() + logger.debug(f"EVENT-logging: {event}") + await event_handler[event["type"]]([event]) + print(f"DEBUG: event processed {event}!!!!") + logger.debug(f"DEBUG: event processed {event}!!!!") + + @property + def config(self) -> EvaluatorServerConfig: + return self._config + + @property + def ensemble(self) -> Ensemble: + return self._ensemble + + async def _fm_handler(self, events: List[CloudEvent]) -> None: + await self._send_snapshot_update(self.ensemble.update_snapshot(events)) + + async def _started_handler(self, events: List[CloudEvent]) -> None: + if self.ensemble.status != ENSEMBLE_STATE_FAILED: + await self._send_snapshot_update(self.ensemble.update_snapshot(events)) + print("DEBUG: STARTED - snapshot sent!!!!!!!") + + async def _stopped_handler(self, events: List[CloudEvent]) -> None: + if self.ensemble.status != ENSEMBLE_STATE_FAILED: + self._result = events[0].data # normal termination + async with self._snapshot_mutex: + max_memory_usage = -1 + for job in self.ensemble.snapshot.get_all_forward_models().values(): + memory_usage = job.max_memory_usage or "-1" + if int(memory_usage) > max_memory_usage: + max_memory_usage = int(memory_usage) + logger.info( + f"Ensemble ran with maximum memory usage for a single realization job: {max_memory_usage}" + ) + await self._send_snapshot_update(self.ensemble.update_snapshot(events)) + + async def _cancelled_handler(self, events: List[CloudEvent]) -> None: + if self.ensemble.status != ENSEMBLE_STATE_FAILED: + await self._send_snapshot_update(self.ensemble.update_snapshot(events)) + await self._stop() + + async def _failed_handler(self, events: List[CloudEvent]) -> None: + if self.ensemble.status not in ( + ENSEMBLE_STATE_STOPPED, + ENSEMBLE_STATE_CANCELLED, + ): + # if list is empty this call is not triggered by an + # event, but as a consequence of some bad state + # create a fake event because that's currently the only + # api for setting state in the ensemble + if len(events) == 0: + events = [await self._create_cloud_event(EVTYPE_ENSEMBLE_FAILED)] + await self._send_snapshot_update(self.ensemble.update_snapshot(events)) + await self._signal_cancel() # let ensemble know it should stop + + async def _send_snapshot_update( + self, snapshot_update_event: PartialSnapshot + ) -> None: + print(f"DEBUG: {self._clients=}") + message = await self._create_cloud_message( + EVTYPE_EE_SNAPSHOT_UPDATE, + snapshot_update_event.to_dict(), + ) + print(f"DEBUG: {message=}") + logger.debug(f"DEBUG: sending {message=} to {self._clients=}") + if message and self._clients: + # Note return_exceptions=True in gather. This fire-and-forget + # approach is currently how we deal with failures when trying + # to send udates to clients. Rationale is that if sending to + # the client fails, the websocket is down and we have no way + # to re-establish it. Thus, it becomes the responsibility of + # the client to re-connect if necessary, in which case the first + # update it receives will be a full snapshot. + print(f"DEBUG: sending {message=} to {self._clients=}") + await asyncio.gather( + *[client.send(message) for client in self._clients], + return_exceptions=True, + ) + + async def _create_cloud_event( + self, + event_type: str, + data: Optional[Dict[str, Any]] = None, + extra_attrs: Optional[Dict[str, Any]] = None, + ) -> CloudEvent: + """Returns a CloudEvent with the given properties""" + if isinstance(data, dict): + data["iter"] = self._iter + if extra_attrs is None: + extra_attrs = {} + + attrs = { + "type": event_type, + "source": f"/ert/ensemble/{self.ensemble.id_}", + } + attrs.update(extra_attrs) + return CloudEvent( + attrs, + data, + ) + + async def _create_cloud_message( + self, + event_type: str, + data: Optional[Dict[str, Any]] = None, + extra_attrs: Optional[Dict[str, Any]] = None, + data_marshaller: Optional[Callable[[Any], Any]] = evaluator_marshaller, + ) -> str: + """Creates the CloudEvent and returns the serialized json-string""" + event = await self._create_cloud_event(event_type, data, extra_attrs) + return to_json(event, data_marshaller=data_marshaller).decode() + + @contextmanager + def store_client( + self, websocket: WebSocketServerProtocol + ) -> Generator[None, None, None]: + self._clients.add(websocket) + yield + self._clients.remove(websocket) + + async def handle_client( + self, websocket: WebSocketServerProtocol, path: str + ) -> None: + with self.store_client(websocket): + async with self._snapshot_mutex: + current_snapshot_dict = self._ensemble.snapshot.to_dict() + event = await self._create_cloud_message( + EVTYPE_EE_SNAPSHOT, current_snapshot_dict + ) + await websocket.send(event) + + async for message in websocket: + client_event = from_json( + message, data_unmarshaller=evaluator_unmarshaller + ) + logger.debug(f"got message from client: {client_event}") + if client_event["type"] == EVTYPE_EE_USER_CANCEL: + logger.debug(f"Client {websocket.remote_address} asked to cancel.") + await self._signal_cancel() + + elif client_event["type"] == EVTYPE_EE_USER_DONE: + logger.debug(f"Client {websocket.remote_address} signalled done.") + await self._stop() + + @asynccontextmanager + async def count_dispatcher(self) -> AsyncIterator[None]: + # do this here (not in __init__) to ensure the queue + # is created on the right event-loop + if self._dispatchers_connected is None: + self._dispatchers_connected = asyncio.Queue() + + await self._dispatchers_connected.put(None) + yield + await self._dispatchers_connected.get() + self._dispatchers_connected.task_done() + + async def handle_dispatch( + self, websocket: WebSocketServerProtocol, path: str + ) -> None: + async with self.count_dispatcher(): + try: + async for msg in websocket: + try: + event = from_json(msg, data_unmarshaller=evaluator_unmarshaller) + except cloudevents.exceptions.DataUnmarshallerError: + event = from_json(msg, data_unmarshaller=pickle.loads) + if self._get_ens_id(event["source"]) != self.ensemble.id_: + logger.info( + "Got event from evaluator " + f"{self._get_ens_id(event['source'])} " + f"with source {event['source']}, " + f"ignoring since I am {self.ensemble.id_}" + ) + continue + try: + # await self._dispatcher.handle_event(event) + await self._events.put(event) + except BaseException as ex: + # Exceptions include asyncio.InvalidStateError, and + # anything that self._*_handler() can raise (updates + # snapshots) + logger.warning( + "cannot handle event - " + f"closing connection to dispatcher: {ex}" + ) + await websocket.close( + code=1011, reason=f"failed handling {event}" + ) + return + + if event["type"] in [ + EVTYPE_ENSEMBLE_STOPPED, + EVTYPE_ENSEMBLE_FAILED, + ]: + return + except ConnectionClosedError as connection_error: + # Dispatchers my close the connection apruptly in the case of + # * flaky network (then the dispatcher will try to reconnect) + # * job being killed due to MAX_RUNTIME + # * job being killed by user + logger.error( + f"a dispatcher abruptly closed a websocket: {str(connection_error)}" + ) + + async def connection_handler( + self, websocket: WebSocketServerProtocol, path: str + ) -> None: + elements = path.split("/") + if elements[1] == "client": + await self.handle_client(websocket, path) + elif elements[1] == "dispatch": + await self.handle_dispatch(websocket, path) + else: + logger.info(f"Connection attempt to unknown path: {path}.") + + async def process_request( + self, path: str, request_headers: Headers + ) -> Optional[Tuple[HTTPStatus, HeadersLike, bytes]]: + if request_headers.get("token") != self._config.token: + return HTTPStatus.UNAUTHORIZED, {}, b"" + if path == "/healthcheck": + return HTTPStatus.OK, {}, b"" + return None + + async def evaluator_server(self) -> None: + async with websockets.serve( + self.connection_handler, + sock=self._config.get_socket(), + ssl=self._config.get_server_ssl_context(), + process_request=self.process_request, + max_queue=None, + max_size=2**26, + ping_timeout=60, + ping_interval=60, + close_timeout=60, + ): + logger.debug("Server started!!!") + await self._done + if self._dispatchers_connected is not None: + logger.debug( + f"Got done signal. {self._dispatchers_connected.qsize()} " + "dispatchers to disconnect..." + ) + try: # Wait for dispatchers to disconnect + await asyncio.wait_for( + self._dispatchers_connected.join(), timeout=20 + ) + except asyncio.TimeoutError: + logger.debug("Timed out waiting for dispatchers to disconnect") + else: + logger.debug("Got done signal. No dispatchers connected") + + # logger.debug("Waiting for batcher to finish...") + # try: + # await asyncio.wait_for(self._dispatcher_task, timeout=20) + # except asyncio.TimeoutError: + # logger.debug("Timed out waiting for batcher to finish") + self._dispatcher_task.cancel() + await self._dispatcher_task + + terminated_attrs: Dict[str, str] = {} + terminated_data = None + if self._result: + terminated_attrs["datacontenttype"] = "application/octet-stream" + terminated_data = cloudpickle.dumps(self._result) + + logger.debug("Sending termination-message to clients...") + message = await self._create_cloud_message( + EVTYPE_EE_TERMINATED, + data=terminated_data, + extra_attrs=terminated_attrs, + data_marshaller=cloudpickle.dumps, + ) + if self._clients: + # See note about return_exceptions=True above + await asyncio.gather( + *[client.send(message) for client in self._clients], + return_exceptions=True, + ) + + logger.debug("Async server exiting.") + + async def _stop(self) -> None: + if not self._done.done(): + self._done.set_result(None) + # if self._dispatcher_task: + # self._dispatcher_task.cancel() + # await self._dispatcher_task + + async def _signal_cancel(self) -> None: + """ + This is just a wrapper around logic for whether to signal cancel via + a cancellable ensemble or to use internal stop-mechanism directly + + I.e. if the ensemble can be cancelled, it is, otherwise cancel + is signalled internally. In both cases the evaluator waits for + the cancel-message to arrive before it shuts down properly. + """ + if self._ensemble.cancellable: + logger.debug("Cancelling current ensemble") + self._loop.run_in_executor(None, self._ensemble.cancel) + else: + logger.debug("Stopping current ensemble") + await self._stop() + + async def run_and_get_successful_realizations(self) -> List[int]: + self._loop = asyncio.get_running_loop() + self._server_task = asyncio.create_task(self.evaluator_server()) + self._dispatcher_task = asyncio.create_task(self.dispatcher()) + self._evaluator_task = await self._ensemble.evaluate_async(self._config) + + await asyncio.gather( + self._server_task, self._evaluator_task, return_exceptions=True + ) + logger.debug("Evaluator is done") + return self._ensemble.get_successful_realizations() + + @staticmethod + def _get_ens_id(source: str) -> str: + # the ens_id will be found at /ert/ensemble/ens_id/... + return source.split("/")[3] diff --git a/src/ert/ensemble_evaluator/evaluator_tracker.py b/src/ert/ensemble_evaluator/evaluator_tracker.py index 5dbd239420d..d222de9db25 100644 --- a/src/ert/ensemble_evaluator/evaluator_tracker.py +++ b/src/ert/ensemble_evaluator/evaluator_tracker.py @@ -122,6 +122,7 @@ def track( ) -> Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]]: while True: event = self._work_queue.get() + print(f"DEBUG: {event=}") if isinstance(event, str): with contextlib.suppress(GeneratorExit): # consumers may exit at this point, make sure the last diff --git a/src/ert/job_queue/queue.py b/src/ert/job_queue/queue.py index 918cfbb5adb..9d38c28231f 100644 --- a/src/ert/job_queue/queue.py +++ b/src/ert/job_queue/queue.py @@ -291,7 +291,7 @@ async def execute( ) -> str: self._changes_to_publish = asyncio.Queue() asyncio.create_task(self._jobqueue_publisher()) - + print("DEBUG: runnning jobqueue") try: await self._changes_to_publish.put(self._differ.snapshot()) while True: diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 401aad6e779..83a0b49c06c 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging import os import shutil @@ -22,6 +23,7 @@ import numpy as np from ert.analysis import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent +from ert.async_utils import get_event_loop, new_event_loop from ert.cli import MODULE_MODE from ert.config import ErtConfig, HookRuntime, QueueSystem from ert.enkf_main import EnKFMain, _seed_sequence, create_run_path @@ -29,18 +31,17 @@ Ensemble, EnsembleBuilder, EnsembleEvaluator, + EnsembleEvaluatorAsync, EvaluatorServerConfig, RealizationBuilder, ) from ert.libres_facade import LibresFacade from ert.run_context import RunContext from ert.runpaths import Runpaths +from ert.shared.feature_toggling import FeatureToggling from ert.storage import StorageAccessor -from .event import ( - RunModelStatusEvent, - RunModelTimeEvent, -) +from .event import RunModelStatusEvent, RunModelTimeEvent event_logger = logging.getLogger("ert.event_log") @@ -370,11 +371,28 @@ def run_ensemble_evaluator( ) -> List[int]: ensemble = self._build_ensemble(run_context) - successful_realizations = EnsembleEvaluator( - ensemble, - ee_config, - run_context.iteration, - ).run_and_get_successful_realizations() + if FeatureToggling.is_enabled("scheduler"): + print("Running AsyncEE!") + try: + asyncio.set_event_loop(new_event_loop()) + successful_realizations = get_event_loop().run_until_complete( + EnsembleEvaluatorAsync( + ensemble, + ee_config, + run_context.iteration, + ).run_and_get_successful_realizations() + ) + except Exception as exc: + print(f"{exc=}") + event_logger.error(f"Exception in AsyncEE: {exc}") + print(f"Exception in AsyncEE: {exc}") + raise + else: + successful_realizations = EnsembleEvaluator( + ensemble, + ee_config, + run_context.iteration, + ).run_and_get_successful_realizations() return successful_realizations diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index c12468f4b8c..ca1ffc9ba60 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -180,6 +180,7 @@ async def execute( # We need to store the loop due to when calling # cancel jobs from another thread self._loop = asyncio.get_running_loop() + print("DEBUG running scheduler") async with background_tasks() as cancel_when_execute_is_done: cancel_when_execute_is_done(self._publisher()) cancel_when_execute_is_done(self._process_event_queue()) diff --git a/test-data/poly_example/poly.ert b/test-data/poly_example/poly.ert index 490a54ee46d..a43ff1d18e9 100644 --- a/test-data/poly_example/poly.ert +++ b/test-data/poly_example/poly.ert @@ -7,7 +7,7 @@ RUNPATH poly_out/realization-/iter- OBS_CONFIG observations -NUM_REALIZATIONS 100 +NUM_REALIZATIONS 5 MIN_REALIZATIONS 1 GEN_KW COEFFS coeff_priors