Skip to content

Commit

Permalink
Use the number of shards option in DownloadConfig when generating dat…
Browse files Browse the repository at this point in the history
…a non-deterministically

PiperOrigin-RevId: 702640057
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Dec 4, 2024
1 parent 4d8506d commit 10401d3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
2 changes: 2 additions & 0 deletions tensorflow_datasets/core/split_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,14 @@ def _build_from_pcollection(
logging.info(
'`nondeterministic_order` is set to True, using NoShuffleBeamWriter'
)
num_shards = self._shard_config.num_shards if self._shard_config else None
beam_writer = writer_lib.NoShuffleBeamWriter(
serializer=serializer,
file_format=file_adapters.FileFormat.from_value(
filename_template.filetype_suffix
),
filename_template=filename_template,
num_shards=num_shards,
)
else:
logging.info('Deterministic ordering is enabled, using BeamWriter')
Expand Down
17 changes: 12 additions & 5 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ def __init__(
serializer: example_serializer.Serializer,
filename_template: naming.ShardedFileTemplate,
file_format: file_adapters.FileFormat,
num_shards: int | None = None,
):
"""Init BeamWriter.
Expand All @@ -740,6 +741,8 @@ def __init__(
serializer: class that can serialize examples.
filename_template: template to format sharded filenames.
file_format: the file format to use.
num_shards: the number of shards to use. If `None`, then the number of
shards is calculated automatically.
"""
self._original_state = dict(
serializer=serializer,
Expand All @@ -750,6 +753,7 @@ def __init__(
self._file_adapter = file_adapters.ADAPTER_FOR_FORMAT[self._file_format]
self._filename_template = filename_template
self._serializer = serializer
self._num_shards = num_shards

@functools.lru_cache()
def _get_counter(self, name: str, namespace: str = "BeamWriter"):
Expand All @@ -775,14 +779,17 @@ def _serialize_example(

def write_from_pcollection(self, examples_pcollection):
"""Returns PTransform to write (key, example) PCollection."""
return (
serialized_examples = (
examples_pcollection
| "Shuffle" >> beam.Reshuffle()
| "Serialize" >> beam.Map(self._serialize_example)
| "Write"
>> self._file_adapter.beam_sink(
filename_template=self._filename_template
)
)
if self._num_shards is not None:
serialized_examples = serialized_examples | "Reshard" >> beam.Reshuffle(
self._num_shards
)
return serialized_examples | "Write" >> self._file_adapter.beam_sink(
filename_template=self._filename_template
)

def finalize(self) -> tuple[list[int], int]:
Expand Down

0 comments on commit 10401d3

Please sign in to comment.