diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 864e1017255a..e92516531dd9 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -86,32 +86,30 @@ class DataCoder: - def __init__(self, exclude_columns=None): + def __init__( + self, + exclude_columns, + coder=coders.registry.get_coder(Any), + ): """ Uses PickleCoder to encode/decode the dictonaries. Args: exclude_columns: list of columns to exclude from the encoding. """ - self.coder = coders.registry.get_coder(Any) - self.exclude_columns = exclude_columns - - def set_unused_columns(self, exclude_columns): + self.coder = coder self.exclude_columns = exclude_columns def encode(self, element): - if not self.exclude_columns: - return element + # if not set(element.keys()) - set(self.exclude_columns): + # return element data_to_encode = element.copy() for key in self.exclude_columns: if key in data_to_encode: del data_to_encode[key] - element[_TEMP_KEY] = self.coder.encode(data_to_encode) return element def decode(self, element): - if not self.exclude_columns: - return element clone = copy.copy(element) clone.update(self.coder.decode(clone[_TEMP_KEY].item())) del clone[_TEMP_KEY] @@ -176,7 +174,6 @@ def __init__( self.artifact_mode = artifact_mode if artifact_mode not in ['produce', 'consume']: raise ValueError('artifact_mode must be either `produce` or `consume`.') - self.data_coder = DataCoder() def append_transform(self, transform): self.transforms.append(transform) @@ -365,7 +362,6 @@ def expand( artifact_location, which was previously used to store the produced artifacts. """ - if self.artifact_mode == ArtifactMode.PRODUCE: # If we are computing artifacts, we should fail for windows other than # default windowing since for example, for a fixed window, each window can @@ -409,24 +405,29 @@ def expand( raw_data_metadata = metadata_io.read_metadata( os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR)) - keyed_raw_data = (raw_data) # | beam.ParDo(_ComputeAndAttachUniqueID())) - feature_set = [feature.name for feature in raw_data_metadata.schema.feature] - self.data_coder.set_unused_columns(exclude_columns=feature_set) + + # TFT ignores columns in the input data that aren't explicitly defined + # in the schema. This is because TFT operations + # are designed to work with a predetermined schema. + # To preserve these extra columns without disrupting TFT processing, + # they are temporarily encoded as bytes and added to the PCollection with + # a unique identifier + data_coder = DataCoder(exclude_columns=feature_set) + data_with_encoded_columns = ( + raw_data + | "EncodeUnmodifiedColumns" >> + beam.Map(lambda elem: data_coder.encode(elem))) # To maintain consistency by outputting numpy array all the time, # whether a scalar value or list or np array is passed as input, - # we will convert scalar values to list values and TFT will ouput + # we will convert scalar values to list values and TFT will ouput # numpy array all the time. - raw_data_list = ( - keyed_raw_data - | beam.Map(lambda elem: self.data_coder.encode(elem))) - - keyed_raw_data = keyed_raw_data | beam.ParDo( + data_list = data_with_encoded_columns | beam.ParDo( _ConvertScalarValuesToListValues()) with tft_beam.Context(temp_dir=self.artifact_location): - data = (raw_data_list, raw_data_metadata) + data = (data_list, raw_data_metadata) if self.artifact_mode == ArtifactMode.PRODUCE: transform_fn = ( data @@ -436,7 +437,7 @@ def expand( self.write_transform_artifacts(transform_fn, self.artifact_location) else: transform_fn = ( - raw_data_list.pipeline + data_list.pipeline | "ReadTransformFn" >> tft_beam.ReadTransformFn( self.artifact_location)) (transformed_dataset, transformed_metadata) = ( @@ -458,9 +459,11 @@ def expand( row_type = RowTypeConstraint.from_fields( list(self.transformed_schema.items())) + # Decode the extra columns that were encoded as bytes. transformed_dataset = ( transformed_dataset - | "DecodeDict" >> beam.Map(lambda x: self.data_coder.decode(x))) + | + "DecodeUnmodifiedColumns" >> beam.Map(lambda x: data_coder.decode(x))) # The schema only contains the columns that are transformed. transformed_dataset = ( transformed_dataset