Skip to content

Commit

Permalink
Refactor: Move remove_revs from BaseStashQueue to expstash
Browse files Browse the repository at this point in the history
  • Loading branch information
karajan1001 authored and pmrowla committed Jul 5, 2022
1 parent 4cc806c commit 2099444
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 48 deletions.
33 changes: 9 additions & 24 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,28 +152,26 @@ def remove(
if all_ or queued:
return self.clear()

removed: List[str] = []
to_remove: Dict[str, ExpStashEntry] = {}
name_to_remove: List[str] = []
entry_to_remove: List[ExpStashEntry] = []
queue_entries = self.match_queue_entry_by_name(
revs, self.iter_queued()
)
for name, entry in queue_entries.items():
if entry:
to_remove[entry.stash_rev] = self.stash.stash_revs[
entry.stash_rev
]
removed.append(name)
entry_to_remove.append(self.stash.stash_revs[entry.stash_rev])
name_to_remove.append(name)

self._remove_revs(to_remove, self.stash)
return removed
self.stash.remove_revs(entry_to_remove)
return name_to_remove

def clear(self, **kwargs) -> List[str]:
"""Remove all entries from the queue."""
stash_revs = self.stash.stash_revs
removed = list(stash_revs)
self._remove_revs(stash_revs, self.stash)
name_to_remove = list(stash_revs)
self.stash.remove_revs(list(stash_revs.values()))

return removed
return name_to_remove

def status(self) -> List[Dict[str, Any]]:
"""Show the status of exp tasks in queue"""
Expand Down Expand Up @@ -217,19 +215,6 @@ def _format_entry(
)
return result

@staticmethod
def _remove_revs(stash_revs: Mapping[str, ExpStashEntry], stash: ExpStash):
"""Remove the specified entries from the queue by stash revision."""
for index in sorted(
(
entry.stash_index
for entry in stash_revs.values()
if entry.stash_index is not None
),
reverse=True,
):
stash.drop(index)

@abstractmethod
def iter_queued(self) -> Generator[QueueEntry, None, None]:
"""Iterate over items in the queue."""
Expand Down
26 changes: 16 additions & 10 deletions dvc/repo/experiments/queue/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def _remove_queued_tasks(
Arguments:
queue_entries: An iterable list of queued task to remove
"""
# pylint: disable=protected-access
stash_revs: Dict[str, "ExpStashEntry"] = {}
for entry in queue_entries:
if entry:
Expand All @@ -36,11 +35,14 @@ def _remove_queued_tasks(
]

try:
for msg, queue_entry in celery_queue._iter_queued():
for (
msg,
queue_entry,
) in celery_queue._iter_queued(): # pylint: disable=protected-access
if queue_entry.stash_rev in stash_revs:
celery_queue.celery.reject(msg.delivery_tag)
finally:
celery_queue._remove_revs(stash_revs, celery_queue.stash)
celery_queue.stash.remove_revs(list(stash_revs.values()))


def _remove_done_tasks(
Expand All @@ -52,21 +54,25 @@ def _remove_done_tasks(
Arguments:
queue_entries: An iterable list of done task to remove
"""
# pylint: disable=protected-access
from celery.result import AsyncResult

failed_stash_revs: Dict[str, "ExpStashEntry"] = {}
failed_stash_revs: List["ExpStashEntry"] = []
queue_entry_set: Set["QueueEntry"] = set()
for entry in queue_entries:
if entry:
queue_entry_set.add(entry)
if entry.stash_rev in celery_queue.failed_stash.stash_revs:
failed_stash_revs[
entry.stash_rev
] = celery_queue.failed_stash.stash_revs[entry.stash_rev]
failed_stash_revs.append(
celery_queue.failed_stash.stash_revs[entry.stash_rev]
)

try:
for msg, queue_entry in celery_queue._iter_processed():
for (
msg,
queue_entry,
) in (
celery_queue._iter_processed() # pylint: disable=protected-access
):
if queue_entry not in queue_entry_set:
continue
task_id = msg.headers["id"]
Expand All @@ -75,7 +81,7 @@ def _remove_done_tasks(
result.forget()
celery_queue.celery.purge(msg.delivery_tag)
finally:
celery_queue._remove_revs(failed_stash_revs, celery_queue.failed_stash)
celery_queue.failed_stash.remove_revs(failed_stash_revs)


def _get_names(entries: Iterable[Union["QueueEntry", "QueueDoneResult"]]):
Expand Down
14 changes: 13 additions & 1 deletion dvc/repo/experiments/stash.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Dict, NamedTuple, Optional
from typing import Dict, Iterable, NamedTuple, Optional

from scmrepo.git import Stash

Expand Down Expand Up @@ -60,3 +60,15 @@ def format_message(
)
branch_msg = f":{branch}" if branch else ""
return f"{msg}{branch_msg}"

def remove_revs(self, stash_revs: Iterable[ExpStashEntry]):
"""Remove the specified entries from the queue by stash revision."""
for index in sorted(
(
entry.stash_index
for entry in stash_revs
if entry.stash_index is not None
),
reverse=True,
):
self.drop(index)
21 changes: 8 additions & 13 deletions tests/unit/repo/experiments/queue/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,17 @@ def test_remove_queued(test_queue, mocker):
mocker.patch.object(test_queue, "_iter_queued", return_value=msg_iter)
mocker.patch.object(test_queue, "iter_queued", return_value=entry_iter)

remove_revs_mocker = mocker.patch.object(test_queue, "_remove_revs")
remove_revs_mocker = mocker.patch.object(test_queue.stash, "remove_revs")
reject_mocker = mocker.patch.object(test_queue.celery, "reject")

assert test_queue.remove(["queue2"]) == ["queue2"]
reject_mocker.assert_called_once_with("msg_queue2")
remove_revs_mocker.assert_called_once_with(
{"queue2": stash_dict["queue2"]}, test_queue.stash
)

remove_revs_mocker.assert_called_once_with([stash_dict["queue2"]])
remove_revs_mocker.reset_mock()
reject_mocker.reset_mock()

assert test_queue.remove([], queued=True) == queued_test
remove_revs_mocker.assert_called_once_with(stash_dict, test_queue.stash)
remove_revs_mocker.assert_called_once_with(list(stash_dict.values()))
reject_mocker.assert_has_calls(
[call("msg_queue1"), call("msg_queue2"), call("msg_queue3")]
)
Expand Down Expand Up @@ -91,16 +88,16 @@ def test_remove_done(test_queue, mocker):
mocker.patch.object(test_queue, "iter_failed", return_value=failed_iter)
mocker.patch("celery.result.AsyncResult", return_value=mocker.Mock())

remove_revs_mocker = mocker.patch.object(test_queue, "_remove_revs")
remove_revs_mocker = mocker.patch.object(
test_queue.failed_stash, "remove_revs"
)
purge_mocker = mocker.patch.object(test_queue.celery, "purge")

assert test_queue.remove(["failed3", "success2"]) == [
"failed3",
"success2",
]
remove_revs_mocker.assert_called_once_with(
{"failed3": stash_dict["failed3"]}, test_queue.failed_stash
)
remove_revs_mocker.assert_called_once_with([stash_dict["failed3"]])
purge_mocker.assert_has_calls([call("msg_failed3"), call("msg_success2")])

remove_revs_mocker.reset_mock()
Expand All @@ -120,9 +117,7 @@ def test_remove_done(test_queue, mocker):
],
any_order=True,
)
remove_revs_mocker.assert_called_once_with(
stash_dict, test_queue.failed_stash
)
remove_revs_mocker.assert_called_once_with(list(stash_dict.values()))


def test_remove_all(test_queue, mocker):
Expand Down

0 comments on commit 2099444

Please sign in to comment.