diff --git a/tensorflow_datasets/core/file_adapters.py b/tensorflow_datasets/core/file_adapters.py index 603a21da2a3..b443cf13fb5 100644 --- a/tensorflow_datasets/core/file_adapters.py +++ b/tensorflow_datasets/core/file_adapters.py @@ -27,12 +27,12 @@ from etils import epy from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam +from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf - with epy.lazy_imports(): # pylint: disable=g-import-not-at-top from etils import epath @@ -325,6 +325,14 @@ def write_examples( writer.write(serialized_example) writer.close() + @classmethod + def num_examples(cls, filename: epath.PathLike) -> int: + """Returns the number of examples in the given file.""" + data_source = array_record_data_source.ArrayRecordDataSource( + paths=[os.fspath(filename)] + ) + return len(data_source) + class ParquetFileAdapter(FileAdapter): """File adapter for the [Parquet](https://parquet.apache.org) file format. diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index 5f4690a1b7c..d6764ea86b4 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -18,6 +18,7 @@ from __future__ import annotations from collections.abc import Iterable, Iterator, Sequence +import concurrent.futures import dataclasses import functools import itertools @@ -802,16 +803,21 @@ def finalize(self) -> tuple[list[int], int]: logging.info("Finalizing writer for %s", self._filename_template.split) # We don't know the number of shards, the length of each shard, nor the # total size, so we compute them here. - length_per_shard = {} - total_size_bytes = 0 prefix = epath.Path(self._filename_template.filepath_prefix()) - for shard in self._filename_template.data_dir.glob(f"{prefix.name}*"): + shards = self._filename_template.data_dir.glob(f"{prefix.name}*") + + def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]: length = self._file_adapter.num_examples(shard) - length_per_shard[shard] = length - total_size_bytes += shard.stat().length - shard_lengths: list[int] = [] - for _, length in sorted(length_per_shard.items()): - shard_lengths.append(length) + size = shard.stat().length + return shard, length, size + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + shard_sizes = executor.map(_get_length_and_size, shards) + + shard_sizes = sorted(shard_sizes, key=lambda x: x[0]) + shard_lengths: list[int] = [x[1] for x in shard_sizes] + total_size_bytes: int = sum([x[2] for x in shard_sizes]) + logging.info( "Found %d shards with a total size of %d bytes.", len(shard_lengths),