diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 3c37ddef1ed5..5bcd0d165761 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -17,9 +17,10 @@ # pytype: skip-file import collections +import copy import os import typing -import uuid +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -31,6 +32,7 @@ import apache_beam as beam import tensorflow as tf import tensorflow_transform.beam as tft_beam +from apache_beam import coders from apache_beam.io.filesystems import FileSystems from apache_beam.ml.transforms.base import ArtifactMode from apache_beam.ml.transforms.base import ProcessHandler @@ -50,7 +52,7 @@ 'TFTProcessHandler', ] -_ID_COLUMN = 'tmp_uuid' # Name for a temporary column. +_TEMP_KEY = 'CODED_SAMPLE' # key for the encoded sample RAW_DATA_METADATA_DIR = 'raw_data_metadata' SCHEMA_FILE = 'schema.pbtxt' @@ -83,12 +85,41 @@ tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]] +class _DataCoder: + def __init__( + self, + exclude_columns, + coder=coders.registry.get_coder(Any), + ): + """ + Encodes/decodes items of a dictionary into a single element. + Args: + exclude_columns: list of columns to exclude from the encoding. + """ + self.coder = coder + self.exclude_columns = exclude_columns + + def encode(self, element): + data_to_encode = element.copy() + element_to_return = element.copy() + for key in self.exclude_columns: + if key in data_to_encode: + del data_to_encode[key] + element_to_return[_TEMP_KEY] = self.coder.encode(data_to_encode) + return element_to_return + + def decode(self, element): + clone = copy.copy(element) + clone.update(self.coder.decode(clone[_TEMP_KEY].item())) + del clone[_TEMP_KEY] + return clone + + class _ConvertScalarValuesToListValues(beam.DoFn): def process( self, element, ): - id, element = element new_dict = {} for key, value in element.items(): if isinstance(value, @@ -96,7 +127,7 @@ def process( new_dict[key] = [value] else: new_dict[key] = value - yield (id, new_dict) + yield new_dict class _ConvertNamedTupleToDict( @@ -124,79 +155,6 @@ def expand( return pcoll | beam.Map(lambda x: x._asdict()) -class _ComputeAndAttachUniqueID(beam.DoFn): - """ - Computes and attaches a unique id to each element in the PCollection. - """ - def process(self, element): - # UUID1 includes machine-specific bits and has a counter. As long as not too - # many are generated at the same time, they should be unique. - # UUID4 generation should be unique in practice as long as underlying random - # number generation is not compromised. - # A combintation of both should avoid the anecdotal pitfalls where - # replacing one with the other has helped some users. - # UUID collision will result in data loss, but we can detect that and fail. - - # TODO(https://github.com/apache/beam/issues/29593): Evaluate MLTransform - # implementation without CoGBK. - unique_key = uuid.uuid1().bytes + uuid.uuid4().bytes - yield (unique_key, element) - - -class _GetMissingColumns(beam.DoFn): - """ - Returns data containing only the columns that are not - present in the schema. This is needed since TFT only outputs - columns that are transformed by any of the data processing transforms. - """ - def __init__(self, existing_columns): - self.existing_columns = existing_columns - - def process(self, element): - id, row_dict = element - new_dict = { - k: v - for k, v in row_dict.items() if k not in self.existing_columns - } - yield (id, new_dict) - - -class _MakeIdAsColumn(beam.DoFn): - """ - Extracts the id from the element and adds it as a column instead. - """ - def process(self, element): - id, element = element - element[_ID_COLUMN] = id - yield element - - -class _ExtractIdAndKeyPColl(beam.DoFn): - """ - Extracts the id and return id and element as a tuple. - """ - def process(self, element): - id = element[_ID_COLUMN][0] - del element[_ID_COLUMN] - yield (id, element) - - -class _MergeDicts(beam.DoFn): - """ - Merges processed and unprocessed columns from CoGBK result into a single row. - """ - def process(self, element): - unused_row_id, row_dicts_tuple = element - new_dict = {} - for d in row_dicts_tuple: - # After CoGBK, dicts with processed and unprocessed portions of each row - # are wrapped in 1-element lists, since all rows have a unique id. - # Assertion could fail due to UUID collision. - assert len(d) == 1, f"Expected 1 element, got: {len(d)}." - new_dict.update(d[0]) - yield new_dict - - class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type, tft_process_handler_output_type]): def __init__( @@ -325,7 +283,7 @@ def _get_raw_data_feature_spec_per_column( def get_raw_data_metadata( self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: raw_data_feature_spec = self.get_raw_data_feature_spec(input_types) - raw_data_feature_spec[_ID_COLUMN] = tf.io.VarLenFeature(dtype=tf.string) + raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string) return self.convert_raw_data_feature_spec_to_dataset_metadata( raw_data_feature_spec) @@ -403,7 +361,6 @@ def expand( artifact_location, which was previously used to store the produced artifacts. """ - if self.artifact_mode == ArtifactMode.PRODUCE: # If we are computing artifacts, we should fail for windows other than # default windowing since for example, for a fixed window, each window can @@ -447,24 +404,29 @@ def expand( raw_data_metadata = metadata_io.read_metadata( os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR)) - keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID())) - feature_set = [feature.name for feature in raw_data_metadata.schema.feature] - keyed_columns_not_in_schema = ( - keyed_raw_data - | beam.ParDo(_GetMissingColumns(feature_set))) + + # TFT ignores columns in the input data that aren't explicitly defined + # in the schema. This is because TFT operations + # are designed to work with a predetermined schema. + # To preserve these extra columns without disrupting TFT processing, + # they are temporarily encoded as bytes and added to the PCollection with + # a unique identifier + data_coder = _DataCoder(exclude_columns=feature_set) + data_with_encoded_columns = ( + raw_data + | "EncodeUnmodifiedColumns" >> + beam.Map(lambda elem: data_coder.encode(elem))) # To maintain consistency by outputting numpy array all the time, # whether a scalar value or list or np array is passed as input, - # we will convert scalar values to list values and TFT will ouput + # we will convert scalar values to list values and TFT will ouput # numpy array all the time. - keyed_raw_data = keyed_raw_data | beam.ParDo( + data_list = data_with_encoded_columns | beam.ParDo( _ConvertScalarValuesToListValues()) - raw_data_list = (keyed_raw_data | beam.ParDo(_MakeIdAsColumn())) - with tft_beam.Context(temp_dir=self.artifact_location): - data = (raw_data_list, raw_data_metadata) + data = (data_list, raw_data_metadata) if self.artifact_mode == ArtifactMode.PRODUCE: transform_fn = ( data @@ -474,7 +436,7 @@ def expand( self.write_transform_artifacts(transform_fn, self.artifact_location) else: transform_fn = ( - raw_data_list.pipeline + data_list.pipeline | "ReadTransformFn" >> tft_beam.ReadTransformFn( self.artifact_location)) (transformed_dataset, transformed_metadata) = ( @@ -492,26 +454,15 @@ def expand( # So we will use a RowTypeConstraint to create a schema'd PCollection. # this is needed since new columns are included in the # transformed_dataset. - del self.transformed_schema[_ID_COLUMN] + del self.transformed_schema[_TEMP_KEY] row_type = RowTypeConstraint.from_fields( list(self.transformed_schema.items())) - # If a non schema PCollection is passed, and one of the input columns - # is not transformed by any of the transforms, then the output will - # not have that column. So we will join the missing columns from the - # raw_data to the transformed_dataset. - keyed_transformed_dataset = ( - transformed_dataset | beam.ParDo(_ExtractIdAndKeyPColl())) - - # The grouping is needed here since tensorflow transform only outputs - # columns that are transformed by any of the transforms. So we will - # join the missing columns from the raw_data to the transformed_dataset - # using the id. + # Decode the extra columns that were encoded as bytes. transformed_dataset = ( - (keyed_transformed_dataset, keyed_columns_not_in_schema) - | beam.CoGroupByKey() - | beam.ParDo(_MergeDicts())) - + transformed_dataset + | + "DecodeUnmodifiedColumns" >> beam.Map(lambda x: data_coder.decode(x))) # The schema only contains the columns that are transformed. transformed_dataset = ( transformed_dataset