Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Feb 5, 2024
1 parent 094a888 commit 066c4ce
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions sdks/python/apache_beam/ml/transforms/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,30 @@


class DataCoder:
def __init__(self, exclude_columns=None):
def __init__(
self,
exclude_columns,
coder=coders.registry.get_coder(Any),
):
"""
Uses PickleCoder to encode/decode the dictonaries.
Args:
exclude_columns: list of columns to exclude from the encoding.
"""
self.coder = coders.registry.get_coder(Any)
self.exclude_columns = exclude_columns

def set_unused_columns(self, exclude_columns):
self.coder = coder
self.exclude_columns = exclude_columns

def encode(self, element):
if not self.exclude_columns:
return element
# if not set(element.keys()) - set(self.exclude_columns):
# return element
data_to_encode = element.copy()
for key in self.exclude_columns:
if key in data_to_encode:
del data_to_encode[key]

element[_TEMP_KEY] = self.coder.encode(data_to_encode)
return element

def decode(self, element):
if not self.exclude_columns:
return element
clone = copy.copy(element)
clone.update(self.coder.decode(clone[_TEMP_KEY].item()))
del clone[_TEMP_KEY]
Expand Down Expand Up @@ -176,7 +174,6 @@ def __init__(
self.artifact_mode = artifact_mode
if artifact_mode not in ['produce', 'consume']:
raise ValueError('artifact_mode must be either `produce` or `consume`.')
self.data_coder = DataCoder()

def append_transform(self, transform):
self.transforms.append(transform)
Expand Down Expand Up @@ -365,7 +362,6 @@ def expand(
artifact_location, which was previously used to store the produced
artifacts.
"""

if self.artifact_mode == ArtifactMode.PRODUCE:
# If we are computing artifacts, we should fail for windows other than
# default windowing since for example, for a fixed window, each window can
Expand Down Expand Up @@ -409,24 +405,29 @@ def expand(
raw_data_metadata = metadata_io.read_metadata(
os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))

keyed_raw_data = (raw_data) # | beam.ParDo(_ComputeAndAttachUniqueID()))

feature_set = [feature.name for feature in raw_data_metadata.schema.feature]
self.data_coder.set_unused_columns(exclude_columns=feature_set)

# TFT ignores columns in the input data that aren't explicitly defined
# in the schema. This is because TFT operations
# are designed to work with a predetermined schema.
# To preserve these extra columns without disrupting TFT processing,
# they are temporarily encoded as bytes and added to the PCollection with
# a unique identifier
data_coder = DataCoder(exclude_columns=feature_set)
data_with_encoded_columns = (
raw_data
| "EncodeUnmodifiedColumns" >>
beam.Map(lambda elem: data_coder.encode(elem)))

# To maintain consistency by outputting numpy array all the time,
# whether a scalar value or list or np array is passed as input,
# we will convert scalar values to list values and TFT will ouput
# we will convert scalar values to list values and TFT will ouput
# numpy array all the time.
raw_data_list = (
keyed_raw_data
| beam.Map(lambda elem: self.data_coder.encode(elem)))

keyed_raw_data = keyed_raw_data | beam.ParDo(
data_list = data_with_encoded_columns | beam.ParDo(
_ConvertScalarValuesToListValues())

with tft_beam.Context(temp_dir=self.artifact_location):
data = (raw_data_list, raw_data_metadata)
data = (data_list, raw_data_metadata)
if self.artifact_mode == ArtifactMode.PRODUCE:
transform_fn = (
data
Expand All @@ -436,7 +437,7 @@ def expand(
self.write_transform_artifacts(transform_fn, self.artifact_location)
else:
transform_fn = (
raw_data_list.pipeline
data_list.pipeline
| "ReadTransformFn" >> tft_beam.ReadTransformFn(
self.artifact_location))
(transformed_dataset, transformed_metadata) = (
Expand All @@ -458,9 +459,11 @@ def expand(
row_type = RowTypeConstraint.from_fields(
list(self.transformed_schema.items()))

# Decode the extra columns that were encoded as bytes.
transformed_dataset = (
transformed_dataset
| "DecodeDict" >> beam.Map(lambda x: self.data_coder.decode(x)))
|
"DecodeUnmodifiedColumns" >> beam.Map(lambda x: data_coder.decode(x)))
# The schema only contains the columns that are transformed.
transformed_dataset = (
transformed_dataset
Expand Down

0 comments on commit 066c4ce

Please sign in to comment.