Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #10593 -- add --keep option for dvc experiments remove #10633

Merged
merged 25 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
725b02d
Add keep_selected parameter, and corresponding code to keep only the …
rmic Nov 23, 2024
e47cc72
test keep_selected_by_name
rmic Nov 23, 2024
df6a7ed
test keep_selected_by_rev
rmic Nov 23, 2024
df2a1ee
test keep_selected multiple, by name
rmic Nov 23, 2024
4f32f20
test keep all by name
rmic Nov 23, 2024
f913ffb
test keep by rev, with num=2
rmic Nov 23, 2024
c3a46da
added option to cli
rmic Nov 23, 2024
4e84b2d
refactoring to meet pr needs
rmic Nov 23, 2024
7f050ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
36a0e2b
fixed test_experiments to add keep_selected=False to remove tests
rmic Nov 23, 2024
87abac1
rename parameter to match cli option
rmic Nov 24, 2024
ca279b9
follow the normal path, then invert the selection before removing
rmic Nov 24, 2024
a903507
fixed tests for list ordering + fixed test with non existent name, it…
rmic Nov 24, 2024
cb43ca3
changed cli option comment
rmic Nov 24, 2024
3184250
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2024
dea54cb
fixed typing issue
rmic Nov 24, 2024
421c3f0
updated parameter name
rmic Nov 24, 2024
0a76842
removed handling queued experiments (since --queue would remove them …
rmic Nov 26, 2024
783b67a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2024
b83108a
code simplification, added __eq__ and __hash__ to be able to compare …
rmic Nov 27, 2024
61fda46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
c6c67b9
fixed linting issues
rmic Nov 28, 2024
9397e83
- --keep and --queue together raise an InvalidArgumentError
rmic Nov 29, 2024
9365b8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2024
fe5c5d7
re-run gh tests. Some tests which did not involve my changes started …
rmic Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dvc/commands/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def add_parser(subparsers, parent_parser):
hide_subparsers_from_help(experiments_subparsers)


def add_keep_selection_flag(experiments_subcmd_parser):
experiments_subcmd_parser.add_argument(
"--keep",
action="store_true",
default=False,
help="Keep the selected experiments instead of removing them.",
)


def add_rev_selection_flags(
experiments_subcmd_parser, command: str, default: bool = True
):
Expand Down
4 changes: 3 additions & 1 deletion dvc/commands/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run(self):
num=self.args.num,
queue=self.args.queue,
git_remote=self.args.git_remote,
keep=self.args.keep,
)
if removed:
ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}")
Expand All @@ -44,7 +45,7 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags
from . import add_keep_selection_flag, add_rev_selection_flags

EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
Expand All @@ -57,6 +58,7 @@ def add_parser(experiments_subparsers, parent_parser):
)
remove_group = experiments_remove_parser.add_mutually_exclusive_group()
add_rev_selection_flags(experiments_remove_parser, "Remove", False)
add_keep_selection_flag(experiments_remove_parser)
remove_group.add_argument(
"--queue", action="store_true", help="Remove all queued experiments."
)
Expand Down
9 changes: 9 additions & 0 deletions dvc/repo/experiments/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,12 @@ def from_ref(cls, ref: str):
baseline_sha = parts[2] + parts[3]
name = parts[4] if len(parts) == 5 else None
return cls(baseline_sha, name)

def __eq__(self, other):
if not isinstance(other, ExpRefInfo):
return False

return self.baseline_sha == other.baseline_sha and self.name == other.name

def __hash__(self):
return hash((self.baseline_sha, self.name))
14 changes: 13 additions & 1 deletion dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dvc.repo.scm_context import scm_context
from dvc.scm import Git, iter_revs

from .exceptions import UnresolvedExpNamesError
from .exceptions import InvalidArgumentError, UnresolvedExpNamesError
from .utils import exp_refs, exp_refs_by_baseline, push_refspec

if TYPE_CHECKING:
Expand All @@ -30,10 +30,16 @@ def remove( # noqa: C901, PLR0912
num: int = 1,
queue: bool = False,
git_remote: Optional[str] = None,
keep: bool = False,
) -> list[str]:
removed: list[str] = []

if all([keep, queue]):
raise InvalidArgumentError("Cannot use both `--keep` and `--queue`.")

if not any([exp_names, queue, all_commits, rev]):
return removed

celery_queue: LocalCeleryQueue = repo.experiments.celery_queue

if queue:
Expand All @@ -43,6 +49,7 @@ def remove( # noqa: C901, PLR0912

exp_ref_list: list[ExpRefInfo] = []
queue_entry_list: list[QueueEntry] = []

if exp_names:
results: dict[str, ExpRefAndQueueEntry] = (
celery_queue.get_ref_and_entry_by_names(exp_names, git_remote)
Expand Down Expand Up @@ -70,6 +77,10 @@ def remove( # noqa: C901, PLR0912
exp_ref_list.extend(exp_refs(repo.scm, git_remote))
removed = [ref.name for ref in exp_ref_list]

if keep:
exp_ref_list = list(set(exp_refs(repo.scm, git_remote)) - set(exp_ref_list))
removed = [ref.name for ref in exp_ref_list]

if exp_ref_list:
_remove_commited_exps(repo.scm, exp_ref_list, git_remote)

Expand All @@ -83,6 +94,7 @@ def remove( # noqa: C901, PLR0912

removed_refs = [str(r) for r in exp_ref_list]
notify_refs_to_studio(repo, git_remote, removed=removed_refs)

return removed


Expand Down
82 changes: 82 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,85 @@ def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage):

assert scm.get_ref(str(baseline_exp_ref)) is None
assert scm.get_ref(str(new_exp_ref)) is None


@pytest.mark.parametrize(
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
"keep, expected_removed",
[
[["exp1"], ["exp2", "exp3"]],
[["exp1", "exp2"], ["exp3"]],
[["exp1", "exp2", "exp3"], []],
[[], []], # remove does nothing if no experiments are specified
],
)
def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage, keep, expected_removed):
# Setup: Run experiments
refs = {}
for i in range(1, len(keep) + len(expected_removed) + 1):
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}"
)
refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results)))
assert scm.get_ref(str(refs[f"exp{i}"])) is not None

removed = dvc.experiments.remove(exp_names=keep, keep=True)
assert sorted(removed) == sorted(expected_removed)

for exp in expected_removed:
assert scm.get_ref(str(refs[exp])) is None

for exp in keep:
assert scm.get_ref(str(refs[exp])) is not None


def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage):
# non existent name should raise an error
with pytest.raises(UnresolvedExpNamesError):
dvc.experiments.remove(exp_names=["nonexistent"], keep=True)


@pytest.mark.parametrize(
"num_exps, rev, num, expected_removed",
[
[2, "exp1", 1, ["exp2"]],
[3, "exp3", 1, ["exp1", "exp2"]],
[3, "exp3", 2, ["exp1"]],
[3, "exp3", 3, []],
[3, "exp2", 2, ["exp3"]],
[4, "exp2", 2, ["exp3", "exp4"]],
[4, "exp4", 2, ["exp1", "exp2"]],
[1, None, 1, []], # remove does nothing if no experiments are specified
],
)
def test_keep_selected_by_rev(
tmp_dir, scm, dvc, exp_stage, num_exps, rev, num, expected_removed
):
refs = {}
revs = {}
# Setup: Run experiments and commit
for i in range(1, num_exps + 1):
scm.commit(f"commit{i}")
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}"
)
refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results)))
revs[f"exp{i}"] = scm.get_rev()
assert scm.get_ref(str(refs[f"exp{i}"])) is not None

# Keep the experiment from the new revision
removed = dvc.experiments.remove(rev=revs.get(rev), num=num, keep=True)
assert sorted(removed) == sorted(expected_removed)

# Check remaining experiments
for exp in expected_removed:
assert scm.get_ref(str(refs[exp])) is None

for exp, ref in refs.items():
if exp not in expected_removed:
assert scm.get_ref(str(ref)) is not None


def test_remove_with_queue_and_keep(tmp_dir, scm, dvc, exp_stage):
# This should raise an exception, until decided otherwise
with pytest.raises(InvalidArgumentError):
dvc.experiments.remove(queue=True, keep=True)
2 changes: 2 additions & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def test_experiments_remove_flag(dvc, scm, mocker, capsys, caplog):
num=2,
queue=False,
git_remote="myremote",
keep=False,
)


Expand All @@ -410,6 +411,7 @@ def test_experiments_remove_special(dvc, scm, mocker, capsys, caplog):
num=1,
queue=False,
git_remote="myremote",
keep=False,
)


Expand Down
Loading