Skip to content

Commit

Permalink
exp: Generate a human-readable name beforehand.
Browse files Browse the repository at this point in the history
- Add `get_random_exp_name` to utils.
Generates a name (`f"{adjective}-{noun}"`) by randomly chosing pairs of adjetive, noun.
Recieves `scm` and `baseline_rev` and retries the random choice until a non-existing ref is found.

Closes iterative#8650
  • Loading branch information
daavoo authored and karajan1001 committed Dec 10, 2022
1 parent 334f297 commit ac6c99d
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 36 deletions.
15 changes: 13 additions & 2 deletions dvc/commands/queue/kill.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,31 @@ class CmdQueueKill(CmdBase):
"""Kill exp task in queue."""

def run(self):
self.repo.experiments.celery_queue.kill(revs=self.args.task)
self.repo.experiments.celery_queue.kill(
revs=self.args.task, force=self.args.force
)

return 0


def add_parser(queue_subparsers, parent_parser):
QUEUE_KILL_HELP = "Kill actively running experiment queue tasks."
QUEUE_KILL_HELP = "Send SIGINT(Ctrl-C) to running experiment queue tasks."
queue_kill_parser = queue_subparsers.add_parser(
"kill",
parents=[parent_parser],
description=append_doc_link(QUEUE_KILL_HELP, "queue/kill"),
help=QUEUE_KILL_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
queue_kill_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Send SIGKILL (kill -9) instead to running experiment queue "
"tasks. (The default `queue kill` will terminate more gracefully,"
" collecting and cleaning up all resources)",
)
queue_kill_parser.add_argument(
"task",
nargs="*",
Expand Down
27 changes: 9 additions & 18 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ExpRefInfo,
)
from .stash import ExpStashEntry
from .utils import exp_refs_by_rev, unlocked_repo
from .utils import check_ref_format, exp_refs_by_rev, unlocked_repo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -236,17 +236,6 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False):
"\tdvc exp branch <exp> <branch>\n"
)

def _validate_new_ref(self, exp_ref: ExpRefInfo):
from .utils import check_ref_format

if not exp_ref.name:
return

check_ref_format(self.scm, exp_ref)

if self.scm.get_ref(str(exp_ref)):
raise ExperimentExistsError(exp_ref.name)

def new(
self,
queue: BaseStashQueue,
Expand All @@ -265,13 +254,15 @@ def new(

name = kwargs.get("name", None)
baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev()
exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name)

try:
self._validate_new_ref(exp_ref)
except ExperimentExistsError as err:
if not (kwargs.get("force", False) or kwargs.get("reset", False)):
raise err
if name:
exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name)
check_ref_format(self.scm, exp_ref)
force = kwargs.get("force", False)
reset = kwargs.get("reset", False)
if self.scm.get_ref(str(exp_ref)) and not (force or reset):
raise ExperimentExistsError(exp_ref.name)

return queue.put(*args, **kwargs)

def _resume_checkpoint(
Expand Down
15 changes: 11 additions & 4 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
from ..executor.local import WorkspaceExecutor
from ..refs import ExpRefInfo
from ..stash import ExpStash, ExpStashEntry
from ..utils import EXEC_PID_DIR, EXEC_TMP_DIR, exp_refs_by_rev, get_exp_rwlock
from ..utils import (
EXEC_PID_DIR,
EXEC_TMP_DIR,
exp_refs_by_rev,
get_exp_rwlock,
get_random_exp_name,
)
from .utils import get_remote_executor_refs

if TYPE_CHECKING:
Expand Down Expand Up @@ -281,7 +287,7 @@ def logs(
output.
"""

def _stash_exp(
def _stash_exp( # noqa: C901
self,
*args,
params: Optional[Dict[str, List[str]]] = None,
Expand Down Expand Up @@ -399,8 +405,9 @@ def _stash_exp(
run_env = {
DVC_EXP_BASELINE_REV: baseline_rev,
}
if name:
run_env[DVC_EXP_NAME] = name
if not name:
name = get_random_exp_name(self.scm, baseline_rev)
run_env[DVC_EXP_NAME] = name
if resume_rev:
run_env[DVCLIVE_RESUME] = "1"
self._pack_args(*args, run_env=run_env, **kwargs)
Expand Down
11 changes: 7 additions & 4 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,15 @@ def _get_running_task_ids(self) -> Set[str]:
return running_task_ids

def _try_to_kill_tasks(
self, to_kill: Dict[QueueEntry, str]
self, to_kill: Dict[QueueEntry, str], force: bool
) -> Dict[QueueEntry, str]:
fail_to_kill_entries: Dict[QueueEntry, str] = {}
for queue_entry, rev in to_kill.items():
try:
self.proc.kill(queue_entry.stash_rev)
if force:
self.proc.kill(queue_entry.stash_rev)
else:
self.proc.interrupt(queue_entry.stash_rev)
logger.debug(f"Task {rev} had been killed.")
except ProcessLookupError:
fail_to_kill_entries[queue_entry] = rev
Expand Down Expand Up @@ -333,7 +336,7 @@ def _mark_inactive_tasks_failure(self, remained_entries):
if remained_revs:
raise CannotKillTasksError(remained_revs)

def kill(self, revs: Collection[str]) -> None:
def kill(self, revs: Collection[str], force: bool = False) -> None:
name_dict: Dict[
str, Optional[QueueEntry]
] = self.match_queue_entry_by_name(set(revs), self.iter_active())
Expand All @@ -349,7 +352,7 @@ def kill(self, revs: Collection[str]) -> None:
raise UnresolvedQueueExpNamesError(missing_revs)

fail_to_kill_entries: Dict[QueueEntry, str] = self._try_to_kill_tasks(
to_kill
to_kill, force
)

if fail_to_kill_entries:
Expand Down
Loading

0 comments on commit ac6c99d

Please sign in to comment.