diff --git a/tfx/components/util/tfxio_utils.py b/tfx/components/util/tfxio_utils.py index 7d5ee8057b..c6a59e9a0f 100644 --- a/tfx/components/util/tfxio_utils.py +++ b/tfx/components/util/tfxio_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """TFXIO (standardized TFX inputs) related utilities.""" +from __future__ import annotations from typing import Any, Callable, Dict, List, Iterator, Optional, Tuple, Union @@ -23,6 +24,7 @@ from tfx.types import artifact from tfx.types import standard_artifacts from tfx_bsl.tfxio import dataset_options +from tfx_bsl.tfxio import parquet_tfxio from tfx_bsl.tfxio import raw_tf_record from tfx_bsl.tfxio import record_to_tensor_tfxio from tfx_bsl.tfxio import tf_example_record @@ -30,6 +32,11 @@ from tfx_bsl.tfxio import tfxio from tensorflow_metadata.proto.v0 import schema_pb2 + +_SUPPORTED_FILE_FORMATS = ( + example_gen_pb2.FileFormat.FILE_FORMAT_PARQUET, + example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP +) # TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl # post-0.22 is released. OneOrMorePatterns = Union[str, List[str]] @@ -242,12 +249,12 @@ def get_data_view_decode_fn_from_artifact( def make_tfxio( file_pattern: OneOrMorePatterns, telemetry_descriptors: List[str], - payload_format: Union[str, int], + payload_format: Union[example_gen_pb2.PayloadFormat, int], data_view_uri: Optional[str] = None, schema: Optional[schema_pb2.Schema] = None, read_as_raw_records: bool = False, raw_record_column_name: Optional[str] = None, - file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO: + file_format: Optional[Union[example_gen_pb2.FileFormat, List[example_gen_pb2.FileFormat]]] = None) -> tfxio.TFXIO: """Creates a TFXIO instance that reads `file_pattern`. Args: @@ -271,8 +278,8 @@ def make_tfxio( that column will be the raw records. Note that not all TFXIO supports this option, and an error will be raised in that case. Required if read_as_raw_records == True. - file_format: file format string for each file_pattern. Only 'tfrecords_gzip' - is supported for now. + file_format: file format for each file_pattern. Only 'tfrecords_gzip' + and 'parquet' are supported for now. Returns: a TFXIO instance. @@ -291,10 +298,10 @@ def make_tfxio( f'The length of file_pattern and file_formats should be the same.' f'Given: file_pattern={file_pattern}, file_format={file_format}') else: - if any(item != 'tfrecords_gzip' for item in file_format): + if any(item not in _SUPPORTED_FILE_FORMATS for item in file_format): raise NotImplementedError(f'{file_format} is not supported yet.') else: # file_format is str type. - if file_format != 'tfrecords_gzip': + if file_format not in _SUPPORTED_FILE_FORMATS: raise NotImplementedError(f'{file_format} is not supported yet.') if read_as_raw_records: @@ -330,6 +337,12 @@ def make_tfxio( telemetry_descriptors=telemetry_descriptors, raw_record_column_name=raw_record_column_name) + if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PARQUET: + return parquet_tfxio.ParquetTFXIO( + file_pattern=file_pattern, + schema=schema, + telemetry_descriptors=telemetry_descriptors) + raise NotImplementedError( 'Unsupport payload format: {}'.format(payload_format)) diff --git a/tfx/components/util/tfxio_utils_test.py b/tfx/components/util/tfxio_utils_test.py index 8625e7953d..5ae3334f3b 100644 --- a/tfx/components/util/tfxio_utils_test.py +++ b/tfx/components/util/tfxio_utils_test.py @@ -26,6 +26,7 @@ from tfx.proto import example_gen_pb2 from tfx.types import standard_artifacts from tfx_bsl.coders import tf_graph_record_decoder +from tfx_bsl.tfxio import parquet_tfxio from tfx_bsl.tfxio import raw_tf_record from tfx_bsl.tfxio import record_based_tfxio from tfx_bsl.tfxio import record_to_tensor_tfxio @@ -70,6 +71,10 @@ read_as_raw_records=True, raw_record_column_name=_RAW_RECORD_COLUMN_NAME, expected_tfxio_type=raw_tf_record.RawTfRecordTFXIO), + dict( + testcase_name='parquet_payload_format', + payload_format=example_gen_pb2.PayloadFormat.FORMAT_PARQUET, + expected_tfxio_type=parquet_tfxio.ParquetTFXIO), ] _RESOLVE_TEST_CASES = [ @@ -181,11 +186,9 @@ def test_make_tfxio(self, payload_format, expected_tfxio_type, _FAKE_FILE_PATTERN, _TELEMETRY_DESCRIPTORS, payload_format, data_view_uri, _SCHEMA, read_as_raw_records, raw_record_column_name) self.assertIsInstance(tfxio, expected_tfxio_type) - # We currently only create RecordBasedTFXIO and the check below relies on - # that. - self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO) self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS) - self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) + if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO): + self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) # Since we provide a schema, ArrowSchema() should not raise. _ = tfxio.ArrowSchema() @@ -219,11 +222,9 @@ def test_get_tfxio_factory_from_artifact(self, raw_record_column_name) tfxio = tfxio_factory(_FAKE_FILE_PATTERN) self.assertIsInstance(tfxio, expected_tfxio_type) - # We currently only create RecordBasedTFXIO and the check below relies on - # that. - self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO) self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS) - self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) + if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO): + self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) # Since we provide a schema, ArrowSchema() should not raise. _ = tfxio.ArrowSchema() @@ -306,7 +307,7 @@ def test_get_tf_dataset_factory_from_artifact(self): dataset_factory = tfxio_utils.get_tf_dataset_factory_from_artifact( [examples], _TELEMETRY_DESCRIPTORS) self.assertIsInstance(dataset_factory, Callable) - self.assertEqual(tf.data.Dataset, + self.assertEqual('tf.data.Dataset', inspect.signature(dataset_factory).return_annotation) def test_get_record_batch_factory_from_artifact(self): @@ -317,7 +318,7 @@ def test_get_record_batch_factory_from_artifact(self): record_batch_factory = tfxio_utils.get_record_batch_factory_from_artifact( [examples], _TELEMETRY_DESCRIPTORS) self.assertIsInstance(record_batch_factory, Callable) - self.assertEqual(Iterator[pa.RecordBatch], + self.assertEqual('Iterator[pa.RecordBatch]', inspect.signature(record_batch_factory).return_annotation) def test_raise_if_data_view_uri_not_available(self): diff --git a/tfx/proto/example_gen.proto b/tfx/proto/example_gen.proto index c419a53534..d2008b3ed8 100644 --- a/tfx/proto/example_gen.proto +++ b/tfx/proto/example_gen.proto @@ -111,7 +111,10 @@ enum PayloadFormat { // Serialized any protocol buffer. FORMAT_PROTO = 11; - reserved 1 to 5, 8 to 10, 12 to max; + // Serialized parquet messages. + FORMAT_PARQUET = 15; + + reserved 1 to 5, 8 to 10, 12 to 14, 16 to max; } // Enum to indicate file format that ExampleGen produces. @@ -122,7 +125,12 @@ enum FileFormat { // Indicates TFRecords format files with gzip compression. FORMAT_TFRECORDS_GZIP = 5; - reserved 1 to 4, 6 to max; + // Indicates parquet format files in any of the supported compressions. + // protolint:disable MAX_LINE_LENGTH + // https://arrow.apache.org/docs/python/parquet.html#compression-encoding-and-file-compatibility + FILE_FORMAT_PARQUET = 16; + + reserved 1 to 4, 7 to 15, 17 to max; } // Specification of the output of the example gen.