diff --git a/cached_path/_cached_path.py b/cached_path/_cached_path.py index 738e98e..4121f28 100644 --- a/cached_path/_cached_path.py +++ b/cached_path/_cached_path.py @@ -4,7 +4,7 @@ import tarfile import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile @@ -60,7 +60,7 @@ def _is_archive(file_path: PathOrStr, url_or_filename) -> bool: def cached_path( url_or_filename: PathOrStr, cache_dir: Optional[PathOrStr] = None, - extract_archive: bool = False, + extract_archive: Union[bool, Literal["auto"]] = "auto", force_extract: bool = False, quiet: bool = False, progress: Optional["Progress"] = None, @@ -104,12 +104,12 @@ def cached_path( cached_path("hf://epwalsh/bert-xsmall-dummy/pytorch_model.bin") - For paths or URLs that point to a tarfile or zipfile, you can append the path - to a specific file within the archive to the ``url_or_filename``, preceeded by a "!". - The archive will be automatically extracted (provided you set ``extract_archive`` to ``True``), + For paths or URLs that point to a TAR or ZIP file, you can append the path + to a specific file within the archive to the ``url_or_filename``, preceded by a "!". + The archive will be automatically extracted, returning the local path to the specific file. For example:: - cached_path("model.tar.gz!weights.th", extract_archive=True) + cached_path("model.tar.gz!weights.th") .. _epwalsh/bert-xsmall-dummy: https://huggingface.co/epwalsh/bert-xsmall-dummy @@ -126,7 +126,8 @@ def cached_path( extract_archive : If ``True``, then zip or tar.gz archives will be automatically extracted. - In which case the directory is returned. + If `"auto"`, then it's extracted only if it contains a "!". + In extracted, the directory is returned. force_extract : If ``True`` and the file is an archive file, it will be extracted regardless @@ -168,9 +169,15 @@ def cached_path( file_path: Path extraction_path: Optional[Path] = None etag: Optional[str] = None + exclamation_index = url_or_filename.find("!") + + if isinstance(extract_archive, str): + if extract_archive == "auto": + extract_archive = exclamation_index >= 0 + else: + raise ValueError(f"Invalid value for `extract_archive`: {extract_archive}") # If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here. - exclamation_index = url_or_filename.find("!") if extract_archive and exclamation_index >= 0: archive_path = url_or_filename[:exclamation_index] file_name = url_or_filename[exclamation_index + 1 :]