Skip to content

Commit

Permalink
exp: Generate a human-readable name beforehand.
Browse files Browse the repository at this point in the history
- Move `hash_exp` from experiment completion to experiment creation.

Change implementation to rely on `deps` and `params`(including overrides).
The existing implementation generated the hash based on the lock of stages executed during the experiment.
Hopefully, this doesn't break anything (it doesn't break tests).

- Add `get_random_name` utils.
Use `hash_exp` result as `seed`.

- Remove logic using `exp_hash` for completion checks.
Replace `exp_hash: Optional[str]` with `completed: bool = False`.

Closes #8650
  • Loading branch information
daavoo committed Dec 5, 2022
1 parent 74a13a4 commit f6f771d
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 87 deletions.
32 changes: 12 additions & 20 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ExpRefInfo,
)
from .stash import ExpStashEntry
from .utils import exp_refs_by_rev, unlocked_repo
from .utils import check_ref_format, exp_refs_by_rev, unlocked_repo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -198,12 +198,12 @@ def reproduce_celery(
result = self.celery_queue.get_result(entry)
except FileNotFoundError:
result = None
if result is None or result.exp_hash is None:
if result is None or not result.completed:
name = entry.name or entry.stash_rev[:7]
failed.append(name)
elif result.ref_info:
exp_rev = self.scm.get_ref(str(result.ref_info))
results[exp_rev] = result.exp_hash
results[exp_rev] = result.completed
except KeyboardInterrupt:
ui.write(
"Experiment(s) are still executing in the background. To "
Expand Down Expand Up @@ -236,17 +236,6 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False):
"\tdvc exp branch <exp> <branch>\n"
)

def _validate_new_ref(self, exp_ref: ExpRefInfo):
from .utils import check_ref_format

if not exp_ref.name:
return

check_ref_format(self.scm, exp_ref)

if self.scm.get_ref(str(exp_ref)):
raise ExperimentExistsError(exp_ref.name)

def new(
self,
queue: BaseStashQueue,
Expand All @@ -265,13 +254,16 @@ def new(

name = kwargs.get("name", None)
baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev()
exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name)

try:
self._validate_new_ref(exp_ref)
except ExperimentExistsError as err:
if not (kwargs.get("force", False) or kwargs.get("reset", False)):
raise err
if name:
exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name)
check_ref_format(self.scm, exp_ref)
exists = self.scm.get_ref(str(exp_ref))
force = kwargs.get("force", False)
reset = kwargs.get("reset", False)
if exists and not (force or reset):
raise ExperimentExistsError(exp_ref.name)

return queue.put(*args, **kwargs)

def _resume_checkpoint(
Expand Down
59 changes: 20 additions & 39 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@

from dvc.env import DVC_EXP_AUTO_PUSH, DVC_EXP_GIT_REMOTE
from dvc.exceptions import DvcException
from dvc.stage.serialize import to_lockfile
from dvc.ui import ui
from dvc.utils import dict_sha256, env2bool, relpath
from dvc.utils import env2bool, relpath
from dvc.utils.fs import remove

from ..exceptions import (
Expand All @@ -44,15 +43,14 @@

from dvc.repo import Repo
from dvc.repo.experiments.stash import ExpStashEntry
from dvc.stage import PipelineStage

logger = logging.getLogger(__name__)


class ExecutorResult(NamedTuple):
exp_hash: Optional[str]
ref_info: Optional["ExpRefInfo"]
force: bool
completed: bool = False


class TaskStatus(IntEnum):
Expand All @@ -74,7 +72,7 @@ class ExecutorInfo:
dvc_dir: str
name: Optional[str] = None
wdir: Optional[str] = None
result_hash: Optional[str] = None
completed: bool = False
result_ref: Optional[str] = None
result_force: bool = False
status: TaskStatus = TaskStatus.PENDING
Expand All @@ -91,12 +89,12 @@ def asdict(self):

@property
def result(self) -> Optional["ExecutorResult"]:
if self.result_hash is None:
if not self.completed:
return None
return ExecutorResult(
self.result_hash,
ExpRefInfo.from_ref(self.result_ref) if self.result_ref else None,
self.result_force,
self.completed,
)

def dump_json(self, filename: str):
Expand Down Expand Up @@ -189,7 +187,7 @@ def collect_cache(
def info(self) -> "ExecutorInfo":
if self.result is not None:
result_dict: Dict[str, Any] = {
"result_hash": self.result.exp_hash,
"completed": self.result.completed,
"result_ref": (
str(self.result.ref_info) if self.result.ref_info else None
),
Expand All @@ -211,15 +209,15 @@ def info(self) -> "ExecutorInfo":

@classmethod
def from_info(cls: Type[_T], info: "ExecutorInfo") -> _T:
if info.result_hash:
if info.completed:
result: Optional["ExecutorResult"] = ExecutorResult(
info.result_hash,
(
ExpRefInfo.from_ref(info.result_ref)
if info.result_ref
else None
),
info.result_force,
completed=info.completed,
)
else:
result = None
Expand Down Expand Up @@ -262,16 +260,6 @@ def _from_stash_entry(
)
return executor

@staticmethod
def hash_exp(stages: Iterable["PipelineStage"]) -> str:
from dvc.stage import PipelineStage

exp_data = {}
for stage in stages:
if isinstance(stage, PipelineStage):
exp_data.update(to_lockfile(stage))
return dict_sha256(exp_data)

def cleanup(self, infofile: str):
if infofile is not None:
info = ExecutorInfo.load_json(infofile)
Expand Down Expand Up @@ -399,10 +387,9 @@ def reproduce(
) -> "ExecutorResult":
"""Run dvc repro and return the result.
Returns tuple of (exp_hash, exp_ref, force) where exp_hash is the
experiment hash (or None on error), exp_ref is the experiment ref,
and force is a bool specifying whether or not this experiment
should force overwrite any existing duplicates.
Returns tuple of (exp_ref, forcem, completed) where exp_ref is the
experiment ref, and force is a bool specifying whether or not this
experiment should force overwrite any existing duplicates.
"""
from dvc.repo.checkout import checkout as dvc_checkout
from dvc.repo.reproduce import reproduce as dvc_reproduce
Expand All @@ -423,7 +410,7 @@ def filter_pipeline(stages):
[stage for stage in stages if isinstance(stage, PipelineStage)]
)

exp_hash: Optional[str] = None
completed: bool = False
exp_ref: Optional["ExpRefInfo"] = None
repro_force: bool = False

Expand Down Expand Up @@ -488,35 +475,34 @@ def filter_pipeline(stages):
checkpoint_func=checkpoint_func,
**kwargs,
)

exp_hash = cls.hash_exp(stages)
if not repro_dry:
ref, exp_ref, repro_force = cls._repro_commit(
dvc,
info,
stages,
exp_hash,
checkpoint_reset,
auto_push,
git_remote,
repro_force,
)
info.result_hash = exp_hash
completed = True
info.completed = completed
info.result_ref = ref
info.result_force = repro_force

# ideally we would return stages here like a normal repro() call, but
# stages is not currently picklable and cannot be returned across
# multiprocessing calls
return ExecutorResult(exp_hash, exp_ref, repro_force)
return ExecutorResult(
completed=completed, ref_info=exp_ref, force=repro_force
)

@classmethod
def _repro_commit(
cls,
dvc,
info,
stages,
exp_hash,
checkpoint_reset,
auto_push,
git_remote,
Expand All @@ -531,7 +517,6 @@ def _repro_commit(
repro_force = True
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=repro_force,
checkpoint=is_checkpoint,
Expand Down Expand Up @@ -647,13 +632,10 @@ def checkpoint_callback(
scm: "Git",
name: Optional[str],
force: bool,
unchanged: Iterable["PipelineStage"],
stages: Iterable["PipelineStage"],
):
try:
exp_hash = cls.hash_exp(list(stages) + list(unchanged))
exp_rev = cls.commit(
scm, exp_hash, exp_name=name, force=force, checkpoint=True
scm, exp_name=name, force=force, checkpoint=True
)

if env2bool(DVC_EXP_AUTO_PUSH):
Expand All @@ -667,7 +649,6 @@ def checkpoint_callback(
def commit(
cls,
scm: "Git",
exp_hash: str,
exp_name: Optional[str] = None,
force: bool = False,
checkpoint: bool = False,
Expand All @@ -685,7 +666,7 @@ def commit(
logger.debug("Commit to current experiment branch '%s'", branch)
else:
baseline_rev = scm.get_ref(EXEC_BASELINE)
name = exp_name if exp_name else f"exp-{exp_hash[:5]}"
name = exp_name
ref_info = ExpRefInfo(baseline_rev, name)
branch = str(ref_info)
old_ref = None
Expand All @@ -701,7 +682,7 @@ def commit(
logger.debug("Commit to new experiment branch '%s'", branch)

scm.add([], update=True)
scm.commit(f"dvc: commit experiment {exp_hash}", no_verify=True)
scm.commit(f"dvc: commit experiment {exp_name}", no_verify=True)
new_rev = scm.get_rev()
if check_conflict:
new_rev = cls._raise_ref_conflict(scm, branch, new_rev, checkpoint)
Expand Down
13 changes: 6 additions & 7 deletions dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def save(
) -> ExecutorResult:
from dvc.repo import Repo

exp_hash: Optional[str] = None
exp_ref: Optional[ExpRefInfo] = None

dvc = Repo(os.path.join(info.root_dir, info.dvc_dir))
Expand All @@ -270,13 +269,11 @@ def save(
os.chdir(dvc.root_dir)

try:
stages = dvc.commit([], force=force)
exp_hash = cls.hash_exp(stages)
dvc.commit([], force=force)
if include_untracked:
dvc.scm.add(include_untracked)
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=force,
)
Expand All @@ -291,9 +288,9 @@ def save(
"\t%s",
", ".join(untracked),
)
info.result_hash = exp_hash
info.completed = True
info.result_ref = ref
info.result_force = False
info.result_force = force
info.status = TaskStatus.SUCCESS
except DvcException:
info.status = TaskStatus.FAILED
Expand All @@ -302,4 +299,6 @@ def save(
dvc.close()
os.chdir(old_cwd)

return ExecutorResult(ref, exp_ref, info.result_force)
return ExecutorResult(
exp_ref, info.result_force, completed=info.completed
)
4 changes: 2 additions & 2 deletions dvc/repo/experiments/executor/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def reproduce(
fs.fs.execute("; ".join(cmd), stdout=stdout, stderr=stderr)
with fs.open(infofile) as fobj:
result_info = ExecutorInfo.from_dict(json.load(fobj))
if result_info.result_hash:
if result_info.completed:
return result_info.result
except ProcessError:
pass
return ExecutorResult(None, None, False)
return ExecutorResult(None, False, False)
24 changes: 17 additions & 7 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
from ..executor.local import WorkspaceExecutor
from ..refs import ExpRefInfo
from ..stash import ExpStash, ExpStashEntry
from ..utils import EXEC_PID_DIR, EXEC_TMP_DIR, exp_refs_by_rev, get_exp_rwlock
from ..utils import (
EXEC_PID_DIR,
EXEC_TMP_DIR,
exp_refs_by_rev,
get_exp_rwlock,
get_random_name,
hash_exp,
)
from .utils import get_remote_executor_refs

if TYPE_CHECKING:
Expand Down Expand Up @@ -281,7 +288,7 @@ def logs(
output.
"""

def _stash_exp(
def _stash_exp( # noqa: C901
self,
*args,
params: Optional[Dict[str, List[str]]] = None,
Expand Down Expand Up @@ -395,6 +402,9 @@ def _stash_exp(
name,
)

if not name:
name = get_random_name(seed=hash_exp(self.repo))

# save additional repro command line arguments
run_env = {
DVC_EXP_BASELINE_REV: baseline_rev,
Expand Down Expand Up @@ -614,8 +624,8 @@ def collect_git(
exp: "Experiments",
executor: BaseExecutor,
exec_result: ExecutorResult,
) -> Dict[str, str]:
results = {}
) -> Dict[str, bool]:
results: Dict[str, bool] = {}

def on_diverged(ref: str, checkpoint: bool):
ref_info = ExpRefInfo.from_ref(ref)
Expand All @@ -634,9 +644,9 @@ def on_diverged(ref: str, checkpoint: bool):
):
exp_rev = exp.scm.get_ref(ref)
if exp_rev:
assert exec_result.exp_hash
assert exec_result.completed
logger.debug("Collected experiment '%s'.", exp_rev[:7])
results[exp_rev] = exec_result.exp_hash
results[exp_rev] = exec_result.completed

return results

Expand All @@ -646,7 +656,7 @@ def collect_executor(
exp: "Experiments",
executor: BaseExecutor,
exec_result: ExecutorResult,
) -> Dict[str, str]:
) -> Dict[str, bool]:
results = cls.collect_git(exp, executor, exec_result)

if exec_result.ref_info is not None:
Expand Down
Loading

0 comments on commit f6f771d

Please sign in to comment.