Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend tfxio factory to use parquet-tfxio #4761

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions tfx/components/util/tfxio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFXIO (standardized TFX inputs) related utilities."""
from __future__ import annotations
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to add this to allow for proto enums to be used as types in signatures. Not sure if this is something desirable from the tfx side or not @jiyongjung0 @1025KB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can go back to string in the make_tfxio function signature instead of the proto enum type.


from typing import Any, Callable, Dict, List, Iterator, Optional, Tuple, Union

Expand All @@ -23,13 +24,19 @@
from tfx.types import artifact
from tfx.types import standard_artifacts
from tfx_bsl.tfxio import dataset_options
from tfx_bsl.tfxio import parquet_tfxio
from tfx_bsl.tfxio import raw_tf_record
from tfx_bsl.tfxio import record_to_tensor_tfxio
from tfx_bsl.tfxio import tf_example_record
from tfx_bsl.tfxio import tf_sequence_example_record
from tfx_bsl.tfxio import tfxio
from tensorflow_metadata.proto.v0 import schema_pb2


_SUPPORTED_FILE_FORMATS = (
example_gen_pb2.FileFormat.FILE_FORMAT_PARQUET,
example_gen_pb2.FileFormat.FORMAT_TFRECORDS_GZIP
)
# TODO(b/162532479): switch to support List[str] exclusively, once tfx-bsl
# post-0.22 is released.
OneOrMorePatterns = Union[str, List[str]]
Expand Down Expand Up @@ -242,12 +249,12 @@ def get_data_view_decode_fn_from_artifact(
def make_tfxio(
file_pattern: OneOrMorePatterns,
telemetry_descriptors: List[str],
payload_format: Union[str, int],
payload_format: Union[example_gen_pb2.PayloadFormat, int],
data_view_uri: Optional[str] = None,
schema: Optional[schema_pb2.Schema] = None,
read_as_raw_records: bool = False,
raw_record_column_name: Optional[str] = None,
file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO:
file_format: Optional[Union[example_gen_pb2.FileFormat, List[example_gen_pb2.FileFormat]]] = None) -> tfxio.TFXIO:
"""Creates a TFXIO instance that reads `file_pattern`.

Args:
Expand All @@ -271,8 +278,8 @@ def make_tfxio(
that column will be the raw records. Note that not all TFXIO supports this
option, and an error will be raised in that case. Required if
read_as_raw_records == True.
file_format: file format string for each file_pattern. Only 'tfrecords_gzip'
is supported for now.
file_format: file format for each file_pattern. Only 'tfrecords_gzip'
and 'parquet' are supported for now.

Returns:
a TFXIO instance.
Expand All @@ -291,10 +298,10 @@ def make_tfxio(
f'The length of file_pattern and file_formats should be the same.'
f'Given: file_pattern={file_pattern}, file_format={file_format}')
else:
if any(item != 'tfrecords_gzip' for item in file_format):
if any(item not in _SUPPORTED_FILE_FORMATS for item in file_format):
raise NotImplementedError(f'{file_format} is not supported yet.')
else: # file_format is str type.
if file_format != 'tfrecords_gzip':
if file_format not in _SUPPORTED_FILE_FORMATS:
raise NotImplementedError(f'{file_format} is not supported yet.')

if read_as_raw_records:
Expand Down Expand Up @@ -330,6 +337,12 @@ def make_tfxio(
telemetry_descriptors=telemetry_descriptors,
raw_record_column_name=raw_record_column_name)

if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PARQUET:
return parquet_tfxio.ParquetTFXIO(
file_pattern=file_pattern,
schema=schema,
telemetry_descriptors=telemetry_descriptors)

raise NotImplementedError(
'Unsupport payload format: {}'.format(payload_format))

Expand Down
21 changes: 11 additions & 10 deletions tfx/components/util/tfxio_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tfx.proto import example_gen_pb2
from tfx.types import standard_artifacts
from tfx_bsl.coders import tf_graph_record_decoder
from tfx_bsl.tfxio import parquet_tfxio
from tfx_bsl.tfxio import raw_tf_record
from tfx_bsl.tfxio import record_based_tfxio
from tfx_bsl.tfxio import record_to_tensor_tfxio
Expand Down Expand Up @@ -70,6 +71,10 @@
read_as_raw_records=True,
raw_record_column_name=_RAW_RECORD_COLUMN_NAME,
expected_tfxio_type=raw_tf_record.RawTfRecordTFXIO),
dict(
testcase_name='parquet_payload_format',
payload_format=example_gen_pb2.PayloadFormat.FORMAT_PARQUET,
expected_tfxio_type=parquet_tfxio.ParquetTFXIO),
]

_RESOLVE_TEST_CASES = [
Expand Down Expand Up @@ -181,11 +186,9 @@ def test_make_tfxio(self, payload_format, expected_tfxio_type,
_FAKE_FILE_PATTERN, _TELEMETRY_DESCRIPTORS, payload_format,
data_view_uri, _SCHEMA, read_as_raw_records, raw_record_column_name)
self.assertIsInstance(tfxio, expected_tfxio_type)
# We currently only create RecordBasedTFXIO and the check below relies on
# that.
self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO):
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
# Since we provide a schema, ArrowSchema() should not raise.
_ = tfxio.ArrowSchema()

Expand Down Expand Up @@ -219,11 +222,9 @@ def test_get_tfxio_factory_from_artifact(self,
raw_record_column_name)
tfxio = tfxio_factory(_FAKE_FILE_PATTERN)
self.assertIsInstance(tfxio, expected_tfxio_type)
# We currently only create RecordBasedTFXIO and the check below relies on
# that.
self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
if isinstance(tfxio, record_based_tfxio.RecordBasedTFXIO):
self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
# Since we provide a schema, ArrowSchema() should not raise.
_ = tfxio.ArrowSchema()

Expand Down Expand Up @@ -306,7 +307,7 @@ def test_get_tf_dataset_factory_from_artifact(self):
dataset_factory = tfxio_utils.get_tf_dataset_factory_from_artifact(
[examples], _TELEMETRY_DESCRIPTORS)
self.assertIsInstance(dataset_factory, Callable)
self.assertEqual(tf.data.Dataset,
self.assertEqual('tf.data.Dataset',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this stringify of the signature is a result of adding: from __future__ import annotations

inspect.signature(dataset_factory).return_annotation)

def test_get_record_batch_factory_from_artifact(self):
Expand All @@ -317,7 +318,7 @@ def test_get_record_batch_factory_from_artifact(self):
record_batch_factory = tfxio_utils.get_record_batch_factory_from_artifact(
[examples], _TELEMETRY_DESCRIPTORS)
self.assertIsInstance(record_batch_factory, Callable)
self.assertEqual(Iterator[pa.RecordBatch],
self.assertEqual('Iterator[pa.RecordBatch]',
inspect.signature(record_batch_factory).return_annotation)

def test_raise_if_data_view_uri_not_available(self):
Expand Down
12 changes: 10 additions & 2 deletions tfx/proto/example_gen.proto
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ enum PayloadFormat {
// Serialized any protocol buffer.
FORMAT_PROTO = 11;

reserved 1 to 5, 8 to 10, 12 to max;
// Serialized parquet messages.
FORMAT_PARQUET = 15;

reserved 1 to 5, 8 to 10, 12 to 14, 16 to max;
}

// Enum to indicate file format that ExampleGen produces.
Expand All @@ -122,7 +125,12 @@ enum FileFormat {
// Indicates TFRecords format files with gzip compression.
FORMAT_TFRECORDS_GZIP = 5;

reserved 1 to 4, 6 to max;
// Indicates parquet format files in any of the supported compressions.
// protolint:disable MAX_LINE_LENGTH
// https://arrow.apache.org/docs/python/parquet.html#compression-encoding-and-file-compatibility
FILE_FORMAT_PARQUET = 16;

reserved 1 to 4, 7 to 15, 17 to max;
}

// Specification of the output of the example gen.
Expand Down