From c9b34af73b797421510c329b1dd6b224f5906429 Mon Sep 17 00:00:00 2001 From: Martin Bomio Date: Wed, 23 Mar 2022 18:05:48 -0600 Subject: [PATCH 1/7] Extend tfxio factory to use parquet-tfxio --- tfx/components/util/tfxio_utils.py | 15 ++++++++++++--- tfx/components/util/tfxio_utils_test.py | 17 +++++++++-------- tfx/proto/example_gen.proto | 5 ++++- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tfx/components/util/tfxio_utils.py b/tfx/components/util/tfxio_utils.py index 7d5ee8057b..ca90566622 100644 --- a/tfx/components/util/tfxio_utils.py +++ b/tfx/components/util/tfxio_utils.py @@ -23,6 +23,7 @@ from tfx.types import artifact from tfx.types import standard_artifacts from tfx_bsl.tfxio import dataset_options +from tfx_bsl.tfxio import parquet_tfxio from tfx_bsl.tfxio import raw_tf_record from tfx_bsl.tfxio import record_to_tensor_tfxio from tfx_bsl.tfxio import tf_example_record @@ -30,6 +31,8 @@ from tfx_bsl.tfxio import tfxio from tensorflow_metadata.proto.v0 import schema_pb2 + +_SUPPORTED_PAYLOAD_FORMATS = ['parquet', 'tfrecords_gzip'] # TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl # post-0.22 is released. OneOrMorePatterns = Union[str, List[str]] @@ -272,7 +275,7 @@ def make_tfxio( option, and an error will be raised in that case. Required if read_as_raw_records == True. file_format: file format string for each file_pattern. Only 'tfrecords_gzip' - is supported for now. + and 'parquet' are supported for now. Returns: a TFXIO instance. @@ -291,10 +294,10 @@ def make_tfxio( f'The length of file_pattern and file_formats should be the same.' f'Given: file_pattern={file_pattern}, file_format={file_format}') else: - if any(item != 'tfrecords_gzip' for item in file_format): + if any(item in _SUPPORTED_PAYLOAD_FORMATS for item in file_format): raise NotImplementedError(f'{file_format} is not supported yet.') else: # file_format is str type. - if file_format != 'tfrecords_gzip': + if file_format in _SUPPORTED_PAYLOAD_FORMATS: raise NotImplementedError(f'{file_format} is not supported yet.') if read_as_raw_records: @@ -330,6 +333,12 @@ def make_tfxio( telemetry_descriptors=telemetry_descriptors, raw_record_column_name=raw_record_column_name) + if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PARQUET: + return parquet_tfxio.ParquetTFXIO( + file_pattern=file_pattern, + schema=schema, + telemetry_descriptors=telemetry_descriptors) + raise NotImplementedError( 'Unsupport payload format: {}'.format(payload_format)) diff --git a/tfx/components/util/tfxio_utils_test.py b/tfx/components/util/tfxio_utils_test.py index 8625e7953d..e0af18c98f 100644 --- a/tfx/components/util/tfxio_utils_test.py +++ b/tfx/components/util/tfxio_utils_test.py @@ -26,6 +26,7 @@ from tfx.proto import example_gen_pb2 from tfx.types import standard_artifacts from tfx_bsl.coders import tf_graph_record_decoder +from tfx_bsl.tfxio import parquet_tfxio from tfx_bsl.tfxio import raw_tf_record from tfx_bsl.tfxio import record_based_tfxio from tfx_bsl.tfxio import record_to_tensor_tfxio @@ -70,6 +71,10 @@ read_as_raw_records=True, raw_record_column_name=_RAW_RECORD_COLUMN_NAME, expected_tfxio_type=raw_tf_record.RawTfRecordTFXIO), + dict( + testcase_name='parquet_payload_format', + payload_format=example_gen_pb2.PayloadFormat.FORMAT_PARQUET, + expected_tfxio_type=parquet_tfxio.ParquetTFXIO), ] _RESOLVE_TEST_CASES = [ @@ -181,11 +186,9 @@ def test_make_tfxio(self, payload_format, expected_tfxio_type, _FAKE_FILE_PATTERN, _TELEMETRY_DESCRIPTORS, payload_format, data_view_uri, _SCHEMA, read_as_raw_records, raw_record_column_name) self.assertIsInstance(tfxio, expected_tfxio_type) - # We currently only create RecordBasedTFXIO and the check below relies on - # that. - self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO) self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS) - self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) + if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO): + self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) # Since we provide a schema, ArrowSchema() should not raise. _ = tfxio.ArrowSchema() @@ -219,11 +222,9 @@ def test_get_tfxio_factory_from_artifact(self, raw_record_column_name) tfxio = tfxio_factory(_FAKE_FILE_PATTERN) self.assertIsInstance(tfxio, expected_tfxio_type) - # We currently only create RecordBasedTFXIO and the check below relies on - # that. - self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO) self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS) - self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) + if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO): + self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) # Since we provide a schema, ArrowSchema() should not raise. _ = tfxio.ArrowSchema() diff --git a/tfx/proto/example_gen.proto b/tfx/proto/example_gen.proto index c419a53534..98c01b5c8c 100644 --- a/tfx/proto/example_gen.proto +++ b/tfx/proto/example_gen.proto @@ -111,7 +111,10 @@ enum PayloadFormat { // Serialized any protocol buffer. FORMAT_PROTO = 11; - reserved 1 to 5, 8 to 10, 12 to max; + // Serialized parquet messages. + FORMAT_PARQUET = 12; + + reserved 1 to 5, 8 to 10, 13 to max; } // Enum to indicate file format that ExampleGen produces. From 5ac2ff2fe3309d3a6196199fb5b16c30c84a2c0e Mon Sep 17 00:00:00 2001 From: Martin Bomio Date: Wed, 30 Mar 2022 17:52:17 -0400 Subject: [PATCH 2/7] Add parquet file format --- tfx/components/util/tfxio_utils.py | 12 ++++++------ tfx/proto/example_gen.proto | 6 +++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tfx/components/util/tfxio_utils.py b/tfx/components/util/tfxio_utils.py index ca90566622..4b54abfb2d 100644 --- a/tfx/components/util/tfxio_utils.py +++ b/tfx/components/util/tfxio_utils.py @@ -32,7 +32,7 @@ from tensorflow_metadata.proto.v0 import schema_pb2 -_SUPPORTED_PAYLOAD_FORMATS = ['parquet', 'tfrecords_gzip'] +_SUPPORTED_FILE_FORMATS = {example_gen_pb2.FileFormat.FORMAT_PARQUET, example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP} # TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl # post-0.22 is released. OneOrMorePatterns = Union[str, List[str]] @@ -245,12 +245,12 @@ def get_data_view_decode_fn_from_artifact( def make_tfxio( file_pattern: OneOrMorePatterns, telemetry_descriptors: List[str], - payload_format: Union[str, int], + payload_format: Union[example_gen_pb2.PayloadFormat, int], data_view_uri: Optional[str] = None, schema: Optional[schema_pb2.Schema] = None, read_as_raw_records: bool = False, raw_record_column_name: Optional[str] = None, - file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO: + file_format: Optional[Union[example_gen_pb2.FileFormat, List[example_gen_pb2.FileFormat]]] = None) -> tfxio.TFXIO: """Creates a TFXIO instance that reads `file_pattern`. Args: @@ -274,7 +274,7 @@ def make_tfxio( that column will be the raw records. Note that not all TFXIO supports this option, and an error will be raised in that case. Required if read_as_raw_records == True. - file_format: file format string for each file_pattern. Only 'tfrecords_gzip' + file_format: file format for each file_pattern. Only 'tfrecords_gzip' and 'parquet' are supported for now. Returns: @@ -294,10 +294,10 @@ def make_tfxio( f'The length of file_pattern and file_formats should be the same.' f'Given: file_pattern={file_pattern}, file_format={file_format}') else: - if any(item in _SUPPORTED_PAYLOAD_FORMATS for item in file_format): + if any(item in _SUPPORTED_FILE_FORMATS for item in file_format): raise NotImplementedError(f'{file_format} is not supported yet.') else: # file_format is str type. - if file_format in _SUPPORTED_PAYLOAD_FORMATS: + if file_format in _SUPPORTED_FILE_FORMATS: raise NotImplementedError(f'{file_format} is not supported yet.') if read_as_raw_records: diff --git a/tfx/proto/example_gen.proto b/tfx/proto/example_gen.proto index 98c01b5c8c..5eed00daf7 100644 --- a/tfx/proto/example_gen.proto +++ b/tfx/proto/example_gen.proto @@ -125,7 +125,11 @@ enum FileFormat { // Indicates TFRecords format files with gzip compression. FORMAT_TFRECORDS_GZIP = 5; - reserved 1 to 4, 6 to max; + // Indicates parquet format files in any of the supported compressions. + // https://arrow.apache.org/docs/python/parquet.html#compression-encoding-and-file-compatibility + FORMAT_PARQUET = 6; + + reserved 1 to 4, 7 to max; } // Specification of the output of the example gen. From 3913c44224f2900842db3f8b3a8837d7c52bdb80 Mon Sep 17 00:00:00 2001 From: Martin Bomio Date: Tue, 17 May 2022 16:55:11 -0400 Subject: [PATCH 3/7] Address PR comments --- tfx/components/util/tfxio_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfx/components/util/tfxio_utils.py b/tfx/components/util/tfxio_utils.py index 4b54abfb2d..c2cc04ee2a 100644 --- a/tfx/components/util/tfxio_utils.py +++ b/tfx/components/util/tfxio_utils.py @@ -32,7 +32,7 @@ from tensorflow_metadata.proto.v0 import schema_pb2 -_SUPPORTED_FILE_FORMATS = {example_gen_pb2.FileFormat.FORMAT_PARQUET, example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP} +_SUPPORTED_FILE_FORMATS = (example_gen_pb2.FileFormat.FORMAT_PARQUET, example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP) # TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl # post-0.22 is released. OneOrMorePatterns = Union[str, List[str]] @@ -294,10 +294,10 @@ def make_tfxio( f'The length of file_pattern and file_formats should be the same.' f'Given: file_pattern={file_pattern}, file_format={file_format}') else: - if any(item in _SUPPORTED_FILE_FORMATS for item in file_format): + if any(item not in _SUPPORTED_FILE_FORMATS for item in file_format): raise NotImplementedError(f'{file_format} is not supported yet.') else: # file_format is str type. - if file_format in _SUPPORTED_FILE_FORMATS: + if file_format not in _SUPPORTED_FILE_FORMATS: raise NotImplementedError(f'{file_format} is not supported yet.') if read_as_raw_records: From 30bec961def3800b1f9b1f1b597aed74ec7e694e Mon Sep 17 00:00:00 2001 From: Martin Bomio Date: Tue, 24 May 2022 09:48:25 -0400 Subject: [PATCH 4/7] More PR suggestions --- tfx/proto/example_gen.proto | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tfx/proto/example_gen.proto b/tfx/proto/example_gen.proto index 5eed00daf7..93f9179822 100644 --- a/tfx/proto/example_gen.proto +++ b/tfx/proto/example_gen.proto @@ -112,9 +112,9 @@ enum PayloadFormat { FORMAT_PROTO = 11; // Serialized parquet messages. - FORMAT_PARQUET = 12; + FORMAT_PARQUET = 15; - reserved 1 to 5, 8 to 10, 13 to max; + reserved 1 to 5, 8 to 10, 12 to 14, 16 to max; } // Enum to indicate file format that ExampleGen produces. @@ -127,9 +127,9 @@ enum FileFormat { // Indicates parquet format files in any of the supported compressions. // https://arrow.apache.org/docs/python/parquet.html#compression-encoding-and-file-compatibility - FORMAT_PARQUET = 6; + FORMAT_PARQUET = 16; - reserved 1 to 4, 7 to max; + reserved 1 to 4, 7 to 15, 17 to max; } // Specification of the output of the example gen. From 3bdccaf2bb48364e80a924f0588c8b726000a798 Mon Sep 17 00:00:00 2001 From: Martin Bomio Date: Fri, 27 May 2022 10:39:17 -0400 Subject: [PATCH 5/7] Rename parquet file format enum --- tfx/components/util/tfxio_utils.py | 5 ++++- tfx/proto/example_gen.proto | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tfx/components/util/tfxio_utils.py b/tfx/components/util/tfxio_utils.py index c2cc04ee2a..9a76c6e54d 100644 --- a/tfx/components/util/tfxio_utils.py +++ b/tfx/components/util/tfxio_utils.py @@ -32,7 +32,10 @@ from tensorflow_metadata.proto.v0 import schema_pb2 -_SUPPORTED_FILE_FORMATS = (example_gen_pb2.FileFormat.FORMAT_PARQUET, example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP) +_SUPPORTED_FILE_FORMATS = ( + example_gen_pb2.FileFormat.FILE_FORMAT_PARQUET, + example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP +) # TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl # post-0.22 is released. OneOrMorePatterns = Union[str, List[str]] diff --git a/tfx/proto/example_gen.proto b/tfx/proto/example_gen.proto index 93f9179822..4e9f0ee4f2 100644 --- a/tfx/proto/example_gen.proto +++ b/tfx/proto/example_gen.proto @@ -127,7 +127,7 @@ enum FileFormat { // Indicates parquet format files in any of the supported compressions. // https://arrow.apache.org/docs/python/parquet.html#compression-encoding-and-file-compatibility - FORMAT_PARQUET = 16; + FILE_FORMAT_PARQUET = 16; reserved 1 to 4, 7 to 15, 17 to max; } From 35bfd1f72b395d288f2043680fa1c1591b7dfbaa Mon Sep 17 00:00:00 2001 From: Martin Bomio Date: Fri, 27 May 2022 17:30:53 -0400 Subject: [PATCH 6/7] Add future annotations to support proto enum typing --- tfx/components/util/tfxio_utils.py | 1 + tfx/components/util/tfxio_utils_test.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tfx/components/util/tfxio_utils.py b/tfx/components/util/tfxio_utils.py index 9a76c6e54d..c6a59e9a0f 100644 --- a/tfx/components/util/tfxio_utils.py +++ b/tfx/components/util/tfxio_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """TFXIO (standardized TFX inputs) related utilities.""" +from __future__ import annotations from typing import Any, Callable, Dict, List, Iterator, Optional, Tuple, Union diff --git a/tfx/components/util/tfxio_utils_test.py b/tfx/components/util/tfxio_utils_test.py index e0af18c98f..5ae3334f3b 100644 --- a/tfx/components/util/tfxio_utils_test.py +++ b/tfx/components/util/tfxio_utils_test.py @@ -307,7 +307,7 @@ def test_get_tf_dataset_factory_from_artifact(self): dataset_factory = tfxio_utils.get_tf_dataset_factory_from_artifact( [examples], _TELEMETRY_DESCRIPTORS) self.assertIsInstance(dataset_factory, Callable) - self.assertEqual(tf.data.Dataset, + self.assertEqual('tf.data.Dataset', inspect.signature(dataset_factory).return_annotation) def test_get_record_batch_factory_from_artifact(self): @@ -318,7 +318,7 @@ def test_get_record_batch_factory_from_artifact(self): record_batch_factory = tfxio_utils.get_record_batch_factory_from_artifact( [examples], _TELEMETRY_DESCRIPTORS) self.assertIsInstance(record_batch_factory, Callable) - self.assertEqual(Iterator[pa.RecordBatch], + self.assertEqual('Iterator[pa.RecordBatch]', inspect.signature(record_batch_factory).return_annotation) def test_raise_if_data_view_uri_not_available(self): From 035e3cf8439686d5f083ff2c4698e4668bb08e85 Mon Sep 17 00:00:00 2001 From: Martin Bomio Date: Mon, 30 May 2022 11:59:06 -0400 Subject: [PATCH 7/7] Disable max line limit in protolint for comment --- tfx/proto/example_gen.proto | 1 + 1 file changed, 1 insertion(+) diff --git a/tfx/proto/example_gen.proto b/tfx/proto/example_gen.proto index 4e9f0ee4f2..d2008b3ed8 100644 --- a/tfx/proto/example_gen.proto +++ b/tfx/proto/example_gen.proto @@ -126,6 +126,7 @@ enum FileFormat { FORMAT_TFRECORDS_GZIP = 5; // Indicates parquet format files in any of the supported compressions. + // protolint:disable MAX_LINE_LENGTH // https://arrow.apache.org/docs/python/parquet.html#compression-encoding-and-file-compatibility FILE_FORMAT_PARQUET = 16;