diff --git a/.dvc/.gitignore b/.dvc/.gitignore index 5f594f74ac..53a678bba6 100644 --- a/.dvc/.gitignore +++ b/.dvc/.gitignore @@ -9,3 +9,4 @@ /pkg /repos /tmp +/experiments diff --git a/dvc/cli.py b/dvc/cli.py index 44af80d6aa..1392332040 100644 --- a/dvc/cli.py +++ b/dvc/cli.py @@ -15,6 +15,7 @@ data_sync, destroy, diff, + experiments, freeze, gc, get, @@ -77,6 +78,7 @@ update, git_hook, plots, + experiments, ] diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py new file mode 100644 index 0000000000..3ec817bc70 --- /dev/null +++ b/dvc/command/experiments.py @@ -0,0 +1,358 @@ +import argparse +import io +import logging +from collections import OrderedDict + +from dvc.command.base import CmdBase, append_doc_link, fix_subparsers +from dvc.command.metrics import DEFAULT_PRECISION +from dvc.exceptions import DvcException + +logger = logging.getLogger(__name__) + + +def _update_names(names, items): + from flatten_json import flatten + + for name, item in items: + if isinstance(item, dict): + item = flatten(item, ".") + names.update(item.keys()) + else: + names.add(name) + + +def _collect_names(all_experiments): + metric_names = set() + param_names = set() + + for _, experiments in all_experiments.items(): + for exp in experiments.values(): + _update_names(metric_names, exp.get("metrics", {}).items()) + _update_names(param_names, exp.get("params", {}).items()) + + return sorted(metric_names), sorted(param_names) + + +def _collect_rows( + base_rev, experiments, metric_names, param_names, precision=None +): + from flatten_json import flatten + + if precision is None: + precision = DEFAULT_PRECISION + + def _round(val): + if isinstance(val, float): + return round(val, precision) + + return val + + def _extend(row, names, items): + for fname, item in items: + if isinstance(item, dict): + item = flatten(item, ".") + else: + item = {fname: item} + for name in names: + if name in item: + row.append(str(_round(item[name]))) + else: + row.append("-") + + for i, (rev, exp) in enumerate(experiments.items()): + row = [] + style = None + if rev == "baseline": + row.append(f"{base_rev}") + style = "bold" + elif i < len(experiments) - 1: + row.append(f"├── {rev[:7]}") + else: + row.append(f"└── {rev[:7]}") + + _extend(row, metric_names, exp.get("metrics", {}).items()) + _extend(row, param_names, exp.get("params", {}).items()) + + yield row, style + + +def _show_experiments(all_experiments, console, precision=None): + from rich.table import Table + from dvc.scm.git import Git + + metric_names, param_names = _collect_names(all_experiments) + + table = Table(row_styles=["white", "bright_white"]) + table.add_column("Experiment", header_style="black on grey93") + for name in metric_names: + table.add_column( + name, justify="right", header_style="black on cornsilk1" + ) + for name in param_names: + table.add_column( + name, justify="left", header_style="black on light_cyan1" + ) + + for base_rev, experiments in all_experiments.items(): + if Git.is_sha(base_rev): + base_rev = base_rev[:7] + + for row, style, in _collect_rows( + base_rev, + experiments, + metric_names, + param_names, + precision=precision, + ): + table.add_row(*row, style=style) + + console.print(table) + + +class CmdExperimentsShow(CmdBase): + def run(self): + from rich.console import Console + from dvc.utils.pager import pager + + if not self.repo.experiments: + return 0 + + try: + all_experiments = self.repo.experiments.show( + all_branches=self.args.all_branches, + all_tags=self.args.all_tags, + all_commits=self.args.all_commits, + ) + + # Note: rich does not currently include a native way to force + # infinite width for use with a pager + console = Console( + file=io.StringIO(), force_terminal=True, width=9999 + ) + + _show_experiments(all_experiments, console) + + pager(console.file.getvalue()) + except DvcException: + logger.exception("failed to show experiments") + return 1 + + return 0 + + +class CmdExperimentsCheckout(CmdBase): + def run(self): + if not self.repo.experiments: + return 0 + + self.repo.experiments.checkout( + self.args.experiment, force=self.args.force + ) + + return 0 + + +def _show_diff( + diff, title="", markdown=False, no_path=False, old=False, precision=None +): + from dvc.utils.diff import table + + if precision is None: + precision = DEFAULT_PRECISION + + def _round(val): + if isinstance(val, float): + return round(val, precision) + + return val + + rows = [] + for fname, diff_ in diff.items(): + sorted_diff = OrderedDict(sorted(diff_.items())) + for item, change in sorted_diff.items(): + row = [] if no_path else [fname] + row.append(item) + if old: + row.append(_round(change.get("old"))) + row.append(_round(change["new"])) + row.append(_round(change.get("diff", "diff not supported"))) + rows.append(row) + + header = [] if no_path else ["Path"] + header.append(title) + if old: + header.extend(["Old", "New"]) + else: + header.append("Value") + header.append("Change") + + return table(header, rows, markdown) + + +class CmdExperimentsDiff(CmdBase): + def run(self): + if not self.repo.experiments: + return 0 + + try: + diff = self.repo.experiments.diff( + a_rev=self.args.a_rev, + b_rev=self.args.b_rev, + all=self.args.all, + ) + + if self.args.show_json: + import json + + logger.info(json.dumps(diff)) + else: + diffs = [("metrics", "Metric"), ("params", "Param")] + for key, title in diffs: + table = _show_diff( + diff[key], + title=title, + markdown=self.args.show_md, + no_path=self.args.no_path, + old=self.args.old, + precision=self.args.precision, + ) + if table: + logger.info(table) + logger.info("") + + except DvcException: + logger.exception("failed to show experiments diff") + return 1 + + return 0 + + +def add_parser(subparsers, parent_parser): + EXPERIMENTS_HELP = "Commands to display and compare experiments." + + experiments_parser = subparsers.add_parser( + "experiments", + parents=[parent_parser], + description=append_doc_link(EXPERIMENTS_HELP, "experiments"), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + experiments_subparsers = experiments_parser.add_subparsers( + dest="cmd", + help="Use `dvc experiments CMD --help` to display " + "command-specific help.", + ) + + fix_subparsers(experiments_subparsers) + + EXPERIMENTS_SHOW_HELP = "Print experiments." + experiments_show_parser = experiments_subparsers.add_parser( + "show", + parents=[parent_parser], + description=append_doc_link(EXPERIMENTS_SHOW_HELP, "experiments/show"), + help=EXPERIMENTS_SHOW_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + experiments_show_parser.add_argument( + "-a", + "--all-branches", + action="store_true", + default=False, + help="Show metrics for all branches.", + ) + experiments_show_parser.add_argument( + "-T", + "--all-tags", + action="store_true", + default=False, + help="Show metrics for all tags.", + ) + experiments_show_parser.add_argument( + "--all-commits", + action="store_true", + default=False, + help="Show metrics for all commits.", + ) + experiments_show_parser.set_defaults(func=CmdExperimentsShow) + + EXPERIMENTS_CHECKOUT_HELP = "Checkout experiments." + experiments_checkout_parser = experiments_subparsers.add_parser( + "checkout", + parents=[parent_parser], + description=append_doc_link( + EXPERIMENTS_CHECKOUT_HELP, "experiments/checkout" + ), + help=EXPERIMENTS_CHECKOUT_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + experiments_checkout_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Overwrite your current workspace with changes from the " + "experiment.", + ) + experiments_checkout_parser.add_argument( + "experiment", help="Checkout this experiment.", + ) + experiments_checkout_parser.set_defaults(func=CmdExperimentsCheckout) + + EXPERIMENTS_DIFF_HELP = ( + "Show changes between experiments in the DVC repository." + ) + experiments_diff_parser = experiments_subparsers.add_parser( + "diff", + parents=[parent_parser], + description=append_doc_link(EXPERIMENTS_DIFF_HELP, "experiments/diff"), + help=EXPERIMENTS_DIFF_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + experiments_diff_parser.add_argument( + "a_rev", nargs="?", help="Old experiment to compare (defaults to HEAD)" + ) + experiments_diff_parser.add_argument( + "b_rev", + nargs="?", + help="New experiment to compare (defaults to the current workspace)", + ) + experiments_diff_parser.add_argument( + "--all", + action="store_true", + default=False, + help="Show unchanged metrics/params as well.", + ) + experiments_diff_parser.add_argument( + "--show-json", + action="store_true", + default=False, + help="Show output in JSON format.", + ) + experiments_diff_parser.add_argument( + "--show-md", + action="store_true", + default=False, + help="Show tabulated output in the Markdown format (GFM).", + ) + experiments_diff_parser.add_argument( + "--old", + action="store_true", + default=False, + help="Show old metric/param value.", + ) + experiments_diff_parser.add_argument( + "--no-path", + action="store_true", + default=False, + help="Don't show metric/param path.", + ) + experiments_diff_parser.add_argument( + "--precision", + type=int, + help=( + "Round metrics/params to `n` digits precision after the decimal " + f"point. Rounds to {DEFAULT_PRECISION} digits by default." + ), + metavar="", + ) + experiments_diff_parser.set_defaults(func=CmdExperimentsDiff) diff --git a/dvc/command/repro.py b/dvc/command/repro.py index 986b9f0b48..2b06ec03e5 100644 --- a/dvc/command/repro.py +++ b/dvc/command/repro.py @@ -40,6 +40,7 @@ def run(self): downstream=self.args.downstream, recursive=self.args.recursive, force_downstream=self.args.force_downstream, + experiment=self.args.experiment, ) if len(stages) == 0: @@ -166,4 +167,11 @@ def add_parser(subparsers, parent_parser): default=False, help="Start from the specified stages when reproducing pipelines.", ) + repro_parser.add_argument( + "-e", + "--experiment", + action="store_true", + default=False, + help=argparse.SUPPRESS, + ) repro_parser.set_defaults(func=CmdRepro) diff --git a/dvc/config.py b/dvc/config.py index 3b083597a7..e1c9d93284 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -127,6 +127,7 @@ class RelPath(str): Optional("analytics", default=True): Bool, Optional("hardlink_lock", default=False): Bool, Optional("no_scm", default=False): Bool, + Optional("experiments", default=False): Bool, }, "cache": { "local": str, diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index a1c3bc7554..46b57a6054 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -74,6 +74,7 @@ def __init__(self, root_dir=None, scm=None, rev=None): from dvc.scm import SCM from dvc.cache import Cache from dvc.data_cloud import DataCloud + from dvc.repo.experiments import Experiments from dvc.repo.metrics import Metrics from dvc.repo.plots import Plots from dvc.repo.params import Params @@ -126,6 +127,11 @@ def __init__(self, root_dir=None, scm=None, rev=None): self.plots = Plots(self) self.params = Params(self) + try: + self.experiments = Experiments(self) + except NotImplementedError: + self.experiments = None + self._ignore() @property @@ -190,7 +196,12 @@ def unprotect(self, target): return self.cache.local.tree.unprotect(PathInfo(target)) def _ignore(self): - flist = [self.config.files["local"], self.tmp_dir] + flist = [ + self.config.files["local"], + self.tmp_dir, + ] + if self.experiments: + flist.append(self.experiments.exp_dir) if path_isin(self.cache.local.cache_dir, self.root_dir): flist += [self.cache.local.cache_dir] diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py new file mode 100644 index 0000000000..dcbe550354 --- /dev/null +++ b/dvc/repo/experiments/__init__.py @@ -0,0 +1,197 @@ +import logging +import os +import tempfile +from contextlib import contextmanager + +from funcy import cached_property + +from dvc.exceptions import DvcException +from dvc.scm.git import Git +from dvc.stage.serialize import to_lockfile +from dvc.utils import dict_sha256, env2bool, relpath +from dvc.utils.fs import remove + +logger = logging.getLogger(__name__) + + +class UnchangedExperimentError(DvcException): + pass + + +class Experiments: + """Class that manages experiments in a DVC repo. + + Args: + repo (dvc.repo.Repo): repo instance that these experiments belong to. + """ + + EXPERIMENTS_DIR = "experiments" + + def __init__(self, repo): + if not ( + env2bool("DVC_TEST") + or repo.config["core"].get("experiments", False) + ): + raise NotImplementedError + + self.repo = repo + + @cached_property + def exp_dir(self): + return os.path.join(self.repo.dvc_dir, self.EXPERIMENTS_DIR) + + @cached_property + def scm(self): + """Experiments clone scm instance.""" + if os.path.exists(self.exp_dir): + return Git(self.exp_dir) + return self._init_clone() + + @cached_property + def exp_dvc_dir(self): + dvc_dir = relpath(self.repo.dvc_dir, self.repo.scm.root_dir) + return os.path.join(self.exp_dir, dvc_dir) + + @cached_property + def exp_dvc(self): + """Return clone dvc Repo instance.""" + from dvc.repo import Repo + + return Repo(self.exp_dvc_dir) + + @staticmethod + def exp_hash(stages): + exp_data = {} + for stage in stages: + exp_data.update(to_lockfile(stage)) + return dict_sha256(exp_data) + + @contextmanager + def chdir(self): + cwd = os.getcwd() + os.chdir(self.exp_dvc.root_dir) + yield + os.chdir(cwd) + + def _init_clone(self): + src_dir = self.repo.scm.root_dir + logger.debug("Initializing experiments clone") + git = Git.clone(src_dir, self.exp_dir) + self._config_clone() + return git + + def _config_clone(self): + dvc_dir = relpath(self.repo.dvc_dir, self.repo.scm.root_dir) + local_config = os.path.join(self.exp_dir, dvc_dir, "config.local") + cache_dir = self.repo.cache.local.cache_dir + logger.debug("Writing experiments local config '%s'", local_config) + with open(local_config, "w") as fobj: + fobj.write(f"[cache]\n dir = {cache_dir}") + + def _scm_checkout(self, rev): + self.scm.repo.git.reset(hard=True) + if not Git.is_sha(rev) or not self.scm.has_rev(rev): + self.scm.fetch(all=True) + logger.debug("Checking out base experiment commit '%s'", rev) + self.scm.checkout(rev) + + def _patch_exp(self): + """Create a patch based on the current (parent) workspace and apply it + to the experiment workspace. + """ + tmp = tempfile.NamedTemporaryFile(delete=False).name + try: + self.repo.scm.repo.git.diff(patch=True, output=tmp) + if os.path.getsize(tmp): + logger.debug("Patching experiment workspace") + self.scm.repo.git.apply(tmp) + else: + raise UnchangedExperimentError( + "Experiment identical to baseline commit." + ) + finally: + remove(tmp) + + def _commit(self, stages, check_exists=True, branch=True, rev=None): + """Commit stages as an experiment and return the commit SHA.""" + hash_ = self.exp_hash(stages) + exp_name = f"{rev[:7]}-{hash_}" + if branch: + if check_exists and exp_name in self.scm.list_branches(): + logger.debug("Using existing experiment branch '%s'", exp_name) + return self.scm.resolve_rev(exp_name) + self.scm.checkout(exp_name, create_new=True) + logger.debug("Commit new experiment branch '%s'", exp_name) + self.scm.repo.git.add(A=True) + self.scm.commit(f"Add experiment {exp_name}") + return self.scm.get_rev() + + def _reproduce(self, *args, **kwargs): + """Run `dvc repro` inside the experiments workspace.""" + with self.chdir(): + return self.exp_dvc.reproduce(*args, **kwargs) + + def new(self, *args, workspace=True, **kwargs): + """Create a new experiment. + + Experiment will be reproduced and checked out into the user's + workspace. + """ + rev = self.repo.scm.get_rev() + self._scm_checkout(rev) + if workspace: + try: + self._patch_exp() + except UnchangedExperimentError as exc: + logger.info("Reproducing existing experiment '%s'.", rev[:7]) + raise exc + else: + # configure params via command line here + pass + self.exp_dvc.checkout() + stages = self._reproduce(*args, **kwargs) + exp_rev = self._commit(stages, rev=rev) + self.checkout_exp(exp_rev, force=True) + logger.info("Generated experiment '%s'.", exp_rev[:7]) + return stages + + def checkout_exp(self, rev, force=False): + """Checkout an experiment to the user's workspace.""" + from git.exc import GitCommandError + from dvc.repo.checkout import _checkout as dvc_checkout + + if force: + self.repo.scm.repo.git.reset(hard=True) + logger.debug(f"checkout {rev}") + self._scm_checkout(rev) + + tmp = tempfile.NamedTemporaryFile(delete=False).name + self.scm.repo.head.commit.diff("HEAD~1", patch=True, output=tmp) + try: + if os.path.getsize(tmp): + logger.debug("Patching local workspace") + self.repo.scm.repo.git.apply(tmp, reverse=True) + dvc_checkout(self.repo) + except GitCommandError: + raise DvcException( + "Checkout failed, experiment contains changes which " + "conflict with your current workspace. To overwrite " + "your workspace, use `dvc experiments checkout --force`." + ) + finally: + remove(tmp) + + def checkout(self, *args, **kwargs): + from dvc.repo.experiments.checkout import checkout + + return checkout(self.repo, *args, **kwargs) + + def diff(self, *args, **kwargs): + from dvc.repo.experiments.diff import diff + + return diff(self.repo, *args, **kwargs) + + def show(self, *args, **kwargs): + from dvc.repo.experiments.show import show + + return show(self.repo, *args, **kwargs) diff --git a/dvc/repo/experiments/checkout.py b/dvc/repo/experiments/checkout.py new file mode 100644 index 0000000000..a1e2e1f1a4 --- /dev/null +++ b/dvc/repo/experiments/checkout.py @@ -0,0 +1,17 @@ +import logging + +from dvc.repo import locked +from dvc.repo.scm_context import scm_context + +logger = logging.getLogger(__name__) + + +@locked +@scm_context +def checkout(repo, rev, *args, **kwargs): + repo.experiments.checkout_exp(rev, *args, **kwargs) + logger.info( + "Changes for experiment '%s' have been applied to your current " + "workspace.", + rev, + ) diff --git a/dvc/repo/experiments/diff.py b/dvc/repo/experiments/diff.py new file mode 100644 index 0000000000..8dd5538977 --- /dev/null +++ b/dvc/repo/experiments/diff.py @@ -0,0 +1,36 @@ +import logging + +from dvc.utils.diff import diff as _diff +from dvc.utils.diff import format_dict + +logger = logging.getLogger(__name__) + + +def diff(repo, *args, a_rev=None, b_rev=None, **kwargs): + from dvc.repo.experiments.show import _collect_experiment + + if repo.scm.no_commits: + return {} + + if a_rev: + with repo.experiments.chdir(): + old = _collect_experiment(repo.experiments.exp_dvc, a_rev) + else: + old = _collect_experiment(repo, "HEAD") + + if b_rev: + with repo.experiments.chdir(): + new = _collect_experiment(repo.experiments.exp_dvc, b_rev) + else: + new = _collect_experiment(repo, "workspace") + + with_unchanged = kwargs.pop("all", False) + + return { + key: _diff( + format_dict(old[key]), + format_dict(new[key]), + with_unchanged=with_unchanged, + ) + for key in ["metrics", "params"] + } diff --git a/dvc/repo/experiments/show.py b/dvc/repo/experiments/show.py new file mode 100644 index 0000000000..bbb2c706ff --- /dev/null +++ b/dvc/repo/experiments/show.py @@ -0,0 +1,64 @@ +import logging +import re +from collections import OrderedDict, defaultdict + +from dvc.repo import locked +from dvc.repo.metrics.show import _collect_metrics, _read_metrics +from dvc.repo.params.show import _collect_configs, _read_params + +logger = logging.getLogger(__name__) + + +EXP_RE = re.compile(r"(?P[a-f0-9]{7})-(?P[a-f0-9]+)") + + +def _collect_experiment(repo, branch): + res = defaultdict(dict) + for rev in repo.brancher(revs=[branch]): + configs = _collect_configs(repo) + params = _read_params(repo, configs, rev) + if params: + res["params"] = params + + metrics = _collect_metrics(repo, None, False) + vals = _read_metrics(repo, metrics, rev) + if vals: + res["metrics"] = vals + + return res + + +@locked +def show( + repo, all_branches=False, all_tags=False, revs=None, all_commits=False +): + res = defaultdict(OrderedDict) + + if revs is None: + revs = [repo.scm.get_rev()] + + revs = set( + repo.brancher( + revs=revs, + all_branches=all_branches, + all_tags=all_tags, + all_commits=all_commits, + ) + ) + + for rev in revs: + res[rev]["baseline"] = _collect_experiment(repo, rev) + + for exp_branch in repo.experiments.scm.list_branches(): + m = re.match(EXP_RE, exp_branch) + if m: + rev = repo.scm.resolve_rev(m.group("rev_sha")) + if rev in revs: + exp_rev = repo.experiments.scm.resolve_rev(exp_branch) + with repo.experiments.chdir(): + experiment = _collect_experiment( + repo.experiments.exp_dvc, exp_branch + ) + res[rev][exp_rev] = experiment + + return res diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 1c33e23b40..e127c4ea42 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -1,6 +1,7 @@ import logging from dvc.exceptions import InvalidArgumentError, ReproductionError +from dvc.repo.experiments import UnchangedExperimentError from dvc.repo.scm_context import scm_context from . import locked @@ -69,6 +70,19 @@ def reproduce( "Neither `target` nor `--all-pipelines` are specified." ) + experiment = kwargs.pop("experiment", False) + if experiment and self.experiments: + try: + return self.experiments.new( + target=target, + recursive=recursive, + all_pipelines=all_pipelines, + **kwargs + ) + except UnchangedExperimentError: + # If experiment contains no changes, just run regular repro + pass + interactive = kwargs.get("interactive", False) if not interactive: kwargs["interactive"] = self.config["core"].get("interactive", False) diff --git a/setup.py b/setup.py index f0646c7a2c..332b69df96 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ def run(self): "pygtrie==2.3.2", "dpath>=2.0.1,<3", "shtab>=1.1.0,<2", + "rich>=3.0.5", ] diff --git a/tests/func/experiments/__init__.py b/tests/func/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py new file mode 100644 index 0000000000..33d41c6d21 --- /dev/null +++ b/tests/func/experiments/test_experiments.py @@ -0,0 +1,24 @@ +from tests.func.test_repro_multistage import COPY_SCRIPT + + +def test_new_simple(tmp_dir, scm, dvc, mocker): + tmp_dir.gen("copy.py", COPY_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + stage = dvc.run( + cmd="python copy.py params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + name="copy-file", + ) + scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) + scm.commit("init") + + tmp_dir.gen("params.yaml", "foo: 2") + + new_mock = mocker.spy(dvc.experiments, "new") + dvc.reproduce(stage.addressing, experiment=True) + + new_mock.assert_called_once() + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text() == "foo: 2" diff --git a/tests/func/experiments/test_show.py b/tests/func/experiments/test_show.py new file mode 100644 index 0000000000..bfcb595e49 --- /dev/null +++ b/tests/func/experiments/test_show.py @@ -0,0 +1,19 @@ +from tests.func.test_repro_multistage import COPY_SCRIPT + + +def test_show_simple(tmp_dir, scm, dvc): + tmp_dir.gen("copy.py", COPY_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + dvc.run( + cmd="python copy.py params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + single_stage=True, + ) + + assert dvc.experiments.show()["workspace"] == { + "baseline": { + "metrics": {"metrics.yaml": {"foo": 1}}, + "params": {"params.yaml": {"foo": 1}}, + } + } diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py new file mode 100644 index 0000000000..8b1048147d --- /dev/null +++ b/tests/unit/command/test_experiments.py @@ -0,0 +1,51 @@ +from dvc.cli import parse_args +from dvc.command.experiments import CmdExperimentsDiff, CmdExperimentsShow + + +def test_experiments_diff(dvc, mocker): + cli_args = parse_args( + [ + "experiments", + "diff", + "HEAD~10", + "HEAD~1", + "--all", + "--show-json", + "--show-md", + "--old", + "--precision", + "10", + ] + ) + assert cli_args.func == CmdExperimentsDiff + + cmd = cli_args.func(cli_args) + m = mocker.patch("dvc.repo.experiments.diff.diff", return_value={}) + + assert cmd.run() == 0 + + m.assert_called_once_with( + cmd.repo, a_rev="HEAD~10", b_rev="HEAD~1", all=True + ) + + +def test_experiments_show(dvc, mocker): + cli_args = parse_args( + [ + "experiments", + "show", + "--all-tags", + "--all-branches", + "--all-commits", + ] + ) + assert cli_args.func == CmdExperimentsShow + + cmd = cli_args.func(cli_args) + m = mocker.patch("dvc.repo.experiments.show.show", return_value={}) + + assert cmd.run() == 0 + + m.assert_called_once_with( + cmd.repo, all_tags=True, all_branches=True, all_commits=True + ) diff --git a/tests/unit/command/test_repro.py b/tests/unit/command/test_repro.py index d29c027661..fbcf1271fc 100644 --- a/tests/unit/command/test_repro.py +++ b/tests/unit/command/test_repro.py @@ -14,6 +14,7 @@ "single_item": False, "recursive": False, "force_downstream": False, + "experiment": False, }