Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp init: fixes #7534; simplifies/updates exp init --live #7703

Merged
merged 1 commit into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions dvc/commands/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ class CmdExperimentsInit(CmdBase):
DEFAULT_METRICS = "metrics.json"
DEFAULT_PARAMS = "params.yaml"
PLOTS = "plots"
DVCLIVE = "dvclive"
DEFAULTS = {
"code": CODE,
"data": DATA,
"models": MODELS,
"metrics": DEFAULT_METRICS,
"params": DEFAULT_PARAMS,
"plots": PLOTS,
"live": DVCLIVE,
}

def run(self):
Expand Down Expand Up @@ -190,12 +188,11 @@ def add_parser(experiments_subparsers, parent_parser):
)
experiments_init_parser.add_argument(
"--live",
help="Path to log dvclive outputs for your experiments"
f" (default: {CmdExperimentsInit.DVCLIVE})",
help="Path to log dvclive outputs for your experiments",
)
experiments_init_parser.add_argument(
"--type",
choices=["default", "dl"],
choices=["default", "checkpoint"],
Comment on lines 194 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs doc update 🙂 (at least in exp init --type ref.)

default="default",
help="Select type of stage to create (default: %(default)s)",
)
Expand Down
25 changes: 11 additions & 14 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"params": "Path to a [b]parameters[/b] file",
"metrics": "Path to a [b]metrics[/b] file",
"plots": "Path to a [b]plots[/b] file/directory",
"live": "Path to log [b]dvclive[/b] outputs",
}


Expand Down Expand Up @@ -79,15 +78,14 @@ def init_interactive(
defaults: Dict[str, str],
provided: Dict[str, str],
validator: Callable[[str, str], Union[str, Tuple[str, str]]] = None,
live: bool = False,
stream: Optional[TextIO] = None,
) -> Dict[str, str]:
command_prompts = lremove(provided.keys(), ["cmd"])
dependencies_prompts = lremove(provided.keys(), ["code", "data", "params"])
outputs_prompts = lremove(
provided.keys(),
["models"] + (["live"] if live else ["metrics", "plots"]),
)
output_keys = ["models"]
if "live" not in provided:
output_keys.extend(["metrics", "plots"])
outputs_prompts = lremove(provided.keys(), output_keys)

ret: Dict[str, str] = {}
if "cmd" in provided:
Expand Down Expand Up @@ -200,21 +198,16 @@ def init(
defaults = defaults.copy() if defaults else {}
overrides = overrides.copy() if overrides else {}

with_live = type == "dl"

if interactive:
defaults = init_interactive(
validator=partial(validate_prompts, repo),
defaults=defaults,
live=with_live,
provided=overrides,
stream=stream,
)
else:
if with_live:
# suppress `metrics`/`plots` if live is selected, unless
# it is also provided via overrides/cli.
# This makes output to be a checkpoint as well.
if "live" in overrides:
# suppress `metrics`/`plots` if live is selected.
defaults.pop("metrics", None)
defaults.pop("plots", None)
else:
Expand Down Expand Up @@ -251,7 +244,11 @@ def init(
metrics_no_cache=compact([context.get("metrics"), live_metrics]),
plots_no_cache=compact([context.get("plots"), live_plots]),
force=force,
**{"checkpoints" if with_live else "outs": compact([models])},
**{
"checkpoints"
if type == "checkpoint"
else "outs": compact([models])
},
)

with _disable_logging(), repo.scm_context(autostage=True, quiet=True):
Expand Down
69 changes: 50 additions & 19 deletions tests/func/experiments/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_init_simple(tmp_dir, scm, dvc, capsys):
CmdExperimentsInit.CODE: {"copy.py": ""},
"data": "data",
"params.yaml": '{"foo": 1}',
"dvclive": {},
"plots": {},
}
)
Expand Down Expand Up @@ -137,13 +136,11 @@ def test_init_interactive_when_no_path_prompts_need_to_be_asked(
"cmd": "cmd",
"deps": ["data", "src"],
"metrics": [
{"dvclive.json": {"cache": False}},
{"metrics.json": {"cache": False}},
],
"outs": ["models"],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("dvclive", "scalars"): {"cache": False}},
{"plots": {"cache": False}},
],
}
Expand Down Expand Up @@ -313,26 +310,18 @@ def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys):
"data\n"
"params.yaml\n"
"models\n"
"dvclive\n"
"y"
),
),
(
True,
{"cmd": "python script.py"},
io.StringIO(
"script.py\n"
"data\n"
"params.yaml\n"
"models\n"
"dvclive\n"
"y"
),
io.StringIO("script.py\n" "data\n" "params.yaml\n" "models\n" "y"),
),
(
True,
{"cmd": "python script.py", "models": "models"},
io.StringIO("script.py\ndata\nparams.yaml\ndvclive\ny"),
io.StringIO("script.py\ndata\nparams.yaml\ny"),
),
],
ids=[
Expand All @@ -345,11 +334,12 @@ def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys):
def test_init_interactive_live(
tmp_dir, scm, dvc, interactive, overrides, inp, capsys
):
overrides["live"] = "dvclive"

(tmp_dir / "params.yaml").dump({"foo": {"bar": 1}})

init(
dvc,
type="dl",
interactive=interactive,
defaults=CmdExperimentsInit.DEFAULTS,
overrides=overrides,
Expand All @@ -361,7 +351,7 @@ def test_init_interactive_live(
"cmd": "python script.py",
"deps": ["data", "script.py"],
"metrics": [{"dvclive.json": {"cache": False}}],
"outs": [{"models": {"checkpoint": True}}],
"outs": ["models"],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("dvclive", "scalars"): {"cache": False}}
Expand Down Expand Up @@ -393,13 +383,13 @@ def test_init_interactive_live(
(True, io.StringIO()),
],
)
def test_init_with_type_live_and_models_plots_provided(
def test_init_with_type_checkpoint_and_models_plots_provided(
tmp_dir, dvc, interactive, inp
):
(tmp_dir / "params.yaml").dump({"foo": 1})
init(
dvc,
type="dl",
type="checkpoint",
interactive=interactive,
stream=inp,
defaults=CmdExperimentsInit.DEFAULTS,
Expand All @@ -411,13 +401,11 @@ def test_init_with_type_live_and_models_plots_provided(
"cmd": "cmd",
"deps": ["data", "src"],
"metrics": [
{"dvclive.json": {"cache": False}},
{"m": {"cache": False}},
],
"outs": [{"models": {"checkpoint": True}}],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("dvclive", "scalars"): {"cache": False}},
{"p": {"cache": False}},
],
}
Expand Down Expand Up @@ -445,6 +433,49 @@ def test_init_with_type_default_and_live_provided(
defaults=CmdExperimentsInit.DEFAULTS,
overrides={"cmd": "cmd", "live": "live"},
)
assert (tmp_dir / "dvc.yaml").parse() == {
"stages": {
"train": {
"cmd": "cmd",
"deps": ["data", "src"],
"metrics": [
{"live.json": {"cache": False}},
],
"outs": ["models"],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("live", "scalars"): {"cache": False}},
],
}
}
}
assert (tmp_dir / "src").is_dir()
assert (tmp_dir / "data").is_dir()


@pytest.mark.parametrize(
"interactive, inp",
[
(False, None),
(True, io.StringIO()),
],
)
def test_init_with_live_and_metrics_plots_provided(
tmp_dir, dvc, interactive, inp
):
(tmp_dir / "params.yaml").dump({"foo": 1})
init(
dvc,
interactive=interactive,
stream=inp,
defaults=CmdExperimentsInit.DEFAULTS,
overrides={
"cmd": "cmd",
"live": "live",
"metrics": "metrics.json",
"plots": "plots",
},
)
assert (tmp_dir / "dvc.yaml").parse() == {
"stages": {
"train": {
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ def test_experiments_init(dvc, scm, mocker, capsys, extra_args):
"metrics": "metrics.json",
"params": "params.yaml",
"plots": "plots",
"live": "dvclive",
},
overrides={"cmd": "cmd"},
interactive=False,
Expand Down Expand Up @@ -727,7 +726,6 @@ def test_experiments_init_config(dvc, scm, mocker):
"metrics": "metrics.json",
"params": "params.yaml",
"plots": "plots",
"live": "dvclive",
},
overrides={"cmd": "cmd"},
interactive=False,
Expand Down Expand Up @@ -782,11 +780,11 @@ def test_experiments_init_cmd_not_required_for_interactive_mode(dvc, mocker):
"extra_args, expected_kw",
[
(["--type", "default"], {"type": "default", "name": "train"}),
(["--type", "dl"], {"type": "dl", "name": "train"}),
(["--type", "checkpoint"], {"type": "checkpoint", "name": "train"}),
(["--force"], {"force": True, "name": "train"}),
(
["--name", "name", "--type", "dl"],
{"name": "name", "type": "dl"},
["--name", "name", "--type", "checkpoint"],
{"name": "name", "type": "checkpoint"},
),
(
[
Expand Down