From 2099444635c5675869f1965fb7b2ee95adf1f99e Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 4 Jul 2022 17:30:11 +0800 Subject: [PATCH] Refactor: Move `remove_revs` from `BaseStashQueue` to `expstash` --- dvc/repo/experiments/queue/base.py | 33 +++++-------------- dvc/repo/experiments/queue/remove.py | 26 +++++++++------ dvc/repo/experiments/stash.py | 14 +++++++- .../repo/experiments/queue/test_remove.py | 21 +++++------- 4 files changed, 46 insertions(+), 48 deletions(-) diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index a171cf4e9d..4c5309c175 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -152,28 +152,26 @@ def remove( if all_ or queued: return self.clear() - removed: List[str] = [] - to_remove: Dict[str, ExpStashEntry] = {} + name_to_remove: List[str] = [] + entry_to_remove: List[ExpStashEntry] = [] queue_entries = self.match_queue_entry_by_name( revs, self.iter_queued() ) for name, entry in queue_entries.items(): if entry: - to_remove[entry.stash_rev] = self.stash.stash_revs[ - entry.stash_rev - ] - removed.append(name) + entry_to_remove.append(self.stash.stash_revs[entry.stash_rev]) + name_to_remove.append(name) - self._remove_revs(to_remove, self.stash) - return removed + self.stash.remove_revs(entry_to_remove) + return name_to_remove def clear(self, **kwargs) -> List[str]: """Remove all entries from the queue.""" stash_revs = self.stash.stash_revs - removed = list(stash_revs) - self._remove_revs(stash_revs, self.stash) + name_to_remove = list(stash_revs) + self.stash.remove_revs(list(stash_revs.values())) - return removed + return name_to_remove def status(self) -> List[Dict[str, Any]]: """Show the status of exp tasks in queue""" @@ -217,19 +215,6 @@ def _format_entry( ) return result - @staticmethod - def _remove_revs(stash_revs: Mapping[str, ExpStashEntry], stash: ExpStash): - """Remove the specified entries from the queue by stash revision.""" - for index in sorted( - ( - entry.stash_index - for entry in stash_revs.values() - if entry.stash_index is not None - ), - reverse=True, - ): - stash.drop(index) - @abstractmethod def iter_queued(self) -> Generator[QueueEntry, None, None]: """Iterate over items in the queue.""" diff --git a/dvc/repo/experiments/queue/remove.py b/dvc/repo/experiments/queue/remove.py index bed3b05b7a..082666c315 100644 --- a/dvc/repo/experiments/queue/remove.py +++ b/dvc/repo/experiments/queue/remove.py @@ -27,7 +27,6 @@ def _remove_queued_tasks( Arguments: queue_entries: An iterable list of queued task to remove """ - # pylint: disable=protected-access stash_revs: Dict[str, "ExpStashEntry"] = {} for entry in queue_entries: if entry: @@ -36,11 +35,14 @@ def _remove_queued_tasks( ] try: - for msg, queue_entry in celery_queue._iter_queued(): + for ( + msg, + queue_entry, + ) in celery_queue._iter_queued(): # pylint: disable=protected-access if queue_entry.stash_rev in stash_revs: celery_queue.celery.reject(msg.delivery_tag) finally: - celery_queue._remove_revs(stash_revs, celery_queue.stash) + celery_queue.stash.remove_revs(list(stash_revs.values())) def _remove_done_tasks( @@ -52,21 +54,25 @@ def _remove_done_tasks( Arguments: queue_entries: An iterable list of done task to remove """ - # pylint: disable=protected-access from celery.result import AsyncResult - failed_stash_revs: Dict[str, "ExpStashEntry"] = {} + failed_stash_revs: List["ExpStashEntry"] = [] queue_entry_set: Set["QueueEntry"] = set() for entry in queue_entries: if entry: queue_entry_set.add(entry) if entry.stash_rev in celery_queue.failed_stash.stash_revs: - failed_stash_revs[ - entry.stash_rev - ] = celery_queue.failed_stash.stash_revs[entry.stash_rev] + failed_stash_revs.append( + celery_queue.failed_stash.stash_revs[entry.stash_rev] + ) try: - for msg, queue_entry in celery_queue._iter_processed(): + for ( + msg, + queue_entry, + ) in ( + celery_queue._iter_processed() # pylint: disable=protected-access + ): if queue_entry not in queue_entry_set: continue task_id = msg.headers["id"] @@ -75,7 +81,7 @@ def _remove_done_tasks( result.forget() celery_queue.celery.purge(msg.delivery_tag) finally: - celery_queue._remove_revs(failed_stash_revs, celery_queue.failed_stash) + celery_queue.failed_stash.remove_revs(failed_stash_revs) def _get_names(entries: Iterable[Union["QueueEntry", "QueueDoneResult"]]): diff --git a/dvc/repo/experiments/stash.py b/dvc/repo/experiments/stash.py index b15a70a304..7678c8d28a 100644 --- a/dvc/repo/experiments/stash.py +++ b/dvc/repo/experiments/stash.py @@ -1,5 +1,5 @@ import re -from typing import Dict, NamedTuple, Optional +from typing import Dict, Iterable, NamedTuple, Optional from scmrepo.git import Stash @@ -60,3 +60,15 @@ def format_message( ) branch_msg = f":{branch}" if branch else "" return f"{msg}{branch_msg}" + + def remove_revs(self, stash_revs: Iterable[ExpStashEntry]): + """Remove the specified entries from the queue by stash revision.""" + for index in sorted( + ( + entry.stash_index + for entry in stash_revs + if entry.stash_index is not None + ), + reverse=True, + ): + self.drop(index) diff --git a/tests/unit/repo/experiments/queue/test_remove.py b/tests/unit/repo/experiments/queue/test_remove.py index 7fae111ef7..2a0f1c5e6a 100644 --- a/tests/unit/repo/experiments/queue/test_remove.py +++ b/tests/unit/repo/experiments/queue/test_remove.py @@ -28,20 +28,17 @@ def test_remove_queued(test_queue, mocker): mocker.patch.object(test_queue, "_iter_queued", return_value=msg_iter) mocker.patch.object(test_queue, "iter_queued", return_value=entry_iter) - remove_revs_mocker = mocker.patch.object(test_queue, "_remove_revs") + remove_revs_mocker = mocker.patch.object(test_queue.stash, "remove_revs") reject_mocker = mocker.patch.object(test_queue.celery, "reject") assert test_queue.remove(["queue2"]) == ["queue2"] reject_mocker.assert_called_once_with("msg_queue2") - remove_revs_mocker.assert_called_once_with( - {"queue2": stash_dict["queue2"]}, test_queue.stash - ) - + remove_revs_mocker.assert_called_once_with([stash_dict["queue2"]]) remove_revs_mocker.reset_mock() reject_mocker.reset_mock() assert test_queue.remove([], queued=True) == queued_test - remove_revs_mocker.assert_called_once_with(stash_dict, test_queue.stash) + remove_revs_mocker.assert_called_once_with(list(stash_dict.values())) reject_mocker.assert_has_calls( [call("msg_queue1"), call("msg_queue2"), call("msg_queue3")] ) @@ -91,16 +88,16 @@ def test_remove_done(test_queue, mocker): mocker.patch.object(test_queue, "iter_failed", return_value=failed_iter) mocker.patch("celery.result.AsyncResult", return_value=mocker.Mock()) - remove_revs_mocker = mocker.patch.object(test_queue, "_remove_revs") + remove_revs_mocker = mocker.patch.object( + test_queue.failed_stash, "remove_revs" + ) purge_mocker = mocker.patch.object(test_queue.celery, "purge") assert test_queue.remove(["failed3", "success2"]) == [ "failed3", "success2", ] - remove_revs_mocker.assert_called_once_with( - {"failed3": stash_dict["failed3"]}, test_queue.failed_stash - ) + remove_revs_mocker.assert_called_once_with([stash_dict["failed3"]]) purge_mocker.assert_has_calls([call("msg_failed3"), call("msg_success2")]) remove_revs_mocker.reset_mock() @@ -120,9 +117,7 @@ def test_remove_done(test_queue, mocker): ], any_order=True, ) - remove_revs_mocker.assert_called_once_with( - stash_dict, test_queue.failed_stash - ) + remove_revs_mocker.assert_called_once_with(list(stash_dict.values())) def test_remove_all(test_queue, mocker):