Skip to content

Commit

Permalink
Mypy: fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Jan 26, 2024
1 parent 2dbf0ac commit 90a4339
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
5 changes: 5 additions & 0 deletions src/ert/ensemble_evaluator/_builder/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from ._realization import Realization

if TYPE_CHECKING:
from asyncio import Task

from ..config import EvaluatorServerConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -111,6 +113,9 @@ def __repr__(self) -> str:
def evaluate(self, config: "EvaluatorServerConfig") -> None:
pass

async def evaluate_async(self, config: EvaluatorServerConfig) -> "Task[Any]":
pass

def cancel(self) -> None:
pass

Expand Down
18 changes: 8 additions & 10 deletions src/ert/ensemble_evaluator/evaluator_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,14 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig, iter_: int

self._result = None

self._server_task: Optional[asyncio.Task] = None
self._dispatcher_task: Optional[asyncio.Task] = None
self._evaluator_task: Optional[asyncio.Task] = None
self._dispatcher_task: Optional[asyncio.Task[None]] = None

async def dispatcher(self):
async def dispatcher(self) -> None:
logger.debug("dispatcher started!!!!****")

event_handler = {}

def set_handler(event_types, function):
def set_handler(event_types: Set[str], function: Any) -> None:
for event_type in event_types:
event_handler[event_type] = function

Expand Down Expand Up @@ -346,6 +344,7 @@ async def evaluator_server(self) -> None:
# await asyncio.wait_for(self._dispatcher_task, timeout=20)
# except asyncio.TimeoutError:
# logger.debug("Timed out waiting for batcher to finish")
assert self._dispatcher_task is not None
self._dispatcher_task.cancel()
await self._dispatcher_task

Expand Down Expand Up @@ -386,20 +385,19 @@ async def _signal_cancel(self) -> None:
"""
if self._ensemble.cancellable:
logger.debug("Cancelling current ensemble")
assert self._loop is not None
self._loop.run_in_executor(None, self._ensemble.cancel)
else:
logger.debug("Stopping current ensemble")
await self._stop()

async def run_and_get_successful_realizations(self) -> List[int]:
self._loop = asyncio.get_running_loop()
self._server_task = asyncio.create_task(self.evaluator_server())
server_task = asyncio.create_task(self.evaluator_server())
self._dispatcher_task = asyncio.create_task(self.dispatcher())
self._evaluator_task = await self._ensemble.evaluate_async(self._config)
evaluator_task = await self._ensemble.evaluate_async(self._config)

await asyncio.gather(
self._server_task, self._evaluator_task, return_exceptions=True
)
await asyncio.gather(server_task, evaluator_task, return_exceptions=True)
logger.debug("Evaluator is done")
return self._ensemble.get_successful_realizations()

Expand Down

0 comments on commit 90a4339

Please sign in to comment.