diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index e2692680786..39046ad2be3 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -223,16 +223,12 @@ async def _evaluate_inner( # pylint: disable=too-many-branches event_creator(identifiers.EVTYPE_ENSEMBLE_STARTED, None) ) - if isinstance(queue, Scheduler): - queue.add_dispatch_information_to_jobs_file() - result = await queue.execute() - elif isinstance(queue, JobQueue): - min_required_realizations = ( - self.min_required_realizations if self.stop_long_running else 0 - ) - queue.add_dispatch_information_to_jobs_file() + min_required_realizations = ( + self.min_required_realizations if self.stop_long_running else 0 + ) - result = await queue.execute(min_required_realizations) + queue.add_dispatch_information_to_jobs_file() + result = await queue.execute(min_required_realizations) except Exception: logger.exception( diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 90caa69fc29..6a25172e4f4 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -2,6 +2,7 @@ import asyncio import logging +import time import uuid from enum import Enum from pathlib import Path @@ -67,6 +68,8 @@ def __init__(self, scheduler: Scheduler, real: Realization) -> None: self._scheduler: Scheduler = scheduler self._callback_status_msg: str = "" self._requested_max_submit: Optional[int] = None + self._start_time: Optional[float] = None + self._end_time: Optional[float] = None @property def iens(self) -> int: @@ -76,6 +79,14 @@ def iens(self) -> int: def driver(self) -> Driver: return self._scheduler.driver + @property + def running_duration(self) -> float: + if self._start_time: + if self._end_time: + return self._end_time - self._start_time + return time.time() - self._start_time + return 0 + async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await sem.acquire() timeout_task: Optional[asyncio.Task[None]] = None @@ -88,6 +99,7 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await self._send(State.PENDING) await self.started.wait() + self._start_time = time.time() await self._send(State.RUNNING) if self.real.max_runtime is not None and self.real.max_runtime > 0: @@ -179,6 +191,10 @@ async def _send(self, state: State) -> None: if state in (State.FAILED, State.ABORTED): await self._handle_failure() + if state == State.COMPLETED: + self._end_time = time.time() + await self._scheduler.completed_jobs.put(self.iens) + status = STATE_TO_LEGACY[state] event = CloudEvent( { diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 3b5d83145d2..2459e7ad2d9 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -8,14 +8,7 @@ from collections import defaultdict from dataclasses import asdict from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Dict, - MutableMapping, - Optional, - Sequence, -) +from typing import TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Sequence from pydantic.dataclasses import dataclass from websockets import Headers @@ -69,6 +62,10 @@ def __init__( } self._events: asyncio.Queue[Any] = asyncio.Queue() + self._average_job_runtime: float = 0 + self._completed_jobs_num: int = 0 + self.completed_jobs: asyncio.Queue[int] = asyncio.Queue() + self._cancelled = False self._max_submit = max_submit self._max_running = max_running @@ -83,8 +80,29 @@ def kill_all_jobs(self) -> None: for task in self._tasks.values(): task.cancel() - def stop_long_running_jobs(self, minimum_required_realizations: int) -> None: - pass + async def _update_avg_job_runtime(self) -> None: + while True: + job_id = await self.completed_jobs.get() + self._average_job_runtime = ( + self._average_job_runtime * self._completed_jobs_num + + self._jobs[job_id].running_duration + ) / (self._completed_jobs_num + 1) + self._completed_jobs_num += 1 + + async def _stop_long_running_jobs( + self, minimum_required_realizations: int, long_running_factor: float = 1.25 + ) -> None: + while True: + if self._completed_jobs_num >= minimum_required_realizations: + for job_id, task in self._tasks.items(): + if ( + self._jobs[job_id].running_duration + > long_running_factor * self._average_job_runtime + and not task.done() + ): + task.cancel() + await task + await asyncio.sleep(0.1) def set_realization(self, realization: Realization) -> None: self._jobs[realization.iens] = Job(self, realization) @@ -126,11 +144,19 @@ def add_dispatch_information_to_jobs_file(self) -> None: for job in self._jobs.values(): self._update_jobs_json(job.iens, job.real.run_arg.runpath) - async def execute(self, minimum_required_realizations: int = 0) -> str: + async def execute( + self, + min_required_realizations: int = 0, + ) -> str: async with background_tasks() as cancel_when_execute_is_done: cancel_when_execute_is_done(self._publisher()) cancel_when_execute_is_done(self._process_event_queue()) cancel_when_execute_is_done(self.driver.poll()) + if min_required_realizations > 0: + cancel_when_execute_is_done( + self._stop_long_running_jobs(min_required_realizations) + ) + cancel_when_execute_is_done(self._update_avg_job_runtime()) start = asyncio.Event() sem = asyncio.BoundedSemaphore(self._max_running) diff --git a/tests/unit_tests/scheduler/conftest.py b/tests/unit_tests/scheduler/conftest.py index 370524af461..0884b85af03 100644 --- a/tests/unit_tests/scheduler/conftest.py +++ b/tests/unit_tests/scheduler/conftest.py @@ -10,19 +10,26 @@ def __init__(self, init=None, wait=None, kill=None): self._mock_wait = wait self._mock_kill = kill - async def _init(self, *args, **kwargs): + async def _init(self, iens, *args, **kwargs): if self._mock_init is not None: - await self._mock_init(*args, **kwargs) + await self._mock_init(iens, *args, **kwargs) + return iens - async def _wait(self, *args): + async def _wait(self, iens): if self._mock_wait is not None: - result = await self._mock_wait() + if self._mock_wait.__code__.co_argcount > 0: + result = await self._mock_wait(iens) + else: + result = await self._mock_wait() return True if result is None else bool(result) return True - async def _kill(self, *args): + async def _kill(self, iens, *args): if self._mock_kill is not None: - await self._mock_kill() + if self._mock_kill.__code__.co_argcount > 0: + await self._mock_kill(iens) + else: + await self._mock_kill() @pytest.fixture diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index fe05e747f3c..260c47bc6d6 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -270,3 +270,34 @@ async def init(iens, *args, **kwargs): assert sch.is_active() await execute_task assert not sch.is_active() + + +@pytest.mark.timeout(6) +async def test_that_long_running_jobs_were_stopped(storage, tmp_path, mock_driver): + killed_iens = [] + + async def kill(iens): + nonlocal killed_iens + killed_iens.append(iens) + + async def wait(iens): + # all jobs with iens > 5 will sleep for 10 seconds and should be killed + if iens < 6: + await asyncio.sleep(0.5) + else: + await asyncio.sleep(10) + return True + + ensemble_size = 10 + ensemble = storage.create_experiment().create_ensemble( + name="foo", ensemble_size=ensemble_size + ) + realizations = [ + create_stub_realization(ensemble, tmp_path, iens) + for iens in range(ensemble_size) + ] + + sch = scheduler.Scheduler(mock_driver(wait=wait, kill=kill), realizations) + + assert await sch.execute(min_required_realizations=5) == EVTYPE_ENSEMBLE_STOPPED + assert killed_iens == [6, 7, 8, 9]