diff --git a/my/core/common.py b/my/core/common.py index 582708a5..b99b339d 100644 --- a/my/core/common.py +++ b/my/core/common.py @@ -157,21 +157,21 @@ def caller() -> str: for src in sources: if src.parts[0] == '~': src = src.expanduser() - if src.is_dir(): + # note: glob handled first, because e.g. on Windows asterisk makes is_dir unhappy + gs = str(src) + if '*' in gs: + if glob != DEFAULT_GLOB: + warnings.warn(f"{caller()}: treating {gs} as glob path. Explicit glob={glob} argument is ignored!") + paths.extend(map(Path, do_glob(gs))) + elif src.is_dir(): gp: Iterable[Path] = src.glob(glob) # todo not sure if should be recursive? paths.extend(gp) else: - ss = str(src) - if '*' in ss: - if glob != DEFAULT_GLOB: - warnings.warn(f"{caller()}: treating {ss} as glob path. Explicit glob={glob} argument is ignored!") - paths.extend(map(Path, do_glob(ss))) - else: - if not src.is_file(): - # todo not sure, might be race condition? - raise RuntimeError(f"Expected '{src}' to exist") - # todo assert matches glob?? - paths.append(src) + if not src.is_file(): + # todo not sure, might be race condition? + raise RuntimeError(f"Expected '{src}' to exist") + # todo assert matches glob?? + paths.append(src) if sort: paths = list(sorted(paths)) diff --git a/my/core/compat.py b/my/core/compat.py index 3a97242e..eb973faf 100644 --- a/my/core/compat.py +++ b/my/core/compat.py @@ -60,3 +60,7 @@ def _get_dal(cfg, module_name: str): from typing import Union # erm.. I guess as long as it's not crashing, whatever... Literal = Union + + +import os +windows = os.name == 'nt' diff --git a/tests/get_files.py b/tests/get_files.py index aa71e7b6..a81b34f7 100644 --- a/tests/get_files.py +++ b/tests/get_files.py @@ -1,10 +1,37 @@ +import os from pathlib import Path -from my.common import get_files +from typing import TYPE_CHECKING + +from my.core.compat import windows +from my.core.common import get_files import pytest # type: ignore -def test_single_file(): + # hack to replace all /tmp with 'real' tmp dir + # not ideal, but makes tests more concise +def _get_files(x, *args, **kwargs): + import my.core.common as C + def repl(x): + if isinstance(x, str): + return x.replace('/tmp', TMP) + elif isinstance(x, Path): + assert x.parts[:2] == (os.sep, 'tmp') # meh + return Path(TMP) / Path(*x.parts[2:]) + else: + # iterable? + return [repl(i) for i in x] + + x = repl(x) + res = C.get_files(x, *args, **kwargs) + return tuple(Path(str(i).replace(TMP, '/tmp')) for i in res) # hack back for asserts.. + + +if not TYPE_CHECKING: + get_files = _get_files + + +def test_single_file() -> None: ''' Regular file path is just returned as is. ''' @@ -27,12 +54,13 @@ def test_single_file(): "if the path starts with ~, we expand it" - assert get_files('~/.bashrc') == ( - Path('~').expanduser() / '.bashrc', - ) + if not windows: # windows dowsn't have bashrc.. ugh + assert get_files('~/.bashrc') == ( + Path('~').expanduser() / '.bashrc', + ) -def test_multiple_files(): +def test_multiple_files() -> None: ''' If you pass a directory/multiple directories, it flattens the contents ''' @@ -57,7 +85,7 @@ def test_multiple_files(): ) -def test_explicit_glob(): +def test_explicit_glob() -> None: ''' You can pass a glob to restrict the extensions ''' @@ -78,7 +106,7 @@ def test_explicit_glob(): assert get_files('/tmp/hpi_test', glob='file_*.zip') == expected -def test_implicit_glob(): +def test_implicit_glob() -> None: ''' Asterisc in the path results in globing too. ''' @@ -98,7 +126,7 @@ def test_implicit_glob(): ) -def test_no_files(): +def test_no_files() -> None: ''' Test for empty matches. They work, but should result in warning ''' @@ -112,7 +140,10 @@ def test_no_files(): # TODO not sure if should uniquify if the filenames end up same? # TODO not sure about the symlinks? and hidden files? -test_path = Path('/tmp/hpi_test') +import tempfile +TMP = tempfile.gettempdir() +test_path = Path(TMP) / 'hpi_test' + def setup(): teardown() test_path.mkdir() @@ -125,6 +156,8 @@ def teardown(): def create(f: str) -> None: + # in test body easier to use /tmp regardless the OS... + f = f.replace('/tmp', TMP) if f.endswith('/'): Path(f).mkdir() else: