diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 93539d7621..e81a622585 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -101,7 +101,7 @@ def transfer(self, source, jobs=None, remote=None, command=None): remote (dvc.remote.base.BaseRemote): optional remote to compare cache to. By default remote from core.remote config option is used. - command (bool): the command which is benefitting from this function + command (str): the command which is benefitting from this function (to be used for reporting better error messages). """ from dvc.tree import get_cloud_tree diff --git a/dvc/repo/add.py b/dvc/repo/add.py index 8862d6865f..eeee0160b7 100644 --- a/dvc/repo/add.py +++ b/dvc/repo/add.py @@ -41,9 +41,11 @@ def add( # noqa: C901 targets = ensure_list(targets) + to_cache = kwargs.get("out") and not to_remote invalid_opt = None - if to_remote: - message = "{option} can't be used with --to-remote" + if to_remote or to_cache: + message = "{option} can't be used with " + message += "--to-remote" if to_remote else "-o" if len(targets) != 1: invalid_opt = "multiple targets" elif no_commit: @@ -52,9 +54,7 @@ def add( # noqa: C901 invalid_opt = "--recursive option" else: message = "{option} can't be used without --to-remote" - if kwargs.get("out"): - invalid_opt = "--out" - elif kwargs.get("remote"): + if kwargs.get("remote"): invalid_opt = "--remote" elif kwargs.get("jobs"): invalid_opt = "--jobs" @@ -88,12 +88,7 @@ def add( # noqa: C901 ) stages = _create_stages( - repo, - sub_targets, - fname, - pbar=pbar, - to_remote=to_remote, - **kwargs, + repo, sub_targets, fname, pbar=pbar, **kwargs, ) try: @@ -125,6 +120,7 @@ def add( # noqa: C901 no_commit, pbar, to_remote, + to_cache, **kwargs, ) ) @@ -148,25 +144,35 @@ def add( # noqa: C901 def _process_stages( - repo, sub_targets, stages, no_commit, pbar, to_remote, **kwargs + repo, sub_targets, stages, no_commit, pbar, to_remote, to_cache, **kwargs ): link_failures = [] from dvc.dvcfile import Dvcfile from ..output.base import OutputDoesNotExistError - if to_remote: + if to_remote or to_cache: # Already verified in the add() - assert len(stages) == 1 - assert len(sub_targets) == 1 - - [stage] = stages - stage.outs[0].hash_info = repo.cloud.transfer( - sub_targets[0], - jobs=kwargs.get("jobs"), - remote=kwargs.get("remote"), - command="add", - ) + (stage,) = stages + (target,) = sub_targets + (out,) = stage.outs + + if to_remote: + out.hash_info = repo.cloud.transfer( + target, + jobs=kwargs.get("jobs"), + remote=kwargs.get("remote"), + command="add", + ) + else: + from dvc.tree import get_cloud_tree + + from_tree = get_cloud_tree(repo, url=target) + out.hash_info = out.cache.transfer( + from_tree, from_tree.path_info, jobs=kwargs.get("jobs"), + ) + out.checkout() + Dvcfile(repo, stage.path).dump(stage) return link_failures @@ -219,7 +225,6 @@ def _create_stages( repo, targets, fname, - to_remote=False, pbar=None, external=False, glob=False, @@ -238,8 +243,8 @@ def _create_stages( disable=len(expanded_targets) < LARGE_DIR_SIZE, unit="file", ): - if to_remote: - out = resolve_output(out, kwargs.get("out")) + if kwargs.get("out"): + out = resolve_output(out, kwargs["out"]) path, wdir, out = resolve_paths(repo, out) stage = create_stage( Stage, diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 77dcd9e3be..5d43bc697a 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -26,6 +26,7 @@ from dvc.output.base import OutputAlreadyTrackedError, OutputIsStageFileError from dvc.repo import Repo as DvcRepo from dvc.stage import Stage +from dvc.stage.exceptions import StagePathNotFoundError from dvc.system import System from dvc.tree.local import LocalTree from dvc.utils import LARGE_DIR_SIZE, file_md5, relpath @@ -1021,3 +1022,118 @@ def test_add_to_remote(tmp_dir, dvc, local_cloud, local_remote): def test_add_to_remote_invalid_combinations(dvc, invalid_opt, kwargs): with pytest.raises(InvalidArgumentError, match=invalid_opt): dvc.add(to_remote=True, **kwargs) + + +def test_add_to_cache_dir(tmp_dir, dvc, local_cloud): + local_cloud.gen({"data": {"foo": "foo", "bar": "bar"}}) + + (stage,) = dvc.add(str(local_cloud / "data"), out="data") + assert len(stage.deps) == 0 + assert len(stage.outs) == 1 + + data = tmp_dir / "data" + assert data.read_text() == {"foo": "foo", "bar": "bar"} + assert (tmp_dir / "data.dvc").exists() + + shutil.rmtree(data) + status = dvc.checkout(str(data)) + assert status["added"] == ["data" + os.sep] + assert data.read_text() == {"foo": "foo", "bar": "bar"} + + +def test_add_to_cache_file(tmp_dir, dvc, local_cloud): + local_cloud.gen("foo", "foo") + + (stage,) = dvc.add(str(local_cloud / "foo"), out="foo") + assert len(stage.deps) == 0 + assert len(stage.outs) == 1 + + foo = tmp_dir / "foo" + assert foo.read_text() == "foo" + assert (tmp_dir / "foo.dvc").exists() + + foo.unlink() + status = dvc.checkout(str(foo)) + assert status["added"] == ["foo"] + assert foo.read_text() == "foo" + + +def test_add_to_cache_different_name(tmp_dir, dvc, local_cloud): + local_cloud.gen({"data": {"foo": "foo", "bar": "bar"}}) + + dvc.add(str(local_cloud / "data"), out="not_data") + + not_data = tmp_dir / "not_data" + assert not_data.read_text() == {"foo": "foo", "bar": "bar"} + assert (tmp_dir / "not_data.dvc").exists() + + assert not (tmp_dir / "data").exists() + assert not (tmp_dir / "data.dvc").exists() + + shutil.rmtree(not_data) + dvc.checkout(str(not_data)) + assert not_data.read_text() == {"foo": "foo", "bar": "bar"} + assert not (tmp_dir / "data").exists() + + +def test_add_to_cache_not_exists(tmp_dir, dvc, local_cloud): + local_cloud.gen({"data": {"foo": "foo", "bar": "bar"}}) + + dest_dir = tmp_dir / "dir" / "that" / "does" / "not" / "exist" + with pytest.raises(StagePathNotFoundError): + dvc.add(str(local_cloud / "data"), out=str(dest_dir)) + + dest_dir.parent.mkdir(parents=True) + dvc.add(str(local_cloud / "data"), out=str(dest_dir)) + + assert dest_dir.read_text() == {"foo": "foo", "bar": "bar"} + assert dest_dir.with_suffix(".dvc").exists() + + +@pytest.mark.parametrize( + "invalid_opt, kwargs", + [ + ("multiple targets", {"targets": ["foo", "bar", "baz"]}), + ("--no-commit", {"targets": ["foo"], "no_commit": True}), + ("--recursive", {"targets": ["foo"], "recursive": True},), + ], +) +def test_add_to_cache_invalid_combinations(dvc, invalid_opt, kwargs): + with pytest.raises(InvalidArgumentError, match=invalid_opt): + dvc.add(out="bar", **kwargs) + + +@pytest.mark.parametrize( + "workspace", + [ + pytest.lazy_fixture("local_cloud"), + pytest.lazy_fixture("s3"), + pytest.lazy_fixture("gs"), + pytest.lazy_fixture("hdfs"), + pytest.param( + pytest.lazy_fixture("ssh"), + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + pytest.lazy_fixture("http"), + ], + indirect=True, +) +def test_add_to_cache_from_remote(tmp_dir, dvc, workspace): + workspace.gen("foo", "foo") + + url = "remote://workspace/foo" + dvc.add(url, out="foo") + + foo = tmp_dir / "foo" + assert foo.read_text() == "foo" + assert (tmp_dir / "foo.dvc").exists() + + # Change the contents of the remote location, in order to + # ensure it retrieves file from the cache and not re-fetches it + (workspace / "foo").write_text("bar") + + foo.unlink() + dvc.checkout(str(foo)) + assert foo.read_text() == "foo" diff --git a/tests/unit/command/test_add.py b/tests/unit/command/test_add.py index 1c4f5266af..4b09d5b5d1 100644 --- a/tests/unit/command/test_add.py +++ b/tests/unit/command/test_add.py @@ -87,11 +87,7 @@ def test_add_to_remote_invalid_combinations(mocker, caplog): expected_msg = "multiple targets can't be used with --to-remote" assert expected_msg in caplog.text - for option, value in ( - ("--remote", "remote"), - ("--out", "bar"), - ("--jobs", "4"), - ): + for option, value in (("--remote", "remote"), ("--jobs", "4")): cli_args = parse_args(["add", "foo", option, value]) cmd = cli_args.func(cli_args) @@ -99,3 +95,16 @@ def test_add_to_remote_invalid_combinations(mocker, caplog): assert cmd.run() == 1 expected_msg = f"{option} can't be used without --to-remote" assert expected_msg in caplog.text + + +def test_add_to_cache_invalid_combinations(mocker, caplog): + cli_args = parse_args( + ["add", "s3://bucket/foo", "s3://bucket/bar", "-o", "foo"] + ) + assert cli_args.func == CmdAdd + + cmd = cli_args.func(cli_args) + with caplog.at_level(logging.ERROR, logger="dvc"): + assert cmd.run() == 1 + expected_msg = "multiple targets can't be used with -o" + assert expected_msg in caplog.text