Skip to content

Commit

Permalink
Add method to get the split info from a dataset info proto.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 664803703
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Aug 19, 2024
1 parent fa4eda5 commit 62c9456
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
53 changes: 49 additions & 4 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,17 +1272,62 @@ def update_info_proto_with_features(
return completed_info_proto


def available_file_formats(
dataset_info_proto: dataset_info_pb2.DatasetInfo,
) -> set[str]:
"""Returns the available file formats for the given dataset."""
return set(
[dataset_info_proto.file_format]
+ list(dataset_info_proto.alternative_file_formats)
)


def supports_file_format(
dataset_info_proto: dataset_info_pb2.DatasetInfo,
file_format: str | file_adapters.FileFormat,
) -> bool:
"""Returns whether the given file format is supported by the dataset."""
if isinstance(file_format, file_adapters.FileFormat):
file_format = file_format.value
return (
file_format == dataset_info_proto.file_format
or file_format in dataset_info_proto.alternative_file_formats
)
return file_format in available_file_formats(dataset_info_proto)


def get_split_info_from_proto(
dataset_info_proto: dataset_info_pb2.DatasetInfo,
split_name: str,
data_dir: epath.PathLike,
file_format: file_adapters.FileFormat,
) -> splits_lib.SplitInfo | None:
"""Returns split info from the given dataset info proto.
Args:
dataset_info_proto: the proto with the dataset info.
split_name: the split for which to retrieve info for.
data_dir: the directory where the data is stored.
file_format: the file format for which to get the split info.
"""
if not supports_file_format(dataset_info_proto, file_format):
available_format = available_file_formats(dataset_info_proto)
raise ValueError(
f"File format {file_format.value} does not match available dataset file"
f" formats: {sorted(available_format)}."
)
for split_info in dataset_info_proto.splits:
if split_info.name == split_name:
filename_template = naming.ShardedFileTemplate(
dataset_name=dataset_info_proto.name,
data_dir=epath.Path(data_dir),
filetype_suffix=file_format.file_suffix,
)
# Override the default file name template if it was set.
if split_info.filepath_template:
filename_template = filename_template.replace(
template=split_info.filepath_template
)
return splits_lib.SplitInfo.from_proto(
proto=split_info, filename_template=filename_template
)
return None


class MetadataDict(Metadata, dict):
Expand Down
76 changes: 76 additions & 0 deletions tensorflow_datasets/core/dataset_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import pathlib
import re
import tempfile
import time
from typing import Union
Expand Down Expand Up @@ -698,6 +699,81 @@ def test_supports_file_format():
)


class GetSplitInfoFromProtoTest(testing.TestCase):

def _dataset_info_proto_with_splits(self):
return dataset_info_pb2.DatasetInfo(
name="dataset",
file_format="tfrecord",
alternative_file_formats=["riegeli"],
splits=[
dataset_info_pb2.SplitInfo(
name="train",
shard_lengths=[1, 2, 3],
num_bytes=42,
),
dataset_info_pb2.SplitInfo(
name="test",
shard_lengths=[1, 2, 3],
num_bytes=42,
filepath_template="{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}",
),
],
)

def test_get_split_info_from_proto_undefined_filename_template(self):
actual = dataset_info.get_split_info_from_proto(
dataset_info_proto=self._dataset_info_proto_with_splits(),
split_name="train",
data_dir="/path/to/data",
file_format=file_adapters.FileFormat.TFRECORD,
)
assert actual.name == "train"
assert actual.shard_lengths == [1, 2, 3]
assert actual.num_bytes == 42
assert actual.filename_template.dataset_name == "dataset"
assert actual.filename_template.template == naming.DEFAULT_FILENAME_TEMPLATE

def test_get_split_info_from_proto_defined_filename_template(self):
actual = dataset_info.get_split_info_from_proto(
dataset_info_proto=self._dataset_info_proto_with_splits(),
split_name="test",
data_dir="/path/to/data",
file_format=file_adapters.FileFormat.TFRECORD,
)
assert actual.name == "test"
assert actual.shard_lengths == [1, 2, 3]
assert actual.filename_template.dataset_name == "dataset"
assert (
actual.filename_template.template
== "{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}"
)

def test_get_split_info_from_proto_non_existing_split(self):
actual = dataset_info.get_split_info_from_proto(
dataset_info_proto=self._dataset_info_proto_with_splits(),
split_name="undefined",
data_dir="/path/to/data",
file_format=file_adapters.FileFormat.TFRECORD,
)
assert actual is None

def test_get_split_info_from_proto_unavailable_format(self):
with pytest.raises(
ValueError,
match=re.escape(
"File format parquet does not match available dataset file formats:"
" ['riegeli', 'tfrecord']."
),
):
dataset_info.get_split_info_from_proto(
dataset_info_proto=self._dataset_info_proto_with_splits(),
split_name="undefined",
data_dir="/path/to/data",
file_format=file_adapters.FileFormat.PARQUET,
)


# pylint: disable=g-inconsistent-quotes
_INFO_STR = '''tfds.core.DatasetInfo(
name='mnist',
Expand Down

0 comments on commit 62c9456

Please sign in to comment.