Skip to content

Commit

Permalink
Replace the CoGBK and utils with Encode and Decode utils
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Jan 29, 2024
1 parent abc1522 commit ee770c4
Showing 1 changed file with 13 additions and 99 deletions.
112 changes: 13 additions & 99 deletions sdks/python/apache_beam/ml/transforms/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import collections
import os
import typing
import uuid
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
Expand All @@ -31,6 +31,7 @@
import apache_beam as beam
import tensorflow as tf
import tensorflow_transform.beam as tft_beam
from apache_beam.internal import pickler
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.transforms.base import ArtifactMode
from apache_beam.ml.transforms.base import ProcessHandler
Expand All @@ -50,7 +51,7 @@
'TFTProcessHandler',
]

_ID_COLUMN = 'tmp_uuid' # Name for a temporary column.
_TEMP_KEY = 'CODED_SAMPLE' # key for the encoded sample

RAW_DATA_METADATA_DIR = 'raw_data_metadata'
SCHEMA_FILE = 'schema.pbtxt'
Expand Down Expand Up @@ -129,7 +130,7 @@ def process(
new_dict[key] = [value]
else:
new_dict[key] = value
yield (id, new_dict)
yield new_dict


class _ConvertNamedTupleToDict(
Expand Down Expand Up @@ -157,79 +158,6 @@ def expand(
return pcoll | beam.Map(lambda x: x._asdict())


class _ComputeAndAttachUniqueID(beam.DoFn):
"""
Computes and attaches a unique id to each element in the PCollection.
"""
def process(self, element):
# UUID1 includes machine-specific bits and has a counter. As long as not too
# many are generated at the same time, they should be unique.
# UUID4 generation should be unique in practice as long as underlying random
# number generation is not compromised.
# A combintation of both should avoid the anecdotal pitfalls where
# replacing one with the other has helped some users.
# UUID collision will result in data loss, but we can detect that and fail.

# TODO(https://github.com/apache/beam/issues/29593): Evaluate MLTransform
# implementation without CoGBK.
unique_key = uuid.uuid1().bytes + uuid.uuid4().bytes
yield (unique_key, element)


class _GetMissingColumns(beam.DoFn):
"""
Returns data containing only the columns that are not
present in the schema. This is needed since TFT only outputs
columns that are transformed by any of the data processing transforms.
"""
def __init__(self, existing_columns):
self.existing_columns = existing_columns

def process(self, element):
id, row_dict = element
new_dict = {
k: v
for k, v in row_dict.items() if k not in self.existing_columns
}
yield (id, new_dict)


class _MakeIdAsColumn(beam.DoFn):
"""
Extracts the id from the element and adds it as a column instead.
"""
def process(self, element):
id, element = element
element[_ID_COLUMN] = id
yield element


class _ExtractIdAndKeyPColl(beam.DoFn):
"""
Extracts the id and return id and element as a tuple.
"""
def process(self, element):
id = element[_ID_COLUMN][0]
del element[_ID_COLUMN]
yield (id, element)


class _MergeDicts(beam.DoFn):
"""
Merges processed and unprocessed columns from CoGBK result into a single row.
"""
def process(self, element):
unused_row_id, row_dicts_tuple = element
new_dict = {}
for d in row_dicts_tuple:
# After CoGBK, dicts with processed and unprocessed portions of each row
# are wrapped in 1-element lists, since all rows have a unique id.
# Assertion could fail due to UUID collision.
assert len(d) == 1, f"Expected 1 element, got: {len(d)}."
new_dict.update(d[0])
yield new_dict


class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
tft_process_handler_output_type]):
def __init__(
Expand Down Expand Up @@ -358,7 +286,7 @@ def _get_raw_data_feature_spec_per_column(
def get_raw_data_metadata(
self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
raw_data_feature_spec[_ID_COLUMN] = tf.io.VarLenFeature(dtype=tf.string)
raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string)
return self.convert_raw_data_feature_spec_to_dataset_metadata(
raw_data_feature_spec)

Expand Down Expand Up @@ -480,22 +408,21 @@ def expand(
raw_data_metadata = metadata_io.read_metadata(
os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))

keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
keyed_raw_data = (raw_data) # | beam.ParDo(_ComputeAndAttachUniqueID()))

feature_set = [feature.name for feature in raw_data_metadata.schema.feature]
keyed_columns_not_in_schema = (
keyed_raw_data
| beam.ParDo(_GetMissingColumns(feature_set)))

# To maintain consistency by outputting numpy array all the time,
# whether a scalar value or list or np array is passed as input,
# we will convert scalar values to list values and TFT will ouput
# numpy array all the time.
raw_data_list = (
keyed_raw_data
| beam.ParDo(_EncodeDict(exclude_columns=feature_set)))

keyed_raw_data = keyed_raw_data | beam.ParDo(
_ConvertScalarValuesToListValues())

raw_data_list = (keyed_raw_data | beam.ParDo(_MakeIdAsColumn()))

with tft_beam.Context(temp_dir=self.artifact_location):
data = (raw_data_list, raw_data_metadata)
if self.artifact_mode == ArtifactMode.PRODUCE:
Expand Down Expand Up @@ -525,26 +452,13 @@ def expand(
# So we will use a RowTypeConstraint to create a schema'd PCollection.
# this is needed since new columns are included in the
# transformed_dataset.
del self.transformed_schema[_ID_COLUMN]
del self.transformed_schema[_TEMP_KEY]
row_type = RowTypeConstraint.from_fields(
list(self.transformed_schema.items()))

# If a non schema PCollection is passed, and one of the input columns
# is not transformed by any of the transforms, then the output will
# not have that column. So we will join the missing columns from the
# raw_data to the transformed_dataset.
keyed_transformed_dataset = (
transformed_dataset | beam.ParDo(_ExtractIdAndKeyPColl()))

# The grouping is needed here since tensorflow transform only outputs
# columns that are transformed by any of the transforms. So we will
# join the missing columns from the raw_data to the transformed_dataset
# using the id.
transformed_dataset = (
(keyed_transformed_dataset, keyed_columns_not_in_schema)
| beam.CoGroupByKey()
| beam.ParDo(_MergeDicts()))

transformed_dataset
| "DecodeDict" >> beam.ParDo(_DecodeDict()))
# The schema only contains the columns that are transformed.
transformed_dataset = (
transformed_dataset
Expand Down

0 comments on commit ee770c4

Please sign in to comment.