From 6405b6b3f836160ed942cd76178478c8d237b3ea Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 3 Oct 2024 14:55:12 +0000 Subject: [PATCH 1/3] Try deepcopy combine_fn and fallback to pickling if TypeError. --- .../transforms/combinefn_lifecycle_test.py | 11 ++-- sdks/python/apache_beam/transforms/core.py | 59 +++++++++++++------ 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 62dbbc5fb77c..2a86f0251e75 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -53,15 +53,18 @@ def test_combining_value_state(self): @parameterized_class([ - {'runner': direct_runner.BundleBasedDirectRunner}, - {'runner': fn_api_runner.FnApiRunner}, -]) # yapf: disable + {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'}, + {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'}, + {'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'}, + {'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'}, + ]) # yapf: disable class LocalCombineFnLifecycleTest(unittest.TestCase): def tearDown(self): CallSequenceEnforcingCombineFn.instances.clear() def test_combine(self): - run_combine(TestPipeline(runner=self.runner())) + test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"]) + run_combine(TestPipeline(runner=self.runner(), options=test_options)) self._assert_teardown_called() def test_non_liftable_combine(self): diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index e7180bc093b0..24d1991f1769 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -3158,33 +3158,56 @@ def process(self, element): yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value)) class PreCombineFn(CombineFn): + def __init__(self): + # Deepcopy of the combine_fn to avoid sharing state between lifted + # stages when using cloudpickle. + try: + self._combine_fn_copy = copy.deepcopy(combine_fn) + except TypeError as e: + logging.warning( + 'Failed to copy combine function. Ensure python dependencies are' + ' properly set up: %s', + e) + self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) + + self.setup = self._combine_fn_copy.setup + self.create_accumulator = self._combine_fn_copy.create_accumulator + self.add_input = self._combine_fn_copy.add_input + self.merge_accumulators = self._combine_fn_copy.merge_accumulators + self.compact = self._combine_fn_copy.compact + self.teardown = self._combine_fn_copy.teardown + @staticmethod def extract_output(accumulator): # Boolean indicates this is an accumulator. return (True, accumulator) - setup = combine_fn.setup - create_accumulator = combine_fn.create_accumulator - add_input = combine_fn.add_input - merge_accumulators = combine_fn.merge_accumulators - compact = combine_fn.compact - teardown = combine_fn.teardown - class PostCombineFn(CombineFn): - @staticmethod - def add_input(accumulator, element): + def __init__(self): + # Deepcopy of the combine_fn to avoid sharing state between lifted + # stages when using cloudpickle. + try: + self._combine_fn_copy = copy.deepcopy(combine_fn) + except TypeError as e: + logging.warning( + 'Failed to copy combine function. Ensure python dependencies are' + ' properly set up: %s', + e) + self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) + + self.setup = self._combine_fn_copy.setup + self.create_accumulator = self._combine_fn_copy.create_accumulator + self.merge_accumulators = self._combine_fn_copy.merge_accumulators + self.compact = self._combine_fn_copy.compact + self.extract_output = self._combine_fn_copy.extract_output + self.teardown = self._combine_fn_copy.teardown + + def add_input(self, accumulator, element): is_accumulator, value = element if is_accumulator: - return combine_fn.merge_accumulators([accumulator, value]) + return self._combine_fn_copy.merge_accumulators([accumulator, value]) else: - return combine_fn.add_input(accumulator, value) - - setup = combine_fn.setup - create_accumulator = combine_fn.create_accumulator - merge_accumulators = combine_fn.merge_accumulators - compact = combine_fn.compact - extract_output = combine_fn.extract_output - teardown = combine_fn.teardown + return self._combine_fn_copy.add_input(accumulator, value) def StripNonce(nonce_key_value): (_, key), value = nonce_key_value From 085103a7fbfca1038adb6bf69c27ebd48f11a029 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 8 Oct 2024 14:46:22 +0000 Subject: [PATCH 2/3] Remove logging, add unit test --- .../combinefn_lifecycle_pipeline.py | 33 +++++++++++++++++++ .../transforms/combinefn_lifecycle_test.py | 7 ++++ sdks/python/apache_beam/transforms/core.py | 12 ++----- 3 files changed, 42 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py index 3cb5f32c3114..a3a25ec102b8 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py @@ -21,6 +21,7 @@ from typing import Tuple import apache_beam as beam +import math from apache_beam.options.pipeline_options import TypeOptions from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -124,6 +125,38 @@ def run_combine(pipeline, input_elements=5, lift_combiners=True): assert_that(pcoll, equal_to([(expected_result, expected_result)])) +def run_combine_uncopyable_attr( + pipeline, input_elements=5, lift_combiners=True): + # Calculate the expected result, which is the sum of an arithmetic sequence. + # By default, this is equal to: 0 + 1 + 2 + 3 + 4 = 10 + expected_result = input_elements * (input_elements - 1) / 2 + + # Enable runtime type checking in order to cover TypeCheckCombineFn by + # the test. + pipeline.get_pipeline_options().view_as(TypeOptions).runtime_type_check = True + pipeline.get_pipeline_options().view_as( + TypeOptions).allow_unsafe_triggers = True + + with pipeline as p: + pcoll = p | 'Start' >> beam.Create(range(input_elements)) + + # Certain triggers, such as AfterCount, are incompatible with combiner + # lifting. We can use that fact to prevent combiners from being lifted. + if not lift_combiners: + pcoll |= beam.WindowInto( + window.GlobalWindows(), + trigger=trigger.AfterCount(input_elements), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + + combine_fn = CallSequenceEnforcingCombineFn() + # Modules are not deep copyable. Ensure fanout falls back to pickling for + # copying combine_fn. + combine_fn.module_attribute = math + pcoll |= 'Do' >> beam.CombineGlobally(combine_fn).with_fanout(fanout=1) + + assert_that(pcoll, equal_to([expected_result])) + + def run_pardo(pipeline, input_elements=10): with pipeline as p: _ = ( diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 2a86f0251e75..34dbc1ac54b3 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -32,6 +32,7 @@ from apache_beam.transforms.combinefn_lifecycle_pipeline import CallSequenceEnforcingCombineFn from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo +from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine_uncopyable_attr @pytest.mark.it_validatesrunner @@ -67,6 +68,12 @@ def test_combine(self): run_combine(TestPipeline(runner=self.runner(), options=test_options)) self._assert_teardown_called() + def test_combine_deepcopy_fails(self): + test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"]) + run_combine_uncopyable_attr( + TestPipeline(runner=self.runner(), options=test_options)) + self._assert_teardown_called() + def test_non_liftable_combine(self): test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) run_combine( diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 24d1991f1769..7c8a3e3f42b9 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -3163,11 +3163,7 @@ def __init__(self): # stages when using cloudpickle. try: self._combine_fn_copy = copy.deepcopy(combine_fn) - except TypeError as e: - logging.warning( - 'Failed to copy combine function. Ensure python dependencies are' - ' properly set up: %s', - e) + except Exception: self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) self.setup = self._combine_fn_copy.setup @@ -3188,11 +3184,7 @@ def __init__(self): # stages when using cloudpickle. try: self._combine_fn_copy = copy.deepcopy(combine_fn) - except TypeError as e: - logging.warning( - 'Failed to copy combine function. Ensure python dependencies are' - ' properly set up: %s', - e) + except Exception: self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) self.setup = self._combine_fn_copy.setup From ddea31b394d075b35aa62bed82b03382b038fa31 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 9 Oct 2024 14:52:05 +0000 Subject: [PATCH 3/3] Linter fixes --- .../apache_beam/transforms/combinefn_lifecycle_pipeline.py | 2 +- sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py index a3a25ec102b8..56610e95297f 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py @@ -17,11 +17,11 @@ # pytype: skip-file +import math from typing import Set from typing import Tuple import apache_beam as beam -import math from apache_beam.options.pipeline_options import TypeOptions from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 34dbc1ac54b3..647e08db7aaa 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -31,8 +31,8 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.transforms.combinefn_lifecycle_pipeline import CallSequenceEnforcingCombineFn from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine -from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine_uncopyable_attr +from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo @pytest.mark.it_validatesrunner