diff --git a/tfx_bsl/tfxio/parquet_tfxio.py b/tfx_bsl/tfxio/parquet_tfxio.py index 6dee4036..b2eaba66 100644 --- a/tfx_bsl/tfxio/parquet_tfxio.py +++ b/tfx_bsl/tfxio/parquet_tfxio.py @@ -39,8 +39,8 @@ class ParquetTFXIO(tfxio.TFXIO): def __init__(self, file_pattern: str, - column_names: List[str], *, + column_names: Optional[List[str]] = None, min_bundle_size: int = 0, schema: Optional[schema_pb2.Schema] = None, validate: bool = True, @@ -51,7 +51,8 @@ def __init__(self, file_pattern: A file glob pattern to read parquet files from. column_names: List of column names to read from the parquet files. min_bundle_size: the minimum size in bytes, to be considered when - splitting the parquet input into bundles. + splitting the parquet input into bundles. If not provided, all columns + in the dataset will be read. schema: An optional TFMD Schema describing the dataset. If schema is provided, it will determine the data type of the parquet columns. Otherwise, the each column's data type will be inferred by the decoder. @@ -70,6 +71,10 @@ def __init__(self, self._schema = schema self._telemetry_descriptors = telemetry_descriptors + @property + def telemetry_descriptors(self) -> Optional[List[str]]: + return self._telemetry_descriptors + def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform: @beam.typehints.with_input_types(Union[beam.PCollection, beam.Pipeline]) @@ -106,7 +111,11 @@ def _TableToRecordBatch( def ArrowSchema(self) -> pa.Schema: if self._schema is None: return self._InferArrowSchema() - return csv_decoder.GetArrowSchema(self._column_names, self._schema) + + # if the column names are not passed, we default to all column names in the schema. + columns = self._column_names or [f.name for f in self._schema.feature] + + return csv_decoder.GetArrowSchema(columns, self._schema) def _InferArrowSchema(self): match_result = FileSystems.match([self._file_pattern])[0] diff --git a/tfx_bsl/tfxio/parquet_tfxio_test.py b/tfx_bsl/tfxio/parquet_tfxio_test.py index cd8bf4f2..3e6799c0 100644 --- a/tfx_bsl/tfxio/parquet_tfxio_test.py +++ b/tfx_bsl/tfxio/parquet_tfxio_test.py @@ -301,6 +301,84 @@ def _AssertFn(record_batch_list): record_batch_pcoll = (p | tfxio.BeamSource(batch_size=_NUM_ROWS)) beam_testing_util.assert_that(record_batch_pcoll, _AssertFn) + def testOptionalColumnNames(self): + """Tests various valid schemas.""" + tfxio = ParquetTFXIO( + file_pattern=self._example_file, + schema=_SCHEMA) + + def _AssertFn(record_batch_list): + self.assertLen(record_batch_list, 1) + record_batch = record_batch_list[0] + self._ValidateRecordBatch(record_batch, _EXPECTED_ARROW_SCHEMA) + + with beam.Pipeline() as p: + record_batch_pcoll = (p | tfxio.BeamSource(batch_size=_NUM_ROWS)) + beam_testing_util.assert_that(record_batch_pcoll, _AssertFn) + + def testOptionalColumnNamesAndSchema(self): + """Tests various valid schemas.""" + tfxio = ParquetTFXIO(file_pattern=self._example_file) + + def _AssertFn(record_batch_list): + self.assertLen(record_batch_list, 1) + record_batch = record_batch_list[0] + self._ValidateRecordBatch(record_batch, _EXPECTED_ARROW_SCHEMA) + + with beam.Pipeline() as p: + record_batch_pcoll = (p | tfxio.BeamSource(batch_size=_NUM_ROWS)) + beam_testing_util.assert_that(record_batch_pcoll, _AssertFn) + + def testSubsetOfColumnNamesWithCompleteSchema(self): + """Tests various valid schemas.""" + tfxio = ParquetTFXIO( + file_pattern=self._example_file, + column_names=['int_feature'], + schema=_SCHEMA) + + def _AssertFn(record_batch_list): + self.assertLen(record_batch_list, 1) + record_batch = record_batch_list[0] + expected_arrow_schema = pa.schema([ + pa.field("int_feature", pa.large_list(pa.int64())), + ]) + self._ValidateRecordBatch(record_batch, expected_arrow_schema) + + with beam.Pipeline() as p: + record_batch_pcoll = (p | tfxio.BeamSource(batch_size=_NUM_ROWS)) + beam_testing_util.assert_that(record_batch_pcoll, _AssertFn) + + def testSubsetOfColumnNamesWithSubsetSchema(self): + """Tests various valid schemas.""" + schema = text_format.Parse( + """ + feature { + name: "int_feature" + type: INT + value_count { + min: 0 + max: 2 + } + } + """, schema_pb2.Schema()) + + tfxio = ParquetTFXIO( + file_pattern=self._example_file, + column_names=['int_feature'], + schema=schema) + + def _AssertFn(record_batch_list): + self.assertLen(record_batch_list, 1) + record_batch = record_batch_list[0] + expected_arrow_schema = pa.schema([ + pa.field("int_feature", pa.large_list(pa.int64())), + ]) + self._ValidateRecordBatch(record_batch, expected_arrow_schema) + + with beam.Pipeline() as p: + record_batch_pcoll = (p | tfxio.BeamSource(batch_size=_NUM_ROWS)) + beam_testing_util.assert_that(record_batch_pcoll, _AssertFn) + def _ValidateRecordBatch(self, record_batch, expected_arrow_schema): self.assertIsInstance(record_batch, pa.RecordBatch) self.assertEqual(record_batch.num_rows, 2)