Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove CoGBK in MLTransform's TFTProcessHandler #30146

Merged
merged 9 commits into from
Feb 13, 2024
149 changes: 49 additions & 100 deletions sdks/python/apache_beam/ml/transforms/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
# pytype: skip-file

import collections
import copy
import os
import typing
import uuid
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
Expand All @@ -31,6 +32,7 @@
import apache_beam as beam
import tensorflow as tf
import tensorflow_transform.beam as tft_beam
from apache_beam import coders
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.transforms.base import ArtifactMode
from apache_beam.ml.transforms.base import ProcessHandler
Expand All @@ -50,7 +52,7 @@
'TFTProcessHandler',
]

_ID_COLUMN = 'tmp_uuid' # Name for a temporary column.
_TEMP_KEY = 'CODED_SAMPLE' # key for the encoded sample

RAW_DATA_METADATA_DIR = 'raw_data_metadata'
SCHEMA_FILE = 'schema.pbtxt'
Expand Down Expand Up @@ -83,20 +85,52 @@
tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]]


class DataCoder:
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, exclude_columns=None):
"""
Uses PickleCoder to encode/decode the dictonaries.
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
Args:
exclude_columns: list of columns to exclude from the encoding.
"""
self.coder = coders.registry.get_coder(Any)
self.exclude_columns = exclude_columns

def set_unused_columns(self, exclude_columns):
self.exclude_columns = exclude_columns

def encode(self, element):
if not self.exclude_columns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. Is it possible for exclude_columns be emtpy? I'd imagine it could rather be the opposite, where all columns are being processed, so there is nothing to encode/decode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that is right but it errors because we are adding the temp id column name to the schema during construction so TFT errors out if the pcoll doesn't have the temp id column. So when the unused columns are none, we have to encode the empty dict and pass it to the PColl.

return element
data_to_encode = element.copy()
for key in self.exclude_columns:
if key in data_to_encode:
del data_to_encode[key]

element[_TEMP_KEY] = self.coder.encode(data_to_encode)
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
return element

def decode(self, element):
if not self.exclude_columns:
return element
clone = copy.copy(element)
clone.update(self.coder.decode(clone[_TEMP_KEY].item()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the function of .item() here? what is the type of clone[_TEMP_KEY]? are the elements in given that we call .item() here - will elements in clone have consistent type after decoding?

Copy link
Contributor Author

@AnandInguva AnandInguva Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type of clone[_TEMP_KEY] is a numpy array and .item() returns underlying element of the numpy array.

will elements in clone have consistent type after decoding.

It should be. depending on the Coder.

del clone[_TEMP_KEY]
return clone


class _ConvertScalarValuesToListValues(beam.DoFn):
def process(
self,
element,
):
id, element = element
new_dict = {}
for key, value in element.items():
if isinstance(value,
tuple(_primitive_types_to_typing_container_type.keys())):
new_dict[key] = [value]
else:
new_dict[key] = value
yield (id, new_dict)
yield new_dict


class _ConvertNamedTupleToDict(
Expand Down Expand Up @@ -124,79 +158,6 @@ def expand(
return pcoll | beam.Map(lambda x: x._asdict())


class _ComputeAndAttachUniqueID(beam.DoFn):
"""
Computes and attaches a unique id to each element in the PCollection.
"""
def process(self, element):
# UUID1 includes machine-specific bits and has a counter. As long as not too
# many are generated at the same time, they should be unique.
# UUID4 generation should be unique in practice as long as underlying random
# number generation is not compromised.
# A combintation of both should avoid the anecdotal pitfalls where
# replacing one with the other has helped some users.
# UUID collision will result in data loss, but we can detect that and fail.

# TODO(https://github.com/apache/beam/issues/29593): Evaluate MLTransform
# implementation without CoGBK.
unique_key = uuid.uuid1().bytes + uuid.uuid4().bytes
yield (unique_key, element)


class _GetMissingColumns(beam.DoFn):
"""
Returns data containing only the columns that are not
present in the schema. This is needed since TFT only outputs
columns that are transformed by any of the data processing transforms.
"""
def __init__(self, existing_columns):
self.existing_columns = existing_columns

def process(self, element):
id, row_dict = element
new_dict = {
k: v
for k, v in row_dict.items() if k not in self.existing_columns
}
yield (id, new_dict)


class _MakeIdAsColumn(beam.DoFn):
"""
Extracts the id from the element and adds it as a column instead.
"""
def process(self, element):
id, element = element
element[_ID_COLUMN] = id
yield element


class _ExtractIdAndKeyPColl(beam.DoFn):
"""
Extracts the id and return id and element as a tuple.
"""
def process(self, element):
id = element[_ID_COLUMN][0]
del element[_ID_COLUMN]
yield (id, element)


class _MergeDicts(beam.DoFn):
"""
Merges processed and unprocessed columns from CoGBK result into a single row.
"""
def process(self, element):
unused_row_id, row_dicts_tuple = element
new_dict = {}
for d in row_dicts_tuple:
# After CoGBK, dicts with processed and unprocessed portions of each row
# are wrapped in 1-element lists, since all rows have a unique id.
# Assertion could fail due to UUID collision.
assert len(d) == 1, f"Expected 1 element, got: {len(d)}."
new_dict.update(d[0])
yield new_dict


class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
tft_process_handler_output_type]):
def __init__(
Expand All @@ -215,6 +176,7 @@ def __init__(
self.artifact_mode = artifact_mode
if artifact_mode not in ['produce', 'consume']:
raise ValueError('artifact_mode must be either `produce` or `consume`.')
self.data_coder = DataCoder()
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved

def append_transform(self, transform):
self.transforms.append(transform)
Expand Down Expand Up @@ -325,7 +287,7 @@ def _get_raw_data_feature_spec_per_column(
def get_raw_data_metadata(
self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
raw_data_feature_spec[_ID_COLUMN] = tf.io.VarLenFeature(dtype=tf.string)
raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string)
return self.convert_raw_data_feature_spec_to_dataset_metadata(
raw_data_feature_spec)

Expand Down Expand Up @@ -447,22 +409,22 @@ def expand(
raw_data_metadata = metadata_io.read_metadata(
os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))

keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
keyed_raw_data = (raw_data) # | beam.ParDo(_ComputeAndAttachUniqueID()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover comment, also we no longer add keys , so keyed_ might not be the best name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the keyed_ from variable names


feature_set = [feature.name for feature in raw_data_metadata.schema.feature]
keyed_columns_not_in_schema = (
keyed_raw_data
| beam.ParDo(_GetMissingColumns(feature_set)))
self.data_coder.set_unused_columns(exclude_columns=feature_set)

# To maintain consistency by outputting numpy array all the time,
# whether a scalar value or list or np array is passed as input,
# we will convert scalar values to list values and TFT will ouput
# numpy array all the time.
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
raw_data_list = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, why is this called raw_data_list? it's modified, so not raw i think, and what's here about _list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is modified. I removed raw from the variable name.

_list: we convert the scalar element to list (len:1) to maintain uniformity. Users can pass list/np arrays to TFT ops and TFT outputs numpy arrays. Users when pass scalars, TFT outputs scalars. to maintain consistent output format, we convert scalar to list.

keyed_raw_data
| beam.Map(lambda elem: self.data_coder.encode(elem)))

keyed_raw_data = keyed_raw_data | beam.ParDo(
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
_ConvertScalarValuesToListValues())

raw_data_list = (keyed_raw_data | beam.ParDo(_MakeIdAsColumn()))

with tft_beam.Context(temp_dir=self.artifact_location):
data = (raw_data_list, raw_data_metadata)
if self.artifact_mode == ArtifactMode.PRODUCE:
Expand Down Expand Up @@ -492,26 +454,13 @@ def expand(
# So we will use a RowTypeConstraint to create a schema'd PCollection.
# this is needed since new columns are included in the
# transformed_dataset.
del self.transformed_schema[_ID_COLUMN]
del self.transformed_schema[_TEMP_KEY]
row_type = RowTypeConstraint.from_fields(
list(self.transformed_schema.items()))

# If a non schema PCollection is passed, and one of the input columns
# is not transformed by any of the transforms, then the output will
# not have that column. So we will join the missing columns from the
# raw_data to the transformed_dataset.
keyed_transformed_dataset = (
transformed_dataset | beam.ParDo(_ExtractIdAndKeyPColl()))

# The grouping is needed here since tensorflow transform only outputs
# columns that are transformed by any of the transforms. So we will
# join the missing columns from the raw_data to the transformed_dataset
# using the id.
transformed_dataset = (
(keyed_transformed_dataset, keyed_columns_not_in_schema)
| beam.CoGroupByKey()
| beam.ParDo(_MergeDicts()))

transformed_dataset
| "DecodeDict" >> beam.Map(lambda x: self.data_coder.decode(x)))
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
# The schema only contains the columns that are transformed.
transformed_dataset = (
transformed_dataset
Expand Down
Loading