Skip to content

Commit

Permalink
Fix PermissionError on Windows CI (#6477)
Browse files Browse the repository at this point in the history
* Align DatasetDict with Dataset as context manager

* Fix test
  • Loading branch information
albertvillanova authored Dec 6, 2023
1 parent f772102 commit d78f070
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
11 changes: 11 additions & 0 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def _check_values_features(self):
f"All datasets in `DatasetDict` should have the same features but features for '{item_a[0]}' and '{item_b[0]}' don't match: {item_a[1].features} != {item_b[1].features}"
)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
# Here `del` is used to del the pyarrow tables. This properly closes the files used for memory mapped tables
for dataset in self.values():
if hasattr(dataset, "_data"):
del dataset._data
if hasattr(dataset, "_indices"):
del dataset._indices

def __getitem__(self, k) -> Dataset:
if isinstance(k, (str, NamedSplit)) or len(self) == 0:
return super().__getitem__(k)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,10 +1313,9 @@ def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_exte
@pytest.mark.integration
def test_loading_from_the_datasets_hub():
with tempfile.TemporaryDirectory() as tmp_dir:
dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=tmp_dir)
assert len(dataset["train"]) == 2
assert len(dataset["validation"]) == 3
del dataset
with load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=tmp_dir) as dataset:
assert len(dataset["train"]) == 2
assert len(dataset["validation"]) == 3


@pytest.mark.integration
Expand Down

0 comments on commit d78f070

Please sign in to comment.