Skip to content

Commit

Permalink
WIP: working state for cli
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Jan 22, 2024
1 parent da4d11d commit 367ac5f
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 42 deletions.
2 changes: 2 additions & 0 deletions src/ert/ensemble_evaluator/_builder/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ async def send_cloudevent(
retries: int = 10,
) -> None:
async with Client(url, token, cert, max_retries=retries) as client:
print(f"DEBUG before send {event=}")
await client._send(to_json(event, data_marshaller=evaluator_marshaller))
print(f"DEBUG after send {event=}")

def get_successful_realizations(self) -> List[int]:
return self._snapshot.get_successful_realizations()
Expand Down
1 change: 1 addition & 0 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
await send_timeout_future

# Dispatch final result from evaluator - FAILED, CANCEL or STOPPED
print(f"DEBUG final event!@!!!!!! {result=}")
await cloudevent_unary_send(event_creator(result, None))

@property
Expand Down
60 changes: 22 additions & 38 deletions src/ert/ensemble_evaluator/evaluator_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)

logger = logging.getLogger(__name__)
logger.debug = print

_MAX_UNSUCCESSFUL_CONNECTION_ATTEMPTS = 3

Expand Down Expand Up @@ -78,10 +79,11 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig, iter_: int
self._dispatcher_task: Optional[asyncio.Task] = None
self._evaluator_task: Optional[asyncio.Task] = None

async def batching_dispatcher(self):
async def dispatcher(self):
logger.debug("dispatcher started!!!!****")

event_handler = {}
# raise ValueError("TEST exception")

def set_handler(event_types, function):
for event_type in event_types:
Expand Down Expand Up @@ -113,21 +115,11 @@ def ensemble(self) -> Ensemble:
return self._ensemble

async def _fm_handler(self, events: List[CloudEvent]) -> None:
async with self._snapshot_mutex:
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
await self._send_snapshot_update(snapshot_update_event)
await self._send_snapshot_update(self.ensemble.update_snapshot(events))

async def _started_handler(self, events: List[CloudEvent]) -> None:
if self.ensemble.status != ENSEMBLE_STATE_FAILED:
async with self._snapshot_mutex:
print("DEBUG: STARTED!!!!!!!")
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
print("DEBUG: STARTED - snapshot updated!!!!!!!")
await self._send_snapshot_update(snapshot_update_event)
await self._send_snapshot_update(self.ensemble.update_snapshot(events))
print("DEBUG: STARTED - snapshot sent!!!!!!!")

async def _stopped_handler(self, events: List[CloudEvent]) -> None:
Expand All @@ -142,18 +134,11 @@ async def _stopped_handler(self, events: List[CloudEvent]) -> None:
logger.info(
f"Ensemble ran with maximum memory usage for a single realization job: {max_memory_usage}"
)
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
await self._send_snapshot_update(snapshot_update_event)
await self._send_snapshot_update(self.ensemble.update_snapshot(events))

async def _cancelled_handler(self, events: List[CloudEvent]) -> None:
if self.ensemble.status != ENSEMBLE_STATE_FAILED:
async with self._snapshot_mutex:
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
await self._send_snapshot_update(snapshot_update_event)
await self._send_snapshot_update(self.ensemble.update_snapshot(events))
await self._stop()

async def _failed_handler(self, events: List[CloudEvent]) -> None:
Expand All @@ -167,11 +152,7 @@ async def _failed_handler(self, events: List[CloudEvent]) -> None:
# api for setting state in the ensemble
if len(events) == 0:
events = [await self._create_cloud_event(EVTYPE_ENSEMBLE_FAILED)]
async with self._snapshot_mutex:
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
await self._send_snapshot_update(snapshot_update_event)
await self._send_snapshot_update(self.ensemble.update_snapshot(events))
await self._signal_cancel() # let ensemble know it should stop

async def _send_snapshot_update(
Expand Down Expand Up @@ -371,12 +352,13 @@ async def evaluator_server(self) -> None:
else:
logger.debug("Got done signal. No dispatchers connected")

logger.debug("Waiting for batcher to finish...")
try:
await asyncio.wait_for(self._dispatcher_task, timeout=20)
except asyncio.TimeoutError:
logger.debug("Timed out waiting for batcher to finish")
self._dispatcher_task.cancel()
# logger.debug("Waiting for batcher to finish...")
# try:
# await asyncio.wait_for(self._dispatcher_task, timeout=20)
# except asyncio.TimeoutError:
# logger.debug("Timed out waiting for batcher to finish")
self._dispatcher_task.cancel()
await self._dispatcher_task

terminated_attrs: Dict[str, str] = {}
terminated_data = None
Expand All @@ -403,9 +385,9 @@ async def evaluator_server(self) -> None:
async def _stop(self) -> None:
if not self._done.done():
self._done.set_result(None)
if self._dispatcher_task:
self._dispatcher_task.cancel()
await self._dispatcher_task
# if self._dispatcher_task:
# self._dispatcher_task.cancel()
# await self._dispatcher_task

async def _signal_cancel(self) -> None:
"""
Expand All @@ -426,10 +408,12 @@ async def _signal_cancel(self) -> None:
async def run_and_get_successful_realizations(self) -> List[int]:
self._loop = asyncio.get_running_loop()
self._server_task = asyncio.create_task(self.evaluator_server())
self._dispatcher_task = asyncio.create_task(self.batching_dispatcher())
self._dispatcher_task = asyncio.create_task(self.dispatcher())
self._evaluator_task = await self._ensemble.evaluate_async(self._config)

await self._server_task
await asyncio.gather(
self._server_task, self._evaluator_task, return_exceptions=True
)
logger.debug("Evaluator is done")
return self._ensemble.get_successful_realizations()

Expand Down
1 change: 1 addition & 0 deletions src/ert/ensemble_evaluator/evaluator_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def track(
) -> Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]]:
while True:
event = self._work_queue.get()
print(f"DEBUG: {event=}")
if isinstance(event, str):
with contextlib.suppress(GeneratorExit):
# consumers may exit at this point, make sure the last
Expand Down
10 changes: 6 additions & 4 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import logging
import os
import shutil
Expand All @@ -22,7 +23,7 @@
import numpy as np

from ert.analysis import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent
from ert.async_utils import get_event_loop
from ert.async_utils import get_event_loop, new_event_loop
from ert.cli import MODULE_MODE
from ert.config import ErtConfig, HookRuntime, QueueSystem
from ert.enkf_main import EnKFMain, _seed_sequence, create_run_path
Expand Down Expand Up @@ -371,8 +372,9 @@ def run_ensemble_evaluator(
ensemble = self._build_ensemble(run_context)

if FeatureToggling.is_enabled("scheduler"):
event_logger.info("Running AsyncEE!")
print("Running AsyncEE!")
try:
asyncio.set_event_loop(new_event_loop())
successful_realizations = get_event_loop().run_until_complete(
EnsembleEvaluatorAsync(
ensemble,
Expand All @@ -381,10 +383,10 @@ def run_ensemble_evaluator(
).run_and_get_successful_realizations()
)
except Exception as exc:
print(f"{exc=}")
event_logger.error(f"Exception in AsyncEE: {exc}")
print(f"Exception in AsyncEE: {exc}")
raise
finally:
get_event_loop().close()
else:
successful_realizations = EnsembleEvaluator(
ensemble,
Expand Down

0 comments on commit 367ac5f

Please sign in to comment.