Skip to content

Commit

Permalink
Simplify ReadFromCroissant by removing the pipeline argument and ma…
Browse files Browse the repository at this point in the history
…king it a PCollection.

PiperOrigin-RevId: 702731977
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed Dec 4, 2024
1 parent 0419f1a commit 9d6090d
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions tensorflow_datasets/core/dataset_builders/croissant_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def _split_generators(
dl_manager: download.DownloadManager,
pipeline: beam.Pipeline,
) -> dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
del dl_manager # unused
del pipeline # unused
# If a split recordset is joined for the required record set, we generate
# splits accordingly. Otherwise, it generates a single `default` split with
# all the records.
Expand All @@ -302,7 +304,6 @@ def _split_generators(
split_key = split_reference.reference_field.references.field
return {
split[split_key]: self._generate_examples(
pipeline=pipeline,
filters={
**self._filters,
split_reference.reference_field.id: split[split_key],
Expand All @@ -311,21 +312,15 @@ def _split_generators(
for split in split_reference.split_record_set.data
}
else:
return {
'default': self._generate_examples(
pipeline=pipeline, filters=self._filters
)
}
return {'default': self._generate_examples(filters=self._filters)}

def _generate_examples(
self,
pipeline: beam.Pipeline,
filters: dict[str, Any],
) -> beam.PTransform:
"""Generates the examples for the given record set.
Args:
pipeline: The Beam pipeline.
filters: A dict of filters to apply to the records. The keys should be
field names and the values should be the values to filter by. If a
record matches all the filters, it will be included in the dataset.
Expand Down Expand Up @@ -354,10 +349,12 @@ def convert_to_tfds_format(
conversion_utils.to_tfds_value(record, features),
)

return records.beam_reader(
pipeline=pipeline
) | f'Convert to TFDS format for filters: {json.dumps(filters)}' >> beam.MapTuple(
convert_to_tfds_format,
features=self.info.features,
record_set_id=record_set.id,
return (
records.beam_reader()
| f'Convert to TFDS format for filters: {json.dumps(filters)}'
>> beam.MapTuple(
convert_to_tfds_format,
features=self.info.features,
record_set_id=record_set.id,
)
)

0 comments on commit 9d6090d

Please sign in to comment.