Skip to content

Commit

Permalink
Add TFXIO for reading parquet (#52)
Browse files Browse the repository at this point in the history
* Add TFXIO for reading parquet

* Fix unit tests

* Implement telemetry

* Format files

* Address PR comments

* Infer parquet schema if no tf.schema passed

* Make keyword only arguments to tfxio
  • Loading branch information
martinbomio authored Mar 15, 2022
1 parent 01b5ff5 commit d713de3
Show file tree
Hide file tree
Showing 2 changed files with 436 additions and 0 deletions.
148 changes: 148 additions & 0 deletions tfx_bsl/tfxio/parquet_tfxio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFXIO implementation for Parquet."""

import copy
from typing import Optional, List, Text, Any

import apache_beam as beam
import pyarrow as pa
import pyarrow.parquet as pq
import tensorflow as tf
from apache_beam.io.filesystems import FileSystems
from tensorflow_metadata.proto.v0 import schema_pb2

from tfx_bsl.coders import csv_decoder

from tfx_bsl.tfxio import dataset_options, tensor_adapter, tensor_representation_util, telemetry
from tfx_bsl.tfxio.tfxio import TFXIO

_PARQUET_FORMAT = "parquet"


class ParquetTFXIO(TFXIO):
"""TFXIO implementation for Parquet."""

def __init__(self,
file_pattern: Text,
column_names: List[Text],
*,
min_bundle_size: int = 0,
schema: Optional[schema_pb2.Schema] = None,
validate: Optional[bool] = True,
telemetry_descriptors: Optional[List[Text]] = None):
"""Initializes a Parquet TFXIO.
Args:
file_pattern: A file glob pattern to read parquet files from.
column_names: List of column names to read from the parquet files.
min_bundle_size: the minimum size in bytes, to be considered when
splitting the parquet input into bundles.
schema: An optional TFMD Schema describing the dataset. If schema is
provided, it will determine the data type of the parquet columns. Otherwise,
the each column's data type will be inferred by the decoder.
validate: Boolean flag to verify that the files exist during the pipeline
creation time.
telemetry_descriptors: A set of descriptors that identify the component
that is instantiating this TFXIO. These will be used to construct the
namespace to contain metrics for profiling and are therefore expected to
be identifiers of the component itself and not individual instances of
source use.
"""
self._file_pattern = file_pattern
self._column_names = column_names
self._min_bundle_size = min_bundle_size
self._validate = validate
self._schema = schema
self._telemetry_descriptors = telemetry_descriptors

def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform:

@beam.typehints.with_input_types(Any)
@beam.typehints.with_output_types(pa.RecordBatch)
def _PTransformFn(pcoll_or_pipeline: Any):
"""Reads Parquet tables and converts to RecordBatches."""
return (
pcoll_or_pipeline | "ParquetBeamSource" >>
beam.io.ReadFromParquetBatched(file_pattern=self._file_pattern,
min_bundle_size=self._min_bundle_size,
validate=self._validate,
columns=self._column_names) |
"ToRecordBatch" >> beam.FlatMap(self._TableToRecordBatch, batch_size)
| "CollectRecordBatchTelemetry" >> telemetry.ProfileRecordBatches(
self._telemetry_descriptors, _PARQUET_FORMAT, _PARQUET_FORMAT))

return beam.ptransform_fn(_PTransformFn)()

def RecordBatches(self, options: dataset_options.RecordBatchesOptions):
raise NotImplementedError

def TensorFlowDataset(
self,
options: dataset_options.TensorFlowDatasetOptions) -> tf.data.Dataset:
raise NotImplementedError

def _TableToRecordBatch(
self,
table: pa.Table,
batch_size: Optional[int] = None) -> List[pa.RecordBatch]:
return table.to_batches(max_chunksize=batch_size)

def ArrowSchema(self) -> pa.Schema:
if self._schema is None:
return self._InferArrowSchema()
return csv_decoder.GetArrowSchema(self._column_names, self._schema)

def _InferArrowSchema(self):
match_result = FileSystems.match([self._file_pattern])[0]
files_metadata = match_result.metadata_list[0]
with FileSystems.open(files_metadata.path) as f:
return pq.read_schema(f)

def TensorRepresentations(self) -> tensor_adapter.TensorRepresentations:
result = (tensor_representation_util.GetTensorRepresentationsFromSchema(
self._schema))
if result is None:
result = (tensor_representation_util.InferTensorRepresentationsFromSchema(
self._schema))
return result

def _ProjectTfmdSchema(self, column_names: List[Text]) -> schema_pb2.Schema:
"""Creates a tensorflow Schema from the current one with only the given columns"""

result = schema_pb2.Schema()
result.CopyFrom(self._schema)

for feature in self._schema.feature:
if feature.name not in column_names:
result.feature.remove(feature)

return result

def _ProjectImpl(self, tensor_names: List[Text]) -> "TFXIO":
"""Returns a projected TFXIO.
Projection is pushed down to the Parquet Beam source.
The Projected TFXIO will project the record batches, arrow schema,
and the tfmd schema.
Args:
tensor_names: The columns to project.
"""
projected_schema = self._ProjectTfmdSchema(tensor_names)
result = copy.copy(self)
result._column_names = tensor_names # pylint: disable=protected-access
result._schema = projected_schema # pylint: disable=protected-access
return result
Loading

0 comments on commit d713de3

Please sign in to comment.