Skip to content

Commit

Permalink
Simplify file_utils.list_dataset_variants().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707088515
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Dec 18, 2024
1 parent 6b93631 commit 687c7d9
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 181 deletions.
200 changes: 65 additions & 135 deletions tensorflow_datasets/core/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,113 +343,6 @@ def _find_files_with_glob(
yield from _find_files_without_glob(folder, globs, file_names)


def _find_references_with_glob(
folder: epath.Path,
is_data_dir: bool,
is_dataset_dir: bool,
namespace: str | None = None,
include_old_tfds_version: bool = True,
glob_suffixes: Sequence[str] = ('json',),
) -> Iterator[naming.DatasetReference]:
"""Yields all dataset references in the given folder.
Args:
folder: the folder where to look for datasets. Can be either a root data
dir, or a dataset folder.
is_data_dir: Whether `folder` is a root TFDS data dir.
is_dataset_dir: Whether `folder` is the folder of one specific dataset.
namespace: Optional namespace to which the found datasets belong to.
include_old_tfds_version: include datasets that have been generated with
TFDS before 4.0.0.
glob_suffixes: list of file suffixes to use to create the glob for
interesting TFDS files. Defaults to json files.
"""
if is_dataset_dir and is_data_dir:
raise ValueError('Folder cannot be both a data dir and dataset dir!')
if not is_data_dir and not is_dataset_dir:
raise ValueError('Folder must be either a data dir or a dataset dir!')

if is_data_dir:
data_dir = folder
dataset_name = None
stars = ['*/*/*/*', '*/*/*']
else:
data_dir = folder.parent
dataset_name = folder.name
stars = ['*/*/*', '*/*']

globs: list[str] = []
for star in stars:
if glob_suffixes:
globs.extend([f'{star}.{suffix}' for suffix in glob_suffixes])
else:
globs.append(star)

# Check files matching the globs and are files we are interested in.
matched_files_per_folder = collections.defaultdict(set)
for file in _find_files_with_glob(
folder,
globs=globs,
file_names=_INFO_FILE_NAMES,
):
matched_files_per_folder[file.parent].add(file.name)

for data_folder, matched_files in matched_files_per_folder.items():
if constants.DATASET_INFO_FILENAME not in matched_files:
logging.warning(
'Ignoring dataset folder %s, which has no dataset_info.json',
os.fspath(data_folder),
)
continue
if (
not include_old_tfds_version
and constants.FEATURES_FILENAME not in matched_files
):
logging.info(
'Ignoring dataset folder %s, which has no features.json',
os.fspath(data_folder),
)
continue

version = data_folder.name
if not version_lib.Version.is_valid(version):
logging.warning(
'Ignoring dataset folder %s, which has invalid version %s',
os.fspath(data_folder),
version,
)
continue

config = None
if is_data_dir:
if data_folder.parent.parent == folder:
dataset_name = data_folder.parent.name
elif data_folder.parent.parent.parent == folder:
dataset_name = data_folder.parent.parent.name
config = data_folder.parent.name
else:
raise ValueError(
f'Could not detect dataset and config from path {data_folder} in'
f' {folder}'
)
else:
if data_folder.parent != folder:
config = data_folder.parent.name

if not naming.is_valid_dataset_name(dataset_name):
logging.warning('Invalid dataset name: %s', dataset_name)
continue

yield naming.DatasetReference(
namespace=namespace,
data_dir=data_dir,
dataset_name=dataset_name,
config=config,
version=version,
info_filenames=matched_files,
)


def list_dataset_versions(
dataset_config_dir: epath.PathLike,
) -> list[version_lib.Version]:
Expand All @@ -476,45 +369,77 @@ def list_dataset_versions(


def list_dataset_variants(
dataset_dir: epath.PathLike,
dataset_dir: Path,
namespace: str | None = None,
include_versions: bool = True,
include_old_tfds_version: bool = False,
glob_suffixes: Sequence[str] = ('json',),
) -> Iterator[naming.DatasetReference]:
"""Yields all variants (config + version) found in `dataset_dir`.
Arguments:
Args:
dataset_dir: the folder of the dataset.
namespace: optional namespace to which this data dir belongs.
include_versions: whether to list what versions are available.
include_old_tfds_version: include datasets that have been generated with
TFDS before 4.0.0.
glob_suffixes: list of file suffixes to use to create the glob for
interesting TFDS files. Defaults to json files.
Yields:
all variants of the given dataset.
""" # fmt: skip
dataset_dir = epath.Path(dataset_dir)
references = {}
for reference in _find_references_with_glob(
folder=dataset_dir,
is_data_dir=False,
is_dataset_dir=True,
namespace=namespace,
include_old_tfds_version=include_old_tfds_version,
glob_suffixes=glob_suffixes,
data_dir = dataset_dir.parent
dataset_name = dataset_dir.name
globs = [
'*/*/*.json', # with nested config directory
'*/*.json', # without nested config directory
]

# Check files matching the globs and are files we are interested in.
matched_files_by_variant_dir = collections.defaultdict(set)
for file in _find_files_with_glob(
dataset_dir,
globs=globs,
file_names=_INFO_FILE_NAMES,
):
if include_versions:
key = f'{reference.dataset_name}/{reference.config}:{reference.version}'
else:
key = f'{reference.dataset_name}/{reference.config}'
reference = reference.replace(version=None)
references[key] = reference
matched_files_by_variant_dir[file.parent].add(file.name)

for variant_dir, matched_files in matched_files_by_variant_dir.items():
if constants.DATASET_INFO_FILENAME not in matched_files:
logging.warning(
'Ignoring variant folder %s, which has no %s',
variant_dir,
constants.DATASET_INFO_FILENAME,
)
continue

if (
not include_old_tfds_version
and constants.FEATURES_FILENAME not in matched_files
):
logging.info(
'Ignoring variant folder %s, which has no %s',
variant_dir,
constants.FEATURES_FILENAME,
)
continue

for reference in references.values():
yield reference
version = variant_dir.name
if not version_lib.Version.is_valid(version):
logging.warning(
'Ignoring variant folder %s, which has invalid version %s',
variant_dir,
version,
)
continue

config_dir = variant_dir.parent
config = config_dir.name if config_dir != dataset_dir else None

yield naming.DatasetReference(
namespace=namespace,
data_dir=data_dir,
dataset_name=dataset_name,
config=config,
version=version,
info_filenames=matched_files,
)


def list_datasets_in_data_dir(
Expand Down Expand Up @@ -547,22 +472,27 @@ def list_datasets_in_data_dir(
for dataset_dir in epath.Path(data_dir).iterdir():
if not dataset_dir.is_dir():
continue
if not naming.is_valid_dataset_name(dataset_dir.name):
dataset_name = dataset_dir.name
if not naming.is_valid_dataset_name(dataset_name):
logging.warning('Invalid dataset name: %s', dataset_name)
continue
num_datasets += 1
if include_configs:
for variant in list_dataset_variants(
dataset_dir=dataset_dir,
namespace=namespace,
include_versions=include_versions,
include_old_tfds_version=include_old_tfds_version,
):
num_variants += 1
yield variant
if include_versions:
yield variant
else:
yield variant.replace(version=None)
break
else:
num_variants += 1
yield naming.DatasetReference(
dataset_name=dataset_dir.name, namespace=namespace, data_dir=data_dir
dataset_name=dataset_name, namespace=namespace, data_dir=data_dir
)
logging.info(
'Found %d datasets and %d variants in %s',
Expand Down
46 changes: 1 addition & 45 deletions tensorflow_datasets/core/utils/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,13 @@ def test_list_dataset_variants_with_configs(mock_fs: testing.MockFs):
constants.FEATURES_FILENAME,
constants.DATASET_INFO_FILENAME,
}
glob_suffixes = [
'json',
]
for config, versions in configs_and_versions.items():
for version in versions:
for info_filename in info_filenames:
mock_fs.add_file(_DATASET_DIR / config / version / info_filename)

references = sorted(
file_utils.list_dataset_variants(
dataset_dir=_DATASET_DIR, glob_suffixes=glob_suffixes
)
file_utils.list_dataset_variants(dataset_dir=_DATASET_DIR)
)
assert references == [
naming.DatasetReference(
Expand All @@ -238,43 +233,6 @@ def test_list_dataset_variants_with_configs(mock_fs: testing.MockFs):
]


def test_list_dataset_variants_with_configs_no_versions(
mock_fs: testing.MockFs,
):
configs_and_versions = {
'x': [_VERSION, '1.0.1'],
'y': ['2.0.0'],
}
info_filenames = {
constants.DATASET_INFO_FILENAME,
constants.FEATURES_FILENAME,
}
for config, versions in configs_and_versions.items():
for version in versions:
for filename in info_filenames:
mock_fs.add_file(_DATASET_DIR / config / version / filename)

references = sorted(
file_utils.list_dataset_variants(
dataset_dir=_DATASET_DIR, include_versions=False
)
)
assert references == [
naming.DatasetReference(
dataset_name=_DATASET_NAME,
config='x',
data_dir=_DATA_DIR,
info_filenames=info_filenames,
),
naming.DatasetReference(
dataset_name=_DATASET_NAME,
config='y',
data_dir=_DATA_DIR,
info_filenames=info_filenames,
),
]


def test_list_dataset_variants_without_configs(mock_fs: testing.MockFs):
# Version 1.0.0 doesn't have features.json, because it was generated with an
# old version of TFDS.
Expand All @@ -286,7 +244,6 @@ def test_list_dataset_variants_without_configs(mock_fs: testing.MockFs):
references = sorted(
file_utils.list_dataset_variants(
dataset_dir=_DATASET_DIR,
include_versions=True,
include_old_tfds_version=True,
)
)
Expand All @@ -312,7 +269,6 @@ def test_list_dataset_variants_without_configs(mock_fs: testing.MockFs):
references = sorted(
file_utils.list_dataset_variants(
dataset_dir=_DATASET_DIR,
include_versions=True,
include_old_tfds_version=False,
)
)
Expand Down
1 change: 0 additions & 1 deletion tensorflow_datasets/scripts/cli/convert_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,6 @@ def convert_dataset_dir(

references = file_utils.list_dataset_variants(
dataset_dir=dataset_dir,
include_versions=True,
include_old_tfds_version=True,
)
from_to_dirs = _create_from_to_dirs(
Expand Down

0 comments on commit 687c7d9

Please sign in to comment.