Skip to content

Commit

Permalink
More updates
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Jan 19, 2024
1 parent 4167bec commit da4d11d
Showing 1 changed file with 32 additions and 31 deletions.
63 changes: 32 additions & 31 deletions src/ert/ensemble_evaluator/evaluator_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def set_handler(event_types, function):
event = await self._events.get()
logger.debug(f"EVENT-logging: {event}")
await event_handler[event["type"]]([event])
print(f"DEBUG: event processed {event}!!!!")
logger.debug(f"DEBUG: event processed {event}!!!!")

@property
def config(self) -> EvaluatorServerConfig:
Expand All @@ -112,21 +114,21 @@ def ensemble(self) -> Ensemble:

async def _fm_handler(self, events: List[CloudEvent]) -> None:
async with self._snapshot_mutex:
future = asyncio.run_coroutine_threadsafe(
self.ensemble.update_snapshot(events), self._loop
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
snapshot_update_event = future.result()

await self._send_snapshot_update(snapshot_update_event)

async def _started_handler(self, events: List[CloudEvent]) -> None:
if self.ensemble.status != ENSEMBLE_STATE_FAILED:
async with self._snapshot_mutex:
future = asyncio.run_coroutine_threadsafe(
self.ensemble.update_snapshot(events), self._loop
print("DEBUG: STARTED!!!!!!!")
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
snapshot_update_event = future.result()
print("DEBUG: STARTED - snapshot updated!!!!!!!")
await self._send_snapshot_update(snapshot_update_event)
print("DEBUG: STARTED - snapshot sent!!!!!!!")

async def _stopped_handler(self, events: List[CloudEvent]) -> None:
if self.ensemble.status != ENSEMBLE_STATE_FAILED:
Expand All @@ -140,19 +142,17 @@ async def _stopped_handler(self, events: List[CloudEvent]) -> None:
logger.info(
f"Ensemble ran with maximum memory usage for a single realization job: {max_memory_usage}"
)
future = asyncio.run_coroutine_threadsafe(
self.ensemble.update_snapshot(events), self._loop
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
snapshot_update_event = future.result()
await self._send_snapshot_update(snapshot_update_event)

async def _cancelled_handler(self, events: List[CloudEvent]) -> None:
if self.ensemble.status != ENSEMBLE_STATE_FAILED:
async with self._snapshot_mutex:
future = asyncio.run_coroutine_threadsafe(
self.ensemble.update_snapshot(events), self._loop
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
snapshot_update_event = future.result()
await self._send_snapshot_update(snapshot_update_event)
await self._stop()

Expand All @@ -166,22 +166,24 @@ async def _failed_handler(self, events: List[CloudEvent]) -> None:
# create a fake event because that's currently the only
# api for setting state in the ensemble
if len(events) == 0:
events = [self._create_cloud_event(EVTYPE_ENSEMBLE_FAILED)]
events = [await self._create_cloud_event(EVTYPE_ENSEMBLE_FAILED)]
async with self._snapshot_mutex:
future = asyncio.run_coroutine_threadsafe(
self.ensemble.update_snapshot(events), self._loop
snapshot_update_event = self._loop.run_in_executor(
None, self.ensemble.update_snapshot, events
)
snapshot_update_event = future.result()
await self._send_snapshot_update(snapshot_update_event)
self._signal_cancel() # let ensemble know it should stop
await self._signal_cancel() # let ensemble know it should stop

async def _send_snapshot_update(
self, snapshot_update_event: PartialSnapshot
) -> None:
message = self._create_cloud_message(
print(f"DEBUG: {self._clients=}")
message = await self._create_cloud_message(
EVTYPE_EE_SNAPSHOT_UPDATE,
snapshot_update_event.to_dict(),
)
print(f"DEBUG: {message=}")
logger.debug(f"DEBUG: sending {message=} to {self._clients=}")
if message and self._clients:
# Note return_exceptions=True in gather. This fire-and-forget
# approach is currently how we deal with failures when trying
Expand All @@ -190,12 +192,13 @@ async def _send_snapshot_update(
# to re-establish it. Thus, it becomes the responsibility of
# the client to re-connect if necessary, in which case the first
# update it receives will be a full snapshot.
print(f"DEBUG: sending {message=} to {self._clients=}")
await asyncio.gather(
*[client.send(message) for client in self._clients],
return_exceptions=True,
)

def _create_cloud_event(
async def _create_cloud_event(
self,
event_type: str,
data: Optional[Dict[str, Any]] = None,
Expand All @@ -217,18 +220,16 @@ def _create_cloud_event(
data,
)

def _create_cloud_message(
async def _create_cloud_message(
self,
event_type: str,
data: Optional[Dict[str, Any]] = None,
extra_attrs: Optional[Dict[str, Any]] = None,
data_marshaller: Optional[Callable[[Any], Any]] = evaluator_marshaller,
) -> str:
"""Creates the CloudEvent and returns the serialized json-string"""
return to_json(
self._create_cloud_event(event_type, data, extra_attrs),
data_marshaller=data_marshaller,
).decode()
event = await self._create_cloud_event(event_type, data, extra_attrs)
return to_json(event, data_marshaller=data_marshaller).decode()

@contextmanager
def store_client(
Expand All @@ -244,7 +245,7 @@ async def handle_client(
with self.store_client(websocket):
async with self._snapshot_mutex:
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event = self._create_cloud_message(
event = await self._create_cloud_message(
EVTYPE_EE_SNAPSHOT, current_snapshot_dict
)
await websocket.send(event)
Expand All @@ -256,7 +257,7 @@ async def handle_client(
logger.debug(f"got message from client: {client_event}")
if client_event["type"] == EVTYPE_EE_USER_CANCEL:
logger.debug(f"Client {websocket.remote_address} asked to cancel.")
self._signal_cancel()
await self._signal_cancel()

elif client_event["type"] == EVTYPE_EE_USER_DONE:
logger.debug(f"Client {websocket.remote_address} signalled done.")
Expand Down Expand Up @@ -384,7 +385,7 @@ async def evaluator_server(self) -> None:
terminated_data = cloudpickle.dumps(self._result)

logger.debug("Sending termination-message to clients...")
message = self._create_cloud_message(
message = await self._create_cloud_message(
EVTYPE_EE_TERMINATED,
data=terminated_data,
extra_attrs=terminated_attrs,
Expand All @@ -406,7 +407,7 @@ async def _stop(self) -> None:
self._dispatcher_task.cancel()
await self._dispatcher_task

def _signal_cancel(self) -> None:
async def _signal_cancel(self) -> None:
"""
This is just a wrapper around logic for whether to signal cancel via
a cancellable ensemble or to use internal stop-mechanism directly
Expand All @@ -417,10 +418,10 @@ def _signal_cancel(self) -> None:
"""
if self._ensemble.cancellable:
logger.debug("Cancelling current ensemble")
self._ensemble.cancel()
self._loop.run_in_executor(None, self._ensemble.cancel)
else:
logger.debug("Stopping current ensemble")
asyncio.run_coroutine_threadsafe(self._stop(), self._loop)
await self._stop()

async def run_and_get_successful_realizations(self) -> List[int]:
self._loop = asyncio.get_running_loop()
Expand Down

0 comments on commit da4d11d

Please sign in to comment.