Skip to content

Commit

Permalink
Merge pull request #4761 from martinbomio:parquet-tfxio-support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452081871
  • Loading branch information
tfx-copybara committed Jun 8, 2022
2 parents 4b5008a + 035e3cf commit bce3fc8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
30 changes: 24 additions & 6 deletions tfx/components/util/tfxio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
from tfx.types import artifact
from tfx.types import standard_artifacts
from tfx_bsl.tfxio import dataset_options
from tfx_bsl.tfxio import parquet_tfxio
from tfx_bsl.tfxio import raw_tf_record
from tfx_bsl.tfxio import record_to_tensor_tfxio
from tfx_bsl.tfxio import tf_example_record
from tfx_bsl.tfxio import tf_sequence_example_record
from tfx_bsl.tfxio import tfxio
from tensorflow_metadata.proto.v0 import schema_pb2

_SUPPORTED_FILE_FORMATS = (example_gen_pb2.FileFormat.FILE_FORMAT_PARQUET,
example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP)
# TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl
# post-0.22 is released.
OneOrMorePatterns = Union[str, List[str]]
Expand Down Expand Up @@ -242,12 +245,13 @@ def get_data_view_decode_fn_from_artifact(
def make_tfxio(
file_pattern: OneOrMorePatterns,
telemetry_descriptors: List[str],
payload_format: Union[str, int],
payload_format: int,
data_view_uri: Optional[str] = None,
schema: Optional[schema_pb2.Schema] = None,
read_as_raw_records: bool = False,
raw_record_column_name: Optional[str] = None,
file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO:
file_format: Optional[Union[int, List[int], str, List[str]]] = None
) -> tfxio.TFXIO:
"""Creates a TFXIO instance that reads `file_pattern`.
Args:
Expand All @@ -271,8 +275,8 @@ def make_tfxio(
that column will be the raw records. Note that not all TFXIO supports this
option, and an error will be raised in that case. Required if
read_as_raw_records == True.
file_format: file format string for each file_pattern. Only 'tfrecords_gzip'
is supported for now.
file_format: file format for each file_pattern. Only 'tfrecords_gzip' and
'parquet' are supported for now.
Returns:
a TFXIO instance.
Expand All @@ -291,10 +295,18 @@ def make_tfxio(
f'The length of file_pattern and file_formats should be the same.'
f'Given: file_pattern={file_pattern}, file_format={file_format}')
else:
if any(item != 'tfrecords_gzip' for item in file_format):
# TODO(b/216604827): deprecated str file format.
file_format = [
example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP
if item == 'tfrecords_gzip' else item for item in file_format
]
if any(item not in _SUPPORTED_FILE_FORMATS for item in file_format):
raise NotImplementedError(f'{file_format} is not supported yet.')
else: # file_format is str type.
if file_format != 'tfrecords_gzip':
# TODO(b/216604827): deprecated str file format.
if file_format == 'tfrecords_gzip':
file_format = example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP
if file_format not in _SUPPORTED_FILE_FORMATS:
raise NotImplementedError(f'{file_format} is not supported yet.')

if read_as_raw_records:
Expand Down Expand Up @@ -330,6 +342,12 @@ def make_tfxio(
telemetry_descriptors=telemetry_descriptors,
raw_record_column_name=raw_record_column_name)

if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PARQUET:
return parquet_tfxio.ParquetTFXIO(
file_pattern=file_pattern,
schema=schema,
telemetry_descriptors=telemetry_descriptors)

raise NotImplementedError(
'Unsupport payload format: {}'.format(payload_format))

Expand Down
17 changes: 9 additions & 8 deletions tfx/components/util/tfxio_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tfx.proto import example_gen_pb2
from tfx.types import standard_artifacts
from tfx_bsl.coders import tf_graph_record_decoder
from tfx_bsl.tfxio import parquet_tfxio
from tfx_bsl.tfxio import raw_tf_record
from tfx_bsl.tfxio import record_based_tfxio
from tfx_bsl.tfxio import record_to_tensor_tfxio
Expand Down Expand Up @@ -70,6 +71,10 @@
read_as_raw_records=True,
raw_record_column_name=_RAW_RECORD_COLUMN_NAME,
expected_tfxio_type=raw_tf_record.RawTfRecordTFXIO),
dict(
testcase_name='parquet_payload_format',
payload_format=example_gen_pb2.PayloadFormat.FORMAT_PARQUET,
expected_tfxio_type=parquet_tfxio.ParquetTFXIO),
]

_RESOLVE_TEST_CASES = [
Expand Down Expand Up @@ -181,11 +186,9 @@ def test_make_tfxio(self, payload_format, expected_tfxio_type,
_FAKE_FILE_PATTERN, _TELEMETRY_DESCRIPTORS, payload_format,
data_view_uri, _SCHEMA, read_as_raw_records, raw_record_column_name)
self.assertIsInstance(tfxio, expected_tfxio_type)
# We currently only create RecordBasedTFXIO and the check below relies on
# that.
self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO):
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
# Since we provide a schema, ArrowSchema() should not raise.
_ = tfxio.ArrowSchema()

Expand Down Expand Up @@ -219,11 +222,9 @@ def test_get_tfxio_factory_from_artifact(self,
raw_record_column_name)
tfxio = tfxio_factory(_FAKE_FILE_PATTERN)
self.assertIsInstance(tfxio, expected_tfxio_type)
# We currently only create RecordBasedTFXIO and the check below relies on
# that.
self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO):
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
# Since we provide a schema, ArrowSchema() should not raise.
_ = tfxio.ArrowSchema()

Expand Down
15 changes: 13 additions & 2 deletions tfx/proto/example_gen.proto
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ enum PayloadFormat {
// Serialized any protocol buffer.
FORMAT_PROTO = 11;

reserved 1 to 5, 8 to 10, 12 to max;
// Serialized parquet messages.
FORMAT_PARQUET = 15;

reserved 1 to 5, 8 to 10, 12 to 14, 16 to max;
}

// Enum to indicate file format that ExampleGen produces.
Expand All @@ -122,7 +125,15 @@ enum FileFormat {
// Indicates TFRecords format files with gzip compression.
FORMAT_TFRECORDS_GZIP = 5;

reserved 1 to 4, 6 to max;
// File format:
// Indicates RecordIO format files.
FORMAT_RECORDIO = 3;

// Indicates parquet format files in any of the supported compressions.
// https://arrow.apache.org/docs/python/parquet.html
FILE_FORMAT_PARQUET = 16;

reserved 1, 2, 4, 6, 7 to 15, 17 to max;
}

// Specification of the output of the example gen.
Expand Down

0 comments on commit bce3fc8

Please sign in to comment.