diff --git a/src/smbclient/shutil.py b/src/smbclient/shutil.py index edd18659..c9d7aefe 100644 --- a/src/smbclient/shutil.py +++ b/src/smbclient/shutil.py @@ -281,8 +281,6 @@ def copytree( source path and the destination path as arguments. By default copy() is used, but any function that supports the same signature (like copy()) can be used. - In this current form, copytree() only supports remote to remote copies over SMB, or remote to local copies. - :param src: The source directory to copy. :param dst: The destination directory to copy to. :param symlinks: Whether to attempt to copy a symlink from the source tree to the dest tree, if False the symlink @@ -295,7 +293,10 @@ def copytree( :param kwargs: Common arguments used to build the SMB Session for any UNC paths. :return: The dst path. """ - dir_entries = list(scandir(src, **kwargs)) + if is_remote_path(src): + dir_entries = list(scandir(src, **kwargs)) + else: + dir_entries = list(os.scandir(src)) if is_remote_path(dst): makedirs(dst, exist_ok=dirs_exist_ok, **kwargs) @@ -316,6 +317,9 @@ def copytree( try: if dir_entry.is_symlink(): + if not isinstance(dir_entry, SMBDirEntry): + raise AssertionError("copytree doesn't yet support symlinks for local to remote operations") + link_target = readlink(src_path, **kwargs) if symlinks: symlink(link_target, dst_path, **kwargs) diff --git a/tests/test_smbclient_shutil.py b/tests/test_smbclient_shutil.py index be23d8ad..5b8fd56a 100644 --- a/tests/test_smbclient_shutil.py +++ b/tests/test_smbclient_shutil.py @@ -1176,6 +1176,33 @@ def test_copytree_with_local_dst(smb_share, tmp_path): assert fd.read() == "file3.txt" +def test_copytree_with_local_src(smb_share, tmp_path): + src_dirname = str(tmp_path / "source") + dst_dirname = "%s\\target" % smb_share + + os.makedirs(os.path.join(src_dirname, "dir1", "subdir1")) + with open(os.path.join(src_dirname, "file1.txt"), mode="w") as fd: + fd.write("file1.txt") + with open(os.path.join(src_dirname, "dir1", "file2.txt"), mode="w") as fd: + fd.write("file2.txt") + with open(os.path.join(src_dirname, "dir1", "subdir1", "file3.txt"), mode="w") as fd: + fd.write("file3.txt") + + actual = copytree(src_dirname, dst_dirname) + assert actual == dst_dirname + + assert sorted(list(listdir(dst_dirname))) == ["dir1", "file1.txt"] + assert sorted(list(listdir("%s\\dir1" % dst_dirname))) == ["file2.txt", "subdir1"] + assert sorted(list(listdir("%s\\dir1\\subdir1" % dst_dirname))) == ["file3.txt"] + + with open_file("%s\\file1.txt" % dst_dirname) as fd: + assert fd.read() == "file1.txt" + with open_file("%s\\dir1\\file2.txt" % dst_dirname) as fd: + assert fd.read() == "file2.txt" + with open_file("%s\\dir1\\subdir1\\file3.txt" % dst_dirname) as fd: + assert fd.read() == "file3.txt" + + @pytest.mark.skipif( os.name != "nt" and not os.environ.get("SMB_FORCE", False), reason="Samba does not update timestamps" )