Skip to content

Commit

Permalink
Add support for loading data in a specific file format
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686816079
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Dec 9, 2024
1 parent 4aa203e commit 97eed86
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 27 deletions.
51 changes: 41 additions & 10 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ def as_data_source(
*,
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
file_format: str | file_adapters.FileFormat | None = None,
) -> ListOrTreeOrElem[Sequence[Any]]:
"""Constructs an `ArrayRecordDataSource`.
Expand All @@ -833,6 +834,8 @@ def as_data_source(
the features. Decoding is only supported if the examples are tf
examples. Note that if the deserialize_method method is other than
PARSE_AND_DECODE, then the `decoders` argument is ignored.
file_format: if the dataset is stored in multiple file formats, then this
can be used to specify which format to use.
Returns:
`Sequence` if `split`,
Expand Down Expand Up @@ -868,22 +871,31 @@ def as_data_source(
"Dataset info file format is not set! For random access, one of the"
f" following formats is required: {random_access_formats_msg}"
)

suitable_formats = available_formats.intersection(random_access_formats)
if suitable_formats:
if not suitable_formats:
raise NotImplementedError(unsupported_format_msg)

if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)
if file_format not in suitable_formats:
raise ValueError(
f"Requested file format {file_format} is not available for this"
f" dataset. Available formats: {available_formats}"
)
chosen_format = file_format
else:
chosen_format = suitable_formats.pop()
logging.info(
"Found random access formats: %s. Chose to use %s. Overriding file"
" format in the dataset info.",
", ".join([f.name for f in suitable_formats]),
chosen_format,
)
# Change the dataset info to read from a random access format.
info.set_file_format(
chosen_format, override=True, override_if_initialized=True
)
else:
raise NotImplementedError(unsupported_format_msg)

# Change the dataset info to read from a random access format.
info.set_file_format(
chosen_format, override=True, override_if_initialized=True
)

# Create a dataset for each of the given splits
def build_single_data_source(split: str) -> Sequence[Any]:
Expand Down Expand Up @@ -924,6 +936,7 @@ def as_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
read_config: read_config_lib.ReadConfig | None = None,
as_supervised: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
):
# pylint: disable=line-too-long
"""Constructs a `tf.data.Dataset`.
Expand Down Expand Up @@ -993,6 +1006,9 @@ def as_dataset(
a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default, the returned
`tf.data.Dataset` will have a dictionary with all the features.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not
specified, the default file format is used.
Returns:
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
Expand Down Expand Up @@ -1026,6 +1042,7 @@ def as_dataset(
decoders=decoders,
read_config=read_config,
as_supervised=as_supervised,
file_format=file_format,
)
all_ds = tree.map_structure(build_single_dataset, split)
return all_ds
Expand All @@ -1038,18 +1055,23 @@ def _build_single_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
read_config: read_config_lib.ReadConfig,
as_supervised: bool,
file_format: str | file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
"""as_dataset for a single split."""
wants_full_dataset = batch_size == -1
if wants_full_dataset:
batch_size = self.info.splits.total_num_examples or sys.maxsize

if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)

# Build base dataset
ds = self._as_dataset(
split=split,
shuffle_files=shuffle_files,
decoders=decoders,
read_config=read_config,
file_format=file_format,
)
# Auto-cache small datasets which are small enough to fit in memory.
if self._should_cache_ds(
Expand Down Expand Up @@ -1235,6 +1257,7 @@ def _as_dataset(
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
read_config: read_config_lib.ReadConfig | None = None,
shuffle_files: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
"""Constructs a `tf.data.Dataset`.
Expand All @@ -1250,6 +1273,9 @@ def _as_dataset(
read_config: `tfds.ReadConfig`
shuffle_files: `bool`, whether to shuffle the input files. Optional,
defaults to `False`.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not
specified, the default file format is used.
Returns:
`tf.data.Dataset`
Expand Down Expand Up @@ -1487,6 +1513,8 @@ def __init__(

@functools.cached_property
def _example_specs(self):
if self.info.features is None:
raise ValueError("Features are not set!")
return self.info.features.get_serialized_info()

def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
Expand All @@ -1495,6 +1523,7 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
read_config: read_config_lib.ReadConfig,
shuffle_files: bool,
file_format: file_adapters.FileFormat | None = None,
) -> tf.data.Dataset:
# Partial decoding
# TODO(epot): Should be moved inside `features.decode_example`
Expand All @@ -1508,10 +1537,12 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
example_specs = self._example_specs
decoders = decoders # pylint: disable=self-assigning-variable

if features is None:
raise ValueError("Features are not set!")

reader = reader_lib.Reader(
self.data_dir,
example_specs=example_specs,
file_format=self.info.file_format,
file_format=file_format or self.info.file_format,
)
decode_fn = functools.partial(features.decode_example, decoders=decoders)
return reader.read(
Expand Down
17 changes: 12 additions & 5 deletions tensorflow_datasets/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _try_load_from_files_first(
**builder_kwargs: Any,
) -> bool:
"""Returns True if files should be used rather than code."""
if set(builder_kwargs) - {'version', 'config', 'data_dir'}:
if set(builder_kwargs) - {'version', 'config', 'data_dir', 'file_format'}:
return False # Has extra kwargs, requires original code.
elif builder_kwargs.get('version') == 'experimental_latest':
return False # Requested version requires original code
Expand Down Expand Up @@ -485,10 +485,13 @@ def _fetch_builder(
data_dir: epath.PathLike | None,
builder_kwargs: dict[str, Any] | None,
try_gcs: bool,
file_format: str | file_adapters.FileFormat | None = None,
) -> dataset_builder.DatasetBuilder:
"""Fetches the `tfds.core.DatasetBuilder` by name."""
if builder_kwargs is None:
builder_kwargs = {}
if file_format is not None:
builder_kwargs['file_format'] = file_format
return builder(name, data_dir=data_dir, try_gcs=try_gcs, **builder_kwargs)


Expand Down Expand Up @@ -529,6 +532,7 @@ def load(
download_and_prepare_kwargs: dict[str, Any] | None = None,
as_dataset_kwargs: dict[str, Any] | None = None,
try_gcs: bool = False,
file_format: str | file_adapters.FileFormat | None = None,
):
# pylint: disable=line-too-long
"""Loads the named dataset into a `tf.data.Dataset`.
Expand Down Expand Up @@ -636,6 +640,9 @@ def load(
fully bypass GCS, please use `try_gcs=False` and
`download_and_prepare_kwargs={'download_config':
tfds.core.download.DownloadConfig(try_download_gcs=False)})`.
file_format: if the dataset is stored in multiple file formats, then this
argument can be used to specify the file format to load. If not specified,
the default file format is used.
Returns:
ds: `tf.data.Dataset`, the dataset requested, or if `split` is None, a
Expand All @@ -648,10 +655,10 @@ def load(
Split-specific information is available in `ds_info.splits`.
""" # fmt: skip
dbuilder = _fetch_builder(
name,
data_dir,
builder_kwargs,
try_gcs,
name=name,
data_dir=data_dir,
builder_kwargs=builder_kwargs,
try_gcs=try_gcs,
)
_download_and_prepare_builder(dbuilder, download, download_and_prepare_kwargs)

Expand Down
23 changes: 21 additions & 2 deletions tensorflow_datasets/core/read_only_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from etils import etree
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import logging as tfds_logging
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import registered
Expand All @@ -57,6 +58,7 @@ def __init__(
builder_dir: epath.PathLike,
*,
info_proto: dataset_info_pb2.DatasetInfo | None = None,
file_format: str | file_adapters.FileFormat | None = None,
):
"""Constructor.
Expand All @@ -66,6 +68,8 @@ def __init__(
info_proto: DatasetInfo describing the name, config, etc of the requested
dataset. Note that this overwrites dataset info that may be present in
builder_dir.
file_format: The desired file format to use for the dataset. If not
specified, the file format in the DatasetInfo is used.
Raises:
FileNotFoundError: If the builder_dir does not exist.
Expand All @@ -74,6 +78,15 @@ def __init__(
if not info_proto:
info_proto = dataset_info.read_proto_from_builder_dir(builder_dir)
self._info_proto = info_proto
if file_format is not None:
file_format = file_adapters.FileFormat.from_value(file_format)
available_formats = set([self._info_proto.file_format])
available_formats.update(self._info_proto.alternative_file_formats)
if file_format.file_suffix not in available_formats:
raise ValueError(
f'File format {file_format.file_suffix} does not match the file'
f' formats in the DatasetInfo: {sorted(available_formats)}.'
)

self.name = info_proto.name
self.VERSION = version_lib.Version(info_proto.version) # pylint: disable=invalid-name
Expand All @@ -89,6 +102,7 @@ def __init__(
# __init__ will call _build_data_dir, _create_builder_config,
# _pick_version to set the data_dir, config, and version
super().__init__(
file_format=file_format,
data_dir=builder_dir,
config=builder_config,
version=info_proto.version,
Expand Down Expand Up @@ -154,6 +168,7 @@ def _download_and_prepare(self, **kwargs): # pylint: disable=arguments-differ

def builder_from_directory(
builder_dir: epath.PathLike,
file_format: str | file_adapters.FileFormat | None = None,
) -> dataset_builder.DatasetBuilder:
"""Loads a `tfds.core.DatasetBuilder` from the given generated dataset path.
Expand All @@ -171,11 +186,13 @@ def builder_from_directory(
Args:
builder_dir: Path of the directory containing the dataset to read ( e.g.
`~/tensorflow_datasets/mnist/3.0.0/`).
file_format: The desired file format to use for the dataset. If not
specified, the default file format in the DatasetInfo is used.
Returns:
builder: `tfds.core.DatasetBuilder`, builder for dataset at the given path.
"""
return ReadOnlyBuilder(builder_dir=builder_dir)
return ReadOnlyBuilder(builder_dir=builder_dir, file_format=file_format)


def builder_from_directories(
Expand Down Expand Up @@ -308,7 +325,8 @@ def builder_from_files(
f'and that it has been generated in: {data_dirs}. If the dataset has'
' configs, you might have to specify the config name.'
)
return builder_from_directory(builder_dir)
file_format = builder_kwargs.pop('file_format', None)
return builder_from_directory(builder_dir, file_format=file_format)


def _find_builder_dir(name: str, **builder_kwargs: Any) -> epath.Path | None:
Expand Down Expand Up @@ -339,6 +357,7 @@ def _find_builder_dir(name: str, **builder_kwargs: Any) -> epath.Path | None:
version = str(version) if version else None
config = builder_kwargs.pop('config', None)
data_dir = builder_kwargs.pop('data_dir', None)
_ = builder_kwargs.pop('file_format', None)

# Builder cannot be found if it uses:
# * namespace
Expand Down
16 changes: 8 additions & 8 deletions tensorflow_datasets/core/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import functools
import os
import re
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
from typing import Any, Callable, NamedTuple, Optional, Sequence

from absl import logging
import numpy as np
Expand All @@ -32,6 +32,7 @@
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import read_config as read_config_lib
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import lineage_log
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
from tensorflow_datasets.core.utils.lazy_imports_utils import tree

Expand Down Expand Up @@ -361,39 +362,36 @@ def _verify_read_config_for_ordered_dataset(
logging.warning(error_message)


class Reader(object):
class Reader:
"""Build a tf.data.Dataset object out of Instruction instance(s).
This class should not typically be exposed to the TFDS user.
"""

def __init__(
self,
path, # TODO(b/216427814) remove this as it isn't used anymore
example_specs,
file_format=file_adapters.DEFAULT_FILE_FORMAT,
):
"""Initializes Reader.
Args:
path (str): path where tfrecords are stored.
example_specs: spec to build ExampleParser.
file_format: file_adapters.FileFormat, format of the record files in which
the dataset will be read/written from.
"""
self._path = path
self._parser = example_parser.ExampleParser(example_specs)
self._file_format = file_format

def read(
self,
*,
instructions: Tree[splits_lib.SplitArg],
split_infos: List[splits_lib.SplitInfo],
split_infos: list[splits_lib.SplitInfo],
read_config: read_config_lib.ReadConfig,
shuffle_files: bool,
disable_shuffling: bool = False,
decode_fn: Optional[DecodeFn] = None,
decode_fn: DecodeFn | None = None,
) -> Tree[tf.data.Dataset]:
"""Returns tf.data.Dataset instance(s).
Expand All @@ -417,7 +415,9 @@ def read(

splits_dict = splits_lib.SplitDict(split_infos=split_infos)

def _read_instruction_to_ds(instruction):
def _read_instruction_to_ds(
instruction: splits_lib.SplitArg,
) -> tf.data.Dataset:
file_instructions = splits_dict[instruction].file_instructions
return self.read_files(
file_instructions,
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_datasets/core/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ class AbstractSplit(abc.ABC):
"""

@classmethod
def from_spec(cls, spec: SplitArg) -> 'AbstractSplit':
def from_spec(cls, spec: SplitArg) -> AbstractSplit:
"""Creates a ReadInstruction instance out of a string spec.
Args:
Expand Down Expand Up @@ -632,7 +632,7 @@ def to_absolute(self, split_infos) -> list[_AbsoluteInstruction]:
"""
raise NotImplementedError

def __add__(self, other: Union[str, 'AbstractSplit']) -> 'AbstractSplit':
def __add__(self, other: Union[str, AbstractSplit]) -> AbstractSplit:
"""Sum of 2 splits."""
if not isinstance(other, (str, AbstractSplit)):
raise TypeError(f'Adding split {self!r} with non-split value: {other!r}')
Expand Down

0 comments on commit 97eed86

Please sign in to comment.