Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove CoGBK in MLTransform's TFTProcessHandler #30146

Merged
merged 9 commits into from
Feb 13, 2024
162 changes: 56 additions & 106 deletions sdks/python/apache_beam/ml/transforms/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
# pytype: skip-file

import collections
import copy
import os
import typing
import uuid
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
Expand All @@ -31,6 +32,7 @@
import apache_beam as beam
import tensorflow as tf
import tensorflow_transform.beam as tft_beam
from apache_beam import coders
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.transforms.base import ArtifactMode
from apache_beam.ml.transforms.base import ProcessHandler
Expand All @@ -50,7 +52,7 @@
'TFTProcessHandler',
]

_ID_COLUMN = 'tmp_uuid' # Name for a temporary column.
_TEMP_KEY = 'CODED_SAMPLE' # key for the encoded sample

RAW_DATA_METADATA_DIR = 'raw_data_metadata'
SCHEMA_FILE = 'schema.pbtxt'
Expand Down Expand Up @@ -83,20 +85,48 @@
tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]]


class DataCoder:
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
exclude_columns,
coder=coders.registry.get_coder(Any),
):
"""
Uses PickleCoder to encode/decode the dictonaries.
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
Args:
exclude_columns: list of columns to exclude from the encoding.
"""
self.coder = coder
self.exclude_columns = exclude_columns

def encode(self, element):
data_to_encode = element.copy()
for key in self.exclude_columns:
if key in data_to_encode:
del data_to_encode[key]
element[_TEMP_KEY] = self.coder.encode(data_to_encode)
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
return element

def decode(self, element):
clone = copy.copy(element)
clone.update(self.coder.decode(clone[_TEMP_KEY].item()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the function of .item() here? what is the type of clone[_TEMP_KEY]? are the elements in given that we call .item() here - will elements in clone have consistent type after decoding?

Copy link
Contributor Author

@AnandInguva AnandInguva Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type of clone[_TEMP_KEY] is a numpy array and .item() returns underlying element of the numpy array.

will elements in clone have consistent type after decoding.

It should be. depending on the Coder.

del clone[_TEMP_KEY]
return clone


class _ConvertScalarValuesToListValues(beam.DoFn):
def process(
self,
element,
):
id, element = element
new_dict = {}
for key, value in element.items():
if isinstance(value,
tuple(_primitive_types_to_typing_container_type.keys())):
new_dict[key] = [value]
else:
new_dict[key] = value
yield (id, new_dict)
yield new_dict


class _ConvertNamedTupleToDict(
Expand Down Expand Up @@ -124,79 +154,6 @@ def expand(
return pcoll | beam.Map(lambda x: x._asdict())


class _ComputeAndAttachUniqueID(beam.DoFn):
"""
Computes and attaches a unique id to each element in the PCollection.
"""
def process(self, element):
# UUID1 includes machine-specific bits and has a counter. As long as not too
# many are generated at the same time, they should be unique.
# UUID4 generation should be unique in practice as long as underlying random
# number generation is not compromised.
# A combintation of both should avoid the anecdotal pitfalls where
# replacing one with the other has helped some users.
# UUID collision will result in data loss, but we can detect that and fail.

# TODO(https://github.com/apache/beam/issues/29593): Evaluate MLTransform
# implementation without CoGBK.
unique_key = uuid.uuid1().bytes + uuid.uuid4().bytes
yield (unique_key, element)


class _GetMissingColumns(beam.DoFn):
"""
Returns data containing only the columns that are not
present in the schema. This is needed since TFT only outputs
columns that are transformed by any of the data processing transforms.
"""
def __init__(self, existing_columns):
self.existing_columns = existing_columns

def process(self, element):
id, row_dict = element
new_dict = {
k: v
for k, v in row_dict.items() if k not in self.existing_columns
}
yield (id, new_dict)


class _MakeIdAsColumn(beam.DoFn):
"""
Extracts the id from the element and adds it as a column instead.
"""
def process(self, element):
id, element = element
element[_ID_COLUMN] = id
yield element


class _ExtractIdAndKeyPColl(beam.DoFn):
"""
Extracts the id and return id and element as a tuple.
"""
def process(self, element):
id = element[_ID_COLUMN][0]
del element[_ID_COLUMN]
yield (id, element)


class _MergeDicts(beam.DoFn):
"""
Merges processed and unprocessed columns from CoGBK result into a single row.
"""
def process(self, element):
unused_row_id, row_dicts_tuple = element
new_dict = {}
for d in row_dicts_tuple:
# After CoGBK, dicts with processed and unprocessed portions of each row
# are wrapped in 1-element lists, since all rows have a unique id.
# Assertion could fail due to UUID collision.
assert len(d) == 1, f"Expected 1 element, got: {len(d)}."
new_dict.update(d[0])
yield new_dict


class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
tft_process_handler_output_type]):
def __init__(
Expand Down Expand Up @@ -325,7 +282,7 @@ def _get_raw_data_feature_spec_per_column(
def get_raw_data_metadata(
self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
raw_data_feature_spec[_ID_COLUMN] = tf.io.VarLenFeature(dtype=tf.string)
raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string)
return self.convert_raw_data_feature_spec_to_dataset_metadata(
raw_data_feature_spec)

Expand Down Expand Up @@ -403,7 +360,6 @@ def expand(
artifact_location, which was previously used to store the produced
artifacts.
"""

if self.artifact_mode == ArtifactMode.PRODUCE:
# If we are computing artifacts, we should fail for windows other than
# default windowing since for example, for a fixed window, each window can
Expand Down Expand Up @@ -447,24 +403,29 @@ def expand(
raw_data_metadata = metadata_io.read_metadata(
os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))

keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))

feature_set = [feature.name for feature in raw_data_metadata.schema.feature]
keyed_columns_not_in_schema = (
keyed_raw_data
| beam.ParDo(_GetMissingColumns(feature_set)))

# TFT ignores columns in the input data that aren't explicitly defined
# in the schema. This is because TFT operations
# are designed to work with a predetermined schema.
# To preserve these extra columns without disrupting TFT processing,
# they are temporarily encoded as bytes and added to the PCollection with
# a unique identifier
data_coder = DataCoder(exclude_columns=feature_set)
data_with_encoded_columns = (
raw_data
| "EncodeUnmodifiedColumns" >>
beam.Map(lambda elem: data_coder.encode(elem)))

# To maintain consistency by outputting numpy array all the time,
# whether a scalar value or list or np array is passed as input,
# we will convert scalar values to list values and TFT will ouput
# we will convert scalar values to list values and TFT will ouput
# numpy array all the time.
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
keyed_raw_data = keyed_raw_data | beam.ParDo(
data_list = data_with_encoded_columns | beam.ParDo(
_ConvertScalarValuesToListValues())

raw_data_list = (keyed_raw_data | beam.ParDo(_MakeIdAsColumn()))

with tft_beam.Context(temp_dir=self.artifact_location):
data = (raw_data_list, raw_data_metadata)
data = (data_list, raw_data_metadata)
if self.artifact_mode == ArtifactMode.PRODUCE:
transform_fn = (
data
Expand All @@ -474,7 +435,7 @@ def expand(
self.write_transform_artifacts(transform_fn, self.artifact_location)
else:
transform_fn = (
raw_data_list.pipeline
data_list.pipeline
| "ReadTransformFn" >> tft_beam.ReadTransformFn(
self.artifact_location))
(transformed_dataset, transformed_metadata) = (
Expand All @@ -492,26 +453,15 @@ def expand(
# So we will use a RowTypeConstraint to create a schema'd PCollection.
# this is needed since new columns are included in the
# transformed_dataset.
del self.transformed_schema[_ID_COLUMN]
del self.transformed_schema[_TEMP_KEY]
row_type = RowTypeConstraint.from_fields(
list(self.transformed_schema.items()))

# If a non schema PCollection is passed, and one of the input columns
# is not transformed by any of the transforms, then the output will
# not have that column. So we will join the missing columns from the
# raw_data to the transformed_dataset.
keyed_transformed_dataset = (
transformed_dataset | beam.ParDo(_ExtractIdAndKeyPColl()))

# The grouping is needed here since tensorflow transform only outputs
# columns that are transformed by any of the transforms. So we will
# join the missing columns from the raw_data to the transformed_dataset
# using the id.
# Decode the extra columns that were encoded as bytes.
transformed_dataset = (
(keyed_transformed_dataset, keyed_columns_not_in_schema)
| beam.CoGroupByKey()
| beam.ParDo(_MergeDicts()))

transformed_dataset
|
"DecodeUnmodifiedColumns" >> beam.Map(lambda x: data_coder.decode(x)))
# The schema only contains the columns that are transformed.
transformed_dataset = (
transformed_dataset
Expand Down
Loading