From 9d6090d0d4b7a791538e456408a43b166fbb7f1b Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Wed, 4 Dec 2024 08:16:58 -0800 Subject: [PATCH] Simplify `ReadFromCroissant` by removing the pipeline argument and making it a PCollection. PiperOrigin-RevId: 702731977 --- .../dataset_builders/croissant_builder.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/tensorflow_datasets/core/dataset_builders/croissant_builder.py b/tensorflow_datasets/core/dataset_builders/croissant_builder.py index 500d2ba4b82..78742c5839c 100644 --- a/tensorflow_datasets/core/dataset_builders/croissant_builder.py +++ b/tensorflow_datasets/core/dataset_builders/croissant_builder.py @@ -288,6 +288,8 @@ def _split_generators( dl_manager: download.DownloadManager, pipeline: beam.Pipeline, ) -> dict[splits_lib.Split, split_builder_lib.SplitGenerator]: + del dl_manager # unused + del pipeline # unused # If a split recordset is joined for the required record set, we generate # splits accordingly. Otherwise, it generates a single `default` split with # all the records. @@ -302,7 +304,6 @@ def _split_generators( split_key = split_reference.reference_field.references.field return { split[split_key]: self._generate_examples( - pipeline=pipeline, filters={ **self._filters, split_reference.reference_field.id: split[split_key], @@ -311,21 +312,15 @@ def _split_generators( for split in split_reference.split_record_set.data } else: - return { - 'default': self._generate_examples( - pipeline=pipeline, filters=self._filters - ) - } + return {'default': self._generate_examples(filters=self._filters)} def _generate_examples( self, - pipeline: beam.Pipeline, filters: dict[str, Any], ) -> beam.PTransform: """Generates the examples for the given record set. Args: - pipeline: The Beam pipeline. filters: A dict of filters to apply to the records. The keys should be field names and the values should be the values to filter by. If a record matches all the filters, it will be included in the dataset. @@ -354,10 +349,12 @@ def convert_to_tfds_format( conversion_utils.to_tfds_value(record, features), ) - return records.beam_reader( - pipeline=pipeline - ) | f'Convert to TFDS format for filters: {json.dumps(filters)}' >> beam.MapTuple( - convert_to_tfds_format, - features=self.info.features, - record_set_id=record_set.id, + return ( + records.beam_reader() + | f'Convert to TFDS format for filters: {json.dumps(filters)}' + >> beam.MapTuple( + convert_to_tfds_format, + features=self.info.features, + record_set_id=record_set.id, + ) )