Skip to content

Commit

Permalink
Merge branch 'main' into allow-polars-as-valid-output-type
Browse files Browse the repository at this point in the history
  • Loading branch information
psmyth94 authored May 31, 2024
2 parents e3ea2e6 + 456f790 commit 713f012
Show file tree
Hide file tree
Showing 17 changed files with 294 additions and 145 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
uv pip install --system -r additional-tests-requirements.txt --no-deps
- name: Install dependencies (latest versions)
if: ${{ matrix.deps_versions == 'deps-latest' }}
run: uv pip install --system --upgrade pyarrow "huggingface-hub<0.23.0" dill
run: uv pip install --system --upgrade pyarrow huggingface-hub dill
- name: Install dependencies (minimum versions)
if: ${{ matrix.deps_versions != 'deps-latest' }}
run: uv pip install --system pyarrow==12.0.0 huggingface-hub==0.21.2 transformers dill==0.3.1.1
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ known-first-party = ["datasets"]

[tool.pytest.ini_options]
# Test fails if a FutureWarning is thrown by `huggingface_hub`
filterwarnings = [
"error::FutureWarning:huggingface_hub*",
]
# Temporarily disabled because transformers 4.41.1 calls deprecated code from `huggingface_hub` that causes FutureWarning
# filterwarnings = [
# "error::FutureWarning:huggingface_hub*",
# ]
markers = [
"unit: unit test",
"integration: integration test",
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@
"multiprocess",
# to save datasets locally or on any filesystem
# minimum 2023.1.0 to support protocol=kwargs in fsspec's `open`, `get_fs_token_paths`, etc.: see https://github.com/fsspec/filesystem_spec/pull/1143
"fsspec[http]>=2023.1.0,<=2024.3.1",
"fsspec[http]>=2023.1.0,<=2024.5.0",
# for data streaming via http
"aiohttp",
# To get datasets from the Datasets Hub on huggingface.co
"huggingface-hub>=0.21.2,<0.23.0", # temporary pin: see https://github.com/huggingface/datasets/issues/6860
"huggingface-hub>=0.21.2",
# Utilities from PyPA to e.g., compare versions
"packaging",
# To parse YAML metadata from dataset cards
Expand Down
52 changes: 28 additions & 24 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from .utils.py_utils import glob_pattern_to_regex, string_to_dict


SingleOriginMetadata = Union[Tuple[str, str], Tuple[str], Tuple[()]]


SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN)


Expand Down Expand Up @@ -361,6 +364,7 @@ def resolve_pattern(
base_path (str): Base path to use when resolving relative paths.
allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters.
Returns:
List[str]: List of paths or URLs to the local or remote files that match the patterns.
"""
Expand Down Expand Up @@ -516,17 +520,17 @@ def get_metadata_patterns(
def _get_single_origin_metadata(
data_file: str,
download_config: Optional[DownloadConfig] = None,
) -> Tuple[str]:
) -> SingleOriginMetadata:
data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
fs, *_ = url_to_fs(data_file, **storage_options)
if isinstance(fs, HfFileSystem):
resolved_path = fs.resolve_path(data_file)
return (resolved_path.repo_id, resolved_path.revision)
return resolved_path.repo_id, resolved_path.revision
elif isinstance(fs, HTTPFileSystem) and data_file.startswith(config.HF_ENDPOINT):
hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
resolved_path = hffs.resolve_path(data_file)
return (resolved_path.repo_id, resolved_path.revision)
return resolved_path.repo_id, resolved_path.revision
info = fs.info(data_file)
# s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
for key in ["ETag", "etag", "mtime"]:
Expand All @@ -539,7 +543,7 @@ def _get_origin_metadata(
data_files: List[str],
download_config: Optional[DownloadConfig] = None,
max_workers: Optional[int] = None,
) -> Tuple[str]:
) -> List[SingleOriginMetadata]:
max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
return thread_map(
partial(_get_single_origin_metadata, download_config=download_config),
Expand All @@ -555,11 +559,11 @@ def _get_origin_metadata(
class DataFilesList(List[str]):
"""
List of data files (absolute local paths or URLs).
It has two construction methods given the user's data files patterns :
It has two construction methods given the user's data files patterns:
- ``from_hf_repo``: resolve patterns inside a dataset repository
- ``from_local_or_remote``: resolve patterns from a local path
Moreover DataFilesList has an additional attribute ``origin_metadata``.
Moreover, DataFilesList has an additional attribute ``origin_metadata``.
It can store:
- the last modified time of local files
- ETag of remote files
Expand All @@ -570,11 +574,11 @@ class DataFilesList(List[str]):
This is useful for caching Dataset objects that are obtained from a list of data files.
"""

def __init__(self, data_files: List[str], origin_metadata: List[Tuple[str]]):
def __init__(self, data_files: List[str], origin_metadata: List[SingleOriginMetadata]) -> None:
super().__init__(data_files)
self.origin_metadata = origin_metadata

def __add__(self, other):
def __add__(self, other: "DataFilesList") -> "DataFilesList":
return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata)

@classmethod
Expand Down Expand Up @@ -646,9 +650,9 @@ class DataFilesDict(Dict[str, DataFilesList]):
- ``from_hf_repo``: resolve patterns inside a dataset repository
- ``from_local_or_remote``: resolve patterns from a local path
Moreover each list is a DataFilesList. It is possible to hash the dictionary
Moreover, each list is a DataFilesList. It is possible to hash the dictionary
and get a different hash if and only if at least one file changed.
For more info, see ``DataFilesList``.
For more info, see [`DataFilesList`].
This is useful for caching Dataset objects that are obtained from a list of data files.
Expand All @@ -666,14 +670,14 @@ def from_local_or_remote(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesList.from_local_or_remote(
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_local_or_remote(
patterns_for_key,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
if not isinstance(patterns_for_key, DataFilesList)
else patterns_for_key
)
return out

Expand All @@ -689,15 +693,15 @@ def from_hf_repo(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesList.from_hf_repo(
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_hf_repo(
patterns_for_key,
dataset_info=dataset_info,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
if not isinstance(patterns_for_key, DataFilesList)
else patterns_for_key
)
return out

Expand All @@ -712,14 +716,14 @@ def from_patterns(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesList.from_patterns(
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_patterns(
patterns_for_key,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
if not isinstance(patterns_for_key, DataFilesList)
else patterns_for_key
)
return out

Expand Down Expand Up @@ -751,7 +755,7 @@ def __add__(self, other):
@classmethod
def from_patterns(
cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
) -> "DataFilesPatternsList":
return cls(patterns, [allowed_extensions] * len(patterns))

def resolve(
Expand All @@ -777,7 +781,7 @@ def resolve(
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
return DataFilesList(data_files, origin_metadata)

def filter_extensions(self, extensions: List[str]) -> "DataFilesList":
def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsList":
return DataFilesPatternsList(
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
)
Expand All @@ -795,12 +799,12 @@ def from_patterns(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesPatternsList.from_patterns(
patterns_for_key
if isinstance(patterns_for_key, DataFilesPatternsList)
else DataFilesPatternsList.from_patterns(
patterns_for_key,
allowed_extensions=allowed_extensions,
)
if not isinstance(patterns_for_key, DataFilesPatternsList)
else patterns_for_key
)
return out

Expand Down
16 changes: 6 additions & 10 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,21 +312,17 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
True,
)
elif config.TORCH_AVAILABLE and "torch" in sys.modules and isinstance(obj, torch.Tensor):
if obj.dtype == torch.bfloat16:
return _cast_to_python_objects(
obj.detach().to(torch.float).cpu().numpy(),
only_1d_for_numpy=only_1d_for_numpy,
optimize_list_casting=optimize_list_casting,
)[0], True
if obj.ndim == 0:
return obj.detach().cpu().numpy()[()], True
elif not only_1d_for_numpy or obj.ndim == 1:
return obj.detach().cpu().numpy(), True
else:
if obj.dtype == torch.bfloat16:
return (
[
_cast_to_python_objects(
x, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for x in obj.detach().to(torch.float).cpu().numpy()
],
True,
)
return (
[
_cast_to_python_objects(
Expand Down
9 changes: 1 addition & 8 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,13 @@ def _split_generators(self, dl_manager):
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
files = [files]
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
files = [dl_manager.iter_files(file) for file in files]
# Infer features is they are stoed in the arrow schema
# Infer features if they are stored in the arrow schema
if self.info.features is None:
for file in itertools.chain.from_iterable(files):
with open(file, "rb") as f:
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/packaged_modules/audiofolder/audiofolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class AudioFolder(folder_based_builder.FolderBasedBuilder):
#
# AUDIO_EXTENSIONS = [f".{format.lower()}" for format in sf.available_formats().keys()]
#
# # .mp3 is currently decoded via `torchaudio`, .opus decoding is supported if version of `libsndfile` >= 1.0.30:
# AUDIO_EXTENSIONS.extend([".mp3", ".opus"])
# # .opus decoding is supported if libsndfile >= 1.0.31:
# AUDIO_EXTENSIONS.extend([".opus"])
# ```
# We intentionally do not run this code on launch because:
# (1) Soundfile is an optional dependency, so importing it in global namespace is not allowed
Expand Down
6 changes: 0 additions & 6 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,6 @@ def _split_generators(self, dl_manager):
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
Expand Down
Loading

0 comments on commit 713f012

Please sign in to comment.