Skip to content

Commit

Permalink
Extend tfxio factory to use parquet-tfxio
Browse files Browse the repository at this point in the history
  • Loading branch information
martinbomio committed May 27, 2022
1 parent 84ca4b4 commit c9b34af
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
15 changes: 12 additions & 3 deletions tfx/components/util/tfxio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
from tfx.types import artifact
from tfx.types import standard_artifacts
from tfx_bsl.tfxio import dataset_options
from tfx_bsl.tfxio import parquet_tfxio
from tfx_bsl.tfxio import raw_tf_record
from tfx_bsl.tfxio import record_to_tensor_tfxio
from tfx_bsl.tfxio import tf_example_record
from tfx_bsl.tfxio import tf_sequence_example_record
from tfx_bsl.tfxio import tfxio
from tensorflow_metadata.proto.v0 import schema_pb2


_SUPPORTED_PAYLOAD_FORMATS = ['parquet', 'tfrecords_gzip']
# TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl
# post-0.22 is released.
OneOrMorePatterns = Union[str, List[str]]
Expand Down Expand Up @@ -272,7 +275,7 @@ def make_tfxio(
option, and an error will be raised in that case. Required if
read_as_raw_records == True.
file_format: file format string for each file_pattern. Only 'tfrecords_gzip'
is supported for now.
and 'parquet' are supported for now.
Returns:
a TFXIO instance.
Expand All @@ -291,10 +294,10 @@ def make_tfxio(
f'The length of file_pattern and file_formats should be the same.'
f'Given: file_pattern={file_pattern}, file_format={file_format}')
else:
if any(item != 'tfrecords_gzip' for item in file_format):
if any(item in _SUPPORTED_PAYLOAD_FORMATS for item in file_format):
raise NotImplementedError(f'{file_format} is not supported yet.')
else: # file_format is str type.
if file_format != 'tfrecords_gzip':
if file_format in _SUPPORTED_PAYLOAD_FORMATS:
raise NotImplementedError(f'{file_format} is not supported yet.')

if read_as_raw_records:
Expand Down Expand Up @@ -330,6 +333,12 @@ def make_tfxio(
telemetry_descriptors=telemetry_descriptors,
raw_record_column_name=raw_record_column_name)

if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PARQUET:
return parquet_tfxio.ParquetTFXIO(
file_pattern=file_pattern,
schema=schema,
telemetry_descriptors=telemetry_descriptors)

raise NotImplementedError(
'Unsupport payload format: {}'.format(payload_format))

Expand Down
17 changes: 9 additions & 8 deletions tfx/components/util/tfxio_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tfx.proto import example_gen_pb2
from tfx.types import standard_artifacts
from tfx_bsl.coders import tf_graph_record_decoder
from tfx_bsl.tfxio import parquet_tfxio
from tfx_bsl.tfxio import raw_tf_record
from tfx_bsl.tfxio import record_based_tfxio
from tfx_bsl.tfxio import record_to_tensor_tfxio
Expand Down Expand Up @@ -70,6 +71,10 @@
read_as_raw_records=True,
raw_record_column_name=_RAW_RECORD_COLUMN_NAME,
expected_tfxio_type=raw_tf_record.RawTfRecordTFXIO),
dict(
testcase_name='parquet_payload_format',
payload_format=example_gen_pb2.PayloadFormat.FORMAT_PARQUET,
expected_tfxio_type=parquet_tfxio.ParquetTFXIO),
]

_RESOLVE_TEST_CASES = [
Expand Down Expand Up @@ -181,11 +186,9 @@ def test_make_tfxio(self, payload_format, expected_tfxio_type,
_FAKE_FILE_PATTERN, _TELEMETRY_DESCRIPTORS, payload_format,
data_view_uri, _SCHEMA, read_as_raw_records, raw_record_column_name)
self.assertIsInstance(tfxio, expected_tfxio_type)
# We currently only create RecordBasedTFXIO and the check below relies on
# that.
self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO):
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
# Since we provide a schema, ArrowSchema() should not raise.
_ = tfxio.ArrowSchema()

Expand Down Expand Up @@ -219,11 +222,9 @@ def test_get_tfxio_factory_from_artifact(self,
raw_record_column_name)
tfxio = tfxio_factory(_FAKE_FILE_PATTERN)
self.assertIsInstance(tfxio, expected_tfxio_type)
# We currently only create RecordBasedTFXIO and the check below relies on
# that.
self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO):
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
# Since we provide a schema, ArrowSchema() should not raise.
_ = tfxio.ArrowSchema()

Expand Down
5 changes: 4 additions & 1 deletion tfx/proto/example_gen.proto
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ enum PayloadFormat {
// Serialized any protocol buffer.
FORMAT_PROTO = 11;

reserved 1 to 5, 8 to 10, 12 to max;
// Serialized parquet messages.
FORMAT_PARQUET = 12;

reserved 1 to 5, 8 to 10, 13 to max;
}

// Enum to indicate file format that ExampleGen produces.
Expand Down

0 comments on commit c9b34af

Please sign in to comment.