From 268360cc7b202c1f58dd5d34e3f8edab5639d33b Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 5 Dec 2022 11:11:47 +0800 Subject: [PATCH] Use `SIGINT` as the default signal in `queue kill` fix: #8624 1. Add a new flag `--force` for `queue kill` 2. Make `SIGINT` to be the default option and `SIGKILL` to be with `--force` 3. Add tests for `queue kill` 4. Bump dvc-task into 0.1.9 --- dvc/commands/queue/kill.py | 16 ++++++++++++++-- dvc/repo/experiments/queue/celery.py | 18 ++++++++++-------- dvc/stage/run.py | 2 +- pyproject.toml | 2 +- tests/unit/command/test_queue.py | 3 ++- .../unit/repo/experiments/queue/test_celery.py | 9 +++++---- 6 files changed, 33 insertions(+), 17 deletions(-) diff --git a/dvc/commands/queue/kill.py b/dvc/commands/queue/kill.py index 098c091194..c2c9fc4549 100644 --- a/dvc/commands/queue/kill.py +++ b/dvc/commands/queue/kill.py @@ -11,13 +11,18 @@ class CmdQueueKill(CmdBase): """Kill exp task in queue.""" def run(self): - self.repo.experiments.celery_queue.kill(revs=self.args.task) + self.repo.experiments.celery_queue.kill( + revs=self.args.task, force=self.args.force + ) return 0 def add_parser(queue_subparsers, parent_parser): - QUEUE_KILL_HELP = "Kill actively running experiment queue tasks." + QUEUE_KILL_HELP = ( + "Gracefully interrupt running experiment queue tasks " + "(equivalent to Ctrl-C)" + ) queue_kill_parser = queue_subparsers.add_parser( "kill", parents=[parent_parser], @@ -25,6 +30,13 @@ def add_parser(queue_subparsers, parent_parser): help=QUEUE_KILL_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) + queue_kill_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Forcefully and immediately kill running experiment queue tasks", + ) queue_kill_parser.add_argument( "task", nargs="*", diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index a21615d1f8..e9131e1c52 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -298,12 +298,15 @@ def _get_running_task_ids(self) -> Set[str]: return running_task_ids def _try_to_kill_tasks( - self, to_kill: Dict[QueueEntry, str] + self, to_kill: Dict[QueueEntry, str], force: bool ) -> Dict[QueueEntry, str]: fail_to_kill_entries: Dict[QueueEntry, str] = {} for queue_entry, rev in to_kill.items(): try: - self.proc.kill(queue_entry.stash_rev) + if force: + self.proc.kill(queue_entry.stash_rev) + else: + self.proc.interrupt(queue_entry.stash_rev) logger.debug(f"Task {rev} had been killed.") except ProcessLookupError: fail_to_kill_entries[queue_entry] = rev @@ -331,20 +334,19 @@ def _mark_inactive_tasks_failure(self, remained_entries): if remained_revs: raise CannotKillTasksError(remained_revs) - def _kill_entries(self, entries: Dict[QueueEntry, str]): + def _kill_entries(self, entries: Dict[QueueEntry, str], force: bool): logger.debug( "Found active tasks: '%s' to kill", list(entries.values()), ) inactive_entries: Dict[QueueEntry, str] = self._try_to_kill_tasks( - entries + entries, force ) if inactive_entries: self._mark_inactive_tasks_failure(inactive_entries) - def kill(self, revs: Collection[str]) -> None: - + def kill(self, revs: Collection[str], force: bool = False) -> None: name_dict: Dict[ str, Optional[QueueEntry] ] = self.match_queue_entry_by_name(set(revs), self.iter_active()) @@ -360,7 +362,7 @@ def kill(self, revs: Collection[str]) -> None: raise UnresolvedQueueExpNamesError(missing_revs) if to_kill: - self._kill_entries(to_kill) + self._kill_entries(to_kill, force) def shutdown(self, kill: bool = False): self.celery.control.shutdown() @@ -369,7 +371,7 @@ def shutdown(self, kill: bool = False): for entry in self.iter_active(): to_kill[entry] = entry.name or entry.stash_rev if to_kill: - self._kill_entries(to_kill) + self._kill_entries(to_kill, True) def follow( self, diff --git a/dvc/stage/run.py b/dvc/stage/run.py index d4478cfac7..0b5f5379df 100644 --- a/dvc/stage/run.py +++ b/dvc/stage/run.py @@ -84,9 +84,9 @@ def _run(stage: "Stage", executable, cmd, checkpoint_func, **kwargs): threading.current_thread(), threading._MainThread, # type: ignore[attr-defined] ) + old_handler = None exec_cmd = _make_cmd(executable, cmd) - old_handler = None try: p = subprocess.Popen(exec_cmd, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 11c1214bce..e3fe16bf38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ "typing-extensions>=3.7.4", "scmrepo==0.1.5", "dvc-render==0.0.17", - "dvc-task==0.1.8", + "dvc-task==0.1.9", "dvclive>=1.2.2", "dvc-data==0.28.4", "dvc-http==2.27.2", diff --git a/tests/unit/command/test_queue.py b/tests/unit/command/test_queue.py index 610ee3e6fc..8d138d1e51 100644 --- a/tests/unit/command/test_queue.py +++ b/tests/unit/command/test_queue.py @@ -92,6 +92,7 @@ def test_experiments_kill(dvc, scm, mocker): [ "queue", "kill", + "--force", "exp1", "exp2", ] @@ -105,7 +106,7 @@ def test_experiments_kill(dvc, scm, mocker): ) assert cmd.run() == 0 - m.assert_called_once_with(revs=["exp1", "exp2"]) + m.assert_called_once_with(revs=["exp1", "exp2"], force=True) def test_experiments_start(dvc, scm, mocker): diff --git a/tests/unit/repo/experiments/queue/test_celery.py b/tests/unit/repo/experiments/queue/test_celery.py index 2c389aff62..b0502af43d 100644 --- a/tests/unit/repo/experiments/queue/test_celery.py +++ b/tests/unit/repo/experiments/queue/test_celery.py @@ -46,7 +46,7 @@ def test_shutdown_with_kill(test_queue, mocker): shutdown_spy.assert_called_once() kill_spy.assert_called_once_with( - {mock_entry_foo: "foo", mock_entry_bar: "bar"} + {mock_entry_foo: "foo", mock_entry_bar: "bar"}, True ) @@ -78,7 +78,8 @@ def test_post_run_after_kill(test_queue): assert result_foo.get(timeout=10) == "foo" -def test_celery_queue_kill(test_queue, mocker): +@pytest.mark.parametrize("force", [True, False]) +def test_celery_queue_kill(test_queue, mocker, force): mock_entry_foo = mocker.Mock(stash_rev="foo") mock_entry_bar = mocker.Mock(stash_rev="bar") @@ -137,13 +138,13 @@ def kill_function(rev): kill_mock = mocker.patch.object( test_queue.proc, - "kill", + "kill" if force else "interrupt", side_effect=mocker.MagicMock(side_effect=kill_function), ) with pytest.raises( CannotKillTasksError, match="Task 'foobar' is initializing," ): - test_queue.kill(["bar", "foo", "foobar"]) + test_queue.kill(["bar", "foo", "foobar"], force=force) assert kill_mock.called_once_with(mock_entry_foo.stash_rev) assert kill_mock.called_once_with(mock_entry_bar.stash_rev) assert kill_mock.called_once_with(mock_entry_foobar.stash_rev)