Skip to content

Commit

Permalink
Implement DeduplicateTensorPerRow in MLTransform (#31307)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmccluskey authored May 16, 2024
1 parent 6cb30cc commit 118c514
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
22 changes: 22 additions & 0 deletions sdks/python/apache_beam/ml/transforms/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,25 @@ def apply_transform(
name=self.name)
}
return output_dict


@register_input_dtype(str)
class DeduplicateTensorPerRow(TFTOperation):
def __init__(self, columns: List[str], name: Optional[str] = None):
""" Deduplicates each row (0th dimension) of the provided tensor.
Args:
columns: A list of the columns to apply the transformation on.
name: optional. A name for this operation.
"""
self.name = name
super().__init__(columns)

def apply_transform(
self, data: common_types.TensorType,
output_col_name: str) -> Dict[str, common_types.TensorType]:
output_dict = {
output_col_name: tft.deduplicate_tensor_per_row(
input_tensor=data, name=self.name)
}
return output_dict
62 changes: 62 additions & 0 deletions sdks/python/apache_beam/ml/transforms/tft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,5 +1009,67 @@ def test_multi_buckets_multi_string(self):
assert_that(result, equal_to(expected_values, equals_fn=np.array_equal))


class DeduplicateTensorPerRowTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.artifact_location)

def test_deduplicate(self):
values = [{
'x': [b'a', b'b', b'a', b'b'],
}, {
'x': [b'b', b'c', b'b', b'c']
}]

expected_output = [np.array([b'a', b'b']), np.array([b'b', b'c'])]
with beam.Pipeline() as p:
list_result = (
p
| "listCreate" >> beam.Create(values)
| "listMLTransform" >> base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
tft.DeduplicateTensorPerRow(columns=['x'])))
result = (list_result | beam.Map(lambda x: x.x))
assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))

def test_deduplicate_no_op(self):
values = [{
'x': [b'a', b'b'],
}, {
'x': [b'c', b'd']
}]

expected_output = [np.array([b'a', b'b']), np.array([b'c', b'd'])]
with beam.Pipeline() as p:
list_result = (
p
| "listCreate" >> beam.Create(values)
| "listMLTransform" >> base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
tft.DeduplicateTensorPerRow(columns=['x'])))
result = (list_result | beam.Map(lambda x: x.x))
assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))

def test_deduplicate_different_output_sizes(self):
values = [{
'x': [b'a', b'b', b'a', b'b'],
}, {
'x': [b'c', b'a', b'd', b'd']
}]

expected_output = [np.array([b'a', b'b']), np.array([b'c', b'a', b'd'])]
with beam.Pipeline() as p:
list_result = (
p
| "listCreate" >> beam.Create(values)
| "listMLTransform" >> base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
tft.DeduplicateTensorPerRow(columns=['x'])))
result = (list_result | beam.Map(lambda x: x.x))
assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))


if __name__ == '__main__':
unittest.main()

0 comments on commit 118c514

Please sign in to comment.