diff --git a/tfx/components/util/tfxio_utils.py b/tfx/components/util/tfxio_utils.py index ca905666226..4b54abfb2d3 100644 --- a/tfx/components/util/tfxio_utils.py +++ b/tfx/components/util/tfxio_utils.py @@ -32,7 +32,7 @@ from tensorflow_metadata.proto.v0 import schema_pb2 -_SUPPORTED_PAYLOAD_FORMATS = ['parquet', 'tfrecords_gzip'] +_SUPPORTED_FILE_FORMATS = {example_gen_pb2.FileFormat.FORMAT_PARQUET, example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP} # TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl # post-0.22 is released. OneOrMorePatterns = Union[str, List[str]] @@ -245,12 +245,12 @@ def get_data_view_decode_fn_from_artifact( def make_tfxio( file_pattern: OneOrMorePatterns, telemetry_descriptors: List[str], - payload_format: Union[str, int], + payload_format: Union[example_gen_pb2.PayloadFormat, int], data_view_uri: Optional[str] = None, schema: Optional[schema_pb2.Schema] = None, read_as_raw_records: bool = False, raw_record_column_name: Optional[str] = None, - file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO: + file_format: Optional[Union[example_gen_pb2.FileFormat, List[example_gen_pb2.FileFormat]]] = None) -> tfxio.TFXIO: """Creates a TFXIO instance that reads `file_pattern`. Args: @@ -274,7 +274,7 @@ def make_tfxio( that column will be the raw records. Note that not all TFXIO supports this option, and an error will be raised in that case. Required if read_as_raw_records == True. - file_format: file format string for each file_pattern. Only 'tfrecords_gzip' + file_format: file format for each file_pattern. Only 'tfrecords_gzip' and 'parquet' are supported for now. Returns: @@ -294,10 +294,10 @@ def make_tfxio( f'The length of file_pattern and file_formats should be the same.' f'Given: file_pattern={file_pattern}, file_format={file_format}') else: - if any(item in _SUPPORTED_PAYLOAD_FORMATS for item in file_format): + if any(item in _SUPPORTED_FILE_FORMATS for item in file_format): raise NotImplementedError(f'{file_format} is not supported yet.') else: # file_format is str type. - if file_format in _SUPPORTED_PAYLOAD_FORMATS: + if file_format in _SUPPORTED_FILE_FORMATS: raise NotImplementedError(f'{file_format} is not supported yet.') if read_as_raw_records: diff --git a/tfx/proto/example_gen.proto b/tfx/proto/example_gen.proto index 98c01b5c8c1..5eed00daf7b 100644 --- a/tfx/proto/example_gen.proto +++ b/tfx/proto/example_gen.proto @@ -125,7 +125,11 @@ enum FileFormat { // Indicates TFRecords format files with gzip compression. FORMAT_TFRECORDS_GZIP = 5; - reserved 1 to 4, 6 to max; + // Indicates parquet format files in any of the supported compressions. + // https://arrow.apache.org/docs/python/parquet.html#compression-encoding-and-file-compatibility + FORMAT_PARQUET = 6; + + reserved 1 to 4, 7 to max; } // Specification of the output of the example gen.