Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved handling of glob when reading files #961

Merged
merged 4 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions fsspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def strip_protocol(urlpath):


def expand_paths_if_needed(paths, mode, num, fs, name_function):
"""Expand paths if they have a ``*`` in them.
"""Expand paths if they have a ``*`` in them (write mode) or any of ``*?[]``
in them (read mode).

:param paths: list of paths
mode: str
Expand All @@ -549,23 +550,32 @@ def expand_paths_if_needed(paths, mode, num, fs, name_function):
"""
expanded_paths = []
paths = list(paths)
if "w" in mode and sum([1 for p in paths if "*" in p]) > 1:
raise ValueError("When writing data, only one filename mask can be specified.")
elif "w" in mode:

if "w" in mode: # read mode
if sum([1 for p in paths if "*" in p]) > 1:
raise ValueError(
"When writing data, only one filename mask can be specified."
)
num = max(num, len(paths))
for curr_path in paths:
if "*" in curr_path:
if "w" in mode:

for curr_path in paths:
if "*" in curr_path:
# expand using name_function
expanded_paths.extend(_expand_paths(curr_path, name_function, num))
else:
expanded_paths.append(curr_path)
# if we generated more paths that asked for, trim the list
if len(expanded_paths) > num:
expanded_paths = expanded_paths[:num]

else: # read mode
for curr_path in paths:
if has_magic(curr_path):
# expand using glob
expanded_paths.extend(fs.glob(curr_path))
else:
expanded_paths.append(curr_path)
# if we generated more paths that asked for, trim the list
if "w" in mode and len(expanded_paths) > num:
expanded_paths = expanded_paths[:num]
else:
expanded_paths.append(curr_path)

return expanded_paths


Expand Down
28 changes: 26 additions & 2 deletions fsspec/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OpenFile,
OpenFiles,
_expand_paths,
expand_paths_if_needed,
get_compression,
open_files,
open_local,
Expand Down Expand Up @@ -45,6 +46,30 @@ def test_expand_paths(path, name_function, num, out):
assert _expand_paths(path, name_function, num) == out


@pytest.mark.parametrize(
"create_files, path, out",
[
[["apath"], "apath", ["apath"]],
[["apath1"], "apath*", ["apath1"]],
[["apath1", "apath2"], "apath*", ["apath1", "apath2"]],
[["apath1", "apath2"], "apath[1]", ["apath1"]],
[["apath1", "apath11"], "apath?", ["apath1"]],
],
)
def test_expand_paths_if_needed_in_read_mode(create_files, path, out):

d = str(tempfile.mkdtemp())
for f in create_files:
f = os.path.join(d, f)
open(f, "w").write("test")

path = os.path.join(d, path)

fs = fsspec.filesystem("file")
res = expand_paths_if_needed([path], "r", 0, fs, None)
assert [os.path.basename(p) for p in res] == out


def test_expand_error():
with pytest.raises(ValueError):
_expand_paths("*.*", None, 1)
Expand Down Expand Up @@ -172,8 +197,7 @@ def test_url_kwargs_chain(ftp_writable):
f.write(data)

with fsspec.open(
f"simplecache::ftp://{username}:{password}@{host}:{port}//afile",
"rb",
f"simplecache::ftp://{username}:{password}@{host}:{port}//afile", "rb"
) as f:
assert f.read() == data

Expand Down