From 118c51404d93132d4a0cae135f00ff68d4fd84bf Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Thu, 16 May 2024 10:29:04 -0400 Subject: [PATCH] Implement DeduplicateTensorPerRow in MLTransform (#31307) --- sdks/python/apache_beam/ml/transforms/tft.py | 22 +++++++ .../apache_beam/ml/transforms/tft_test.py | 62 +++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index 550dbedbc7ba..370043bc0d99 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -681,3 +681,25 @@ def apply_transform( name=self.name) } return output_dict + + +@register_input_dtype(str) +class DeduplicateTensorPerRow(TFTOperation): + def __init__(self, columns: List[str], name: Optional[str] = None): + """ Deduplicates each row (0th dimension) of the provided tensor. + + Args: + columns: A list of the columns to apply the transformation on. + name: optional. A name for this operation. + """ + self.name = name + super().__init__(columns) + + def apply_transform( + self, data: common_types.TensorType, + output_col_name: str) -> Dict[str, common_types.TensorType]: + output_dict = { + output_col_name: tft.deduplicate_tensor_per_row( + input_tensor=data, name=self.name) + } + return output_dict diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py index 6763032a8eba..5c42ecc012f9 100644 --- a/sdks/python/apache_beam/ml/transforms/tft_test.py +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -1009,5 +1009,67 @@ def test_multi_buckets_multi_string(self): assert_that(result, equal_to(expected_values, equals_fn=np.array_equal)) +class DeduplicateTensorPerRowTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_deduplicate(self): + values = [{ + 'x': [b'a', b'b', b'a', b'b'], + }, { + 'x': [b'b', b'c', b'b', b'c'] + }] + + expected_output = [np.array([b'a', b'b']), np.array([b'b', b'c'])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(values) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.DeduplicateTensorPerRow(columns=['x']))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_output, equals_fn=np.array_equal)) + + def test_deduplicate_no_op(self): + values = [{ + 'x': [b'a', b'b'], + }, { + 'x': [b'c', b'd'] + }] + + expected_output = [np.array([b'a', b'b']), np.array([b'c', b'd'])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(values) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.DeduplicateTensorPerRow(columns=['x']))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_output, equals_fn=np.array_equal)) + + def test_deduplicate_different_output_sizes(self): + values = [{ + 'x': [b'a', b'b', b'a', b'b'], + }, { + 'x': [b'c', b'a', b'd', b'd'] + }] + + expected_output = [np.array([b'a', b'b']), np.array([b'c', b'a', b'd'])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(values) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.DeduplicateTensorPerRow(columns=['x']))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_output, equals_fn=np.array_equal)) + + if __name__ == '__main__': unittest.main()