diff --git a/dvc/command/update.py b/dvc/command/update.py index c3ef60ec9c..718f08cdae 100644 --- a/dvc/command/update.py +++ b/dvc/command/update.py @@ -10,12 +10,15 @@ class CmdUpdate(CmdBase): def run(self): ret = 0 - for target in self.args.targets: - try: - self.repo.update(target, self.args.rev) - except DvcException: - logger.exception("failed to update '{}'.".format(target)) - ret = 1 + try: + self.repo.update( + targets=self.args.targets, + rev=self.args.rev, + recursive=self.args.recursive, + ) + except DvcException: + logger.exception("failed update data") + ret = 1 return ret @@ -37,4 +40,11 @@ def add_parser(subparsers, parent_parser): help="Git revision (e.g. SHA, branch, tag)", metavar="", ) + update_parser.add_argument( + "-R", + "--recursive", + action="store_true", + default=False, + help="Update all stages in the specified directory.", + ) update_parser.set_defaults(func=CmdUpdate) diff --git a/dvc/repo/update.py b/dvc/repo/update.py index 5ade233669..d0dac9a979 100644 --- a/dvc/repo/update.py +++ b/dvc/repo/update.py @@ -1,14 +1,23 @@ +from ..dvcfile import Dvcfile from . import locked @locked -def update(self, target, rev=None): - from ..dvcfile import Dvcfile +def update(self, targets, rev=None, recursive=False): + if not targets: + targets = [None] - dvcfile = Dvcfile(self, target) - stage = dvcfile.stage - stage.update(rev) + if isinstance(targets, str): + targets = [targets] - dvcfile.dump(stage) + stages = set() + for target in targets: + stages.update(self.collect(target, recursive=recursive)) - return stage + for stage in stages: + stage.update(rev) + dvcfile = Dvcfile(self, stage.path) + dvcfile.dump(stage) + stages.add(stage) + + return list(stages) diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 10652b14d1..bc6bbcbebf 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -956,12 +956,12 @@ def test(self, mock_prompt): self.assertNotEqual(self.dvc.status(), {}) - self.dvc.update(import_stage.path) + self.dvc.update([import_stage.path]) self.assertTrue(os.path.exists("import")) self.assertTrue(filecmp.cmp("import", self.BAR, shallow=False)) self.assertEqual(self.dvc.status([import_stage.path]), {}) - self.dvc.update(import_remote_stage.path) + self.dvc.update([import_remote_stage.path]) self.assertEqual(self.dvc.status([import_remote_stage.path]), {}) stages = self.dvc.reproduce(cmd_stage.addressing) diff --git a/tests/func/test_update.py b/tests/func/test_update.py index 2da66f3479..6a91df6ffd 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -26,7 +26,7 @@ def test_update_import(tmp_dir, dvc, erepo_dir, cached): assert old_rev != new_rev - dvc.update(stage.path) + dvc.update([stage.path]) assert (tmp_dir / "version").read_text() == "updated" stage = Dvcfile(dvc, stage.path).stage @@ -65,7 +65,7 @@ def test_update_import_after_remote_updates_to_dvc(tmp_dir, dvc, erepo_dir): assert changed_dep[0].startswith("version ") assert changed_dep[1] == "update available" - dvc.update(stage.path) + dvc.update([stage.path]) assert dvc.status([stage.path]) == {} @@ -106,7 +106,7 @@ def test_update_before_and_after_dvc_init(tmp_dir, dvc, git_dir): ] } - dvc.update(stage.path) + dvc.update([stage.path]) assert (tmp_dir / "file").read_text() == "second version" assert dvc.status([stage.path]) == {} @@ -127,7 +127,7 @@ def test_update_import_url(tmp_dir, dvc, tmp_path_factory): src.write_text("updated file content") assert dvc.status([stage.path]) == {} - dvc.update(stage.path) + dvc.update([stage.path]) assert dvc.status([stage.path]) == {} assert dst.is_file() @@ -149,7 +149,7 @@ def test_update_rev(tmp_dir, dvc, scm, git_dir): git_dir.scm_gen({"foo": "foobar foo"}, commit="branch2 commit") branch2_head = git_dir.scm.get_rev() - stage = dvc.update("foo.dvc", rev="branch1") + stage = dvc.update(["foo.dvc"], rev="branch1")[0] assert stage.deps[0].def_repo == { "url": fspath(git_dir), "rev": "branch1", @@ -158,7 +158,7 @@ def test_update_rev(tmp_dir, dvc, scm, git_dir): with open(fspath_py35(tmp_dir / "foo")) as f: assert "foobar" == f.read() - stage = dvc.update("foo.dvc", rev="branch2") + stage = dvc.update(["foo.dvc"], rev="branch2")[0] assert stage.deps[0].def_repo == { "url": fspath(git_dir), "rev": "branch2", @@ -166,3 +166,60 @@ def test_update_rev(tmp_dir, dvc, scm, git_dir): } with open(fspath_py35(tmp_dir / "foo")) as f: assert "foobar foo" == f.read() + + +def test_update_recursive(tmp_dir, dvc, erepo_dir): + with erepo_dir.branch("branch", new=True), erepo_dir.chdir(): + erepo_dir.scm_gen( + {"foo1": "text1", "foo2": "text2", "foo3": "text3"}, + commit="add foo files", + ) + old_rev = erepo_dir.scm.get_rev() + + tmp_dir.gen({"dir": {"subdir": {}}}) + stage1 = dvc.imp( + fspath(erepo_dir), "foo1", os.path.join("dir", "foo1"), rev="branch", + ) + stage2 = dvc.imp( + fspath(erepo_dir), + "foo2", + os.path.join("dir", "subdir", "foo2"), + rev="branch", + ) + stage3 = dvc.imp( + fspath(erepo_dir), + "foo3", + os.path.join("dir", "subdir", "foo3"), + rev="branch", + ) + + assert (tmp_dir / os.path.join("dir", "foo1")).read_text() == "text1" + assert ( + tmp_dir / os.path.join("dir", "subdir", "foo2") + ).read_text() == "text2" + assert ( + tmp_dir / os.path.join("dir", "subdir", "foo3") + ).read_text() == "text3" + + assert stage1.deps[0].def_repo["rev_lock"] == old_rev + assert stage2.deps[0].def_repo["rev_lock"] == old_rev + assert stage3.deps[0].def_repo["rev_lock"] == old_rev + + with erepo_dir.branch("branch", new=False), erepo_dir.chdir(): + erepo_dir.scm_gen( + {"foo1": "updated1", "foo2": "updated2", "foo3": "updated3"}, + "", + "update foo content", + ) + new_rev = erepo_dir.scm.get_rev() + + assert old_rev != new_rev + + dvc.update(["dir"], recursive=True) + + stage1 = Dvcfile(dvc, stage1.path).stage + stage2 = Dvcfile(dvc, stage2.path).stage + stage3 = Dvcfile(dvc, stage3.path).stage + assert stage1.deps[0].def_repo["rev_lock"] == new_rev + assert stage2.deps[0].def_repo["rev_lock"] == new_rev + assert stage3.deps[0].def_repo["rev_lock"] == new_rev diff --git a/tests/unit/command/test_update.py b/tests/unit/command/test_update.py index ff8e77942e..087ab44832 100644 --- a/tests/unit/command/test_update.py +++ b/tests/unit/command/test_update.py @@ -1,20 +1,17 @@ -import pytest - from dvc.cli import parse_args from dvc.command.update import CmdUpdate -@pytest.mark.parametrize( - "command,rev", [(["update"], None), (["update", "--rev", "REV"], "REV")] -) -def test_update(dvc, mocker, command, rev): - targets = ["target1", "target2", "target3"] - cli_args = parse_args(command + targets) +def test_update(dvc, mocker): + cli_args = parse_args( + ["update", "target1", "target2", "--rev", "REV", "--recursive"] + ) assert cli_args.func == CmdUpdate cmd = cli_args.func(cli_args) m = mocker.patch("dvc.repo.Repo.update") assert cmd.run() == 0 - calls = [mocker.call(target, rev) for target in targets] - m.assert_has_calls(calls) + m.assert_called_once_with( + targets=["target1", "target2"], rev="REV", recursive=True, + )