Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp init: track data by default #6914

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rich.rule import Rule
from rich.syntax import Syntax

from dvc.exceptions import DvcException
from dvc.exceptions import CacheLinkError, DvcException
from dvc.stage import PipelineStage
from dvc.stage.serialize import to_pipeline_file
from dvc.types import OptStr
Expand Down Expand Up @@ -181,6 +181,32 @@ def _check_stage_exists(
)


def add_data(repo, init_stage, data_source):
from dvc.repo.add import (
create_stages,
translate_graph_error,
warn_link_failures,
)

(data_stage,) = create_stages(repo, [data_source])
msg = "Collecting stages from the workspace"
stages = [init_stage, data_stage]
with translate_graph_error(stages), ui.status(msg) as status:
# remove existing stages that are to-be replaced with these
# new stages for the graph checks.
new_index = repo.index.update(stages)
status.update("Checking graph")
new_index.check_graph()

with warn_link_failures() as link_failures:
try:
data_stage.save()
data_stage.commit()
except CacheLinkError:
link_failures.append(str(data_stage.relpath))
data_stage.dump()


def init(
repo: "Repo",
name: str = None,
Expand Down Expand Up @@ -256,6 +282,19 @@ def init(
):
scm = repo.scm
with _disable_logging(), scm.track_file_changes(autostage=True):
if "data" in context:
from dvc.repo import lock_repo

data = context["data"]
with lock_repo(repo):
add_data(repo, stage, data)

if interactive:
ui.write()
ui.write(
f"Tracking '[green]{data}[/green]' dependency", styled=True
)

stage.dump(update_lock=False)
stage.ignore_outs()
scm.track_file(params)
Expand Down