Skip to content

Commit

Permalink
Compute shard lengths and size in parallel in NoShuffleBeamWriter
Browse files Browse the repository at this point in the history
Also add a more efficient way to compute the length for array record.

PiperOrigin-RevId: 702676333
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Dec 4, 2024
1 parent aef1fdc commit eaefd56
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
10 changes: 9 additions & 1 deletion tensorflow_datasets/core/file_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@

from etils import epy
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf


with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from etils import epath
Expand Down Expand Up @@ -325,6 +325,14 @@ def write_examples(
writer.write(serialized_example)
writer.close()

@classmethod
def num_examples(cls, filename: epath.PathLike) -> int:
"""Returns the number of examples in the given file."""
data_source = array_record_data_source.ArrayRecordDataSource(
paths=[os.fspath(filename)]
)
return len(data_source)


class ParquetFileAdapter(FileAdapter):
"""File adapter for the [Parquet](https://parquet.apache.org) file format.
Expand Down
22 changes: 14 additions & 8 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from collections.abc import Iterable, Iterator, Sequence
import concurrent.futures
import dataclasses
import functools
import itertools
Expand Down Expand Up @@ -802,16 +803,21 @@ def finalize(self) -> tuple[list[int], int]:
logging.info("Finalizing writer for %s", self._filename_template.split)
# We don't know the number of shards, the length of each shard, nor the
# total size, so we compute them here.
length_per_shard = {}
total_size_bytes = 0
prefix = epath.Path(self._filename_template.filepath_prefix())
for shard in self._filename_template.data_dir.glob(f"{prefix.name}*"):
shards = self._filename_template.data_dir.glob(f"{prefix.name}*")

def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]:
length = self._file_adapter.num_examples(shard)
length_per_shard[shard] = length
total_size_bytes += shard.stat().length
shard_lengths: list[int] = []
for _, length in sorted(length_per_shard.items()):
shard_lengths.append(length)
size = shard.stat().length
return shard, length, size

with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
shard_sizes = executor.map(_get_length_and_size, shards)

shard_sizes = sorted(shard_sizes, key=lambda x: x[0])
shard_lengths: list[int] = [x[1] for x in shard_sizes]
total_size_bytes: int = sum([x[2] for x in shard_sizes])

logging.info(
"Found %d shards with a total size of %d bytes.",
len(shard_lengths),
Expand Down

0 comments on commit eaefd56

Please sign in to comment.