From e85448d7d54cc5ba8e82f513c0639af29675bf43 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 30 Sep 2024 16:10:47 +0000 Subject: [PATCH] Add state sampler stub --- .../runners/worker/statesampler_interface.py | 24 +++++++ .../runners/worker/statesampler_slow.py | 3 +- .../runners/worker/statesampler_stub.py | 32 +++++++++ sdks/python/apache_beam/transforms/core.py | 69 ++++++++++++++++++- .../apache_beam/transforms/ptransform_test.py | 31 +++++---- 5 files changed, 144 insertions(+), 15 deletions(-) create mode 100644 sdks/python/apache_beam/runners/worker/statesampler_interface.py create mode 100644 sdks/python/apache_beam/runners/worker/statesampler_stub.py diff --git a/sdks/python/apache_beam/runners/worker/statesampler_interface.py b/sdks/python/apache_beam/runners/worker/statesampler_interface.py new file mode 100644 index 000000000000..aed470036ffb --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/statesampler_interface.py @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod + + +class StateSamplerInterface(ABC): + @abstractmethod + def update_metric(self, typed_metric_name, value): + raise NotImplementedError diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index be801284450a..eb05ee97a74a 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -21,9 +21,10 @@ from apache_beam.runners import common from apache_beam.utils import counters +from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface -class StateSampler(object): +class StateSampler(StateSamplerInterface): def __init__(self, sampling_period_ms): self._state_stack = [ ScopedState(self, counters.CounterName('unknown'), None) diff --git a/sdks/python/apache_beam/runners/worker/statesampler_stub.py b/sdks/python/apache_beam/runners/worker/statesampler_stub.py new file mode 100644 index 000000000000..563eaed8be43 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/statesampler_stub.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface + + +class StubStateSampler(StateSamplerInterface): + def __init__(self): + self._update_metric_calls = {} + + def update_metric(self, typed_metric_name, value): + if (typed_metric_name not in self._update_metric_calls): + self._update_metric_calls[typed_metric_name] = value + return + self._update_metric_calls[typed_metric_name] += value + + def get_recorded_calls(self): + return self._update_metric_calls diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 9c798d3ce6dc..a11c37ad7660 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -29,6 +29,7 @@ import traceback import types import typing +import functools from itertools import dropwhile from apache_beam import coders @@ -2495,6 +2496,68 @@ def as_result(self, error_post_processing=None): return self._pvalue +class _DeferredStateUpdatingPool: + """ + A class that submits a DoFn#process but defers updating counter metrics until + after the subprocess finishes execution. + """ + def __init__(self, pool, timeout): + """ + Args: + process_pool (ProcessPoolExecutor). + timeout (Optional[float]): The maximum time allowed for execution. + """ + self._pool = pool + self._timeout = timeout + + @staticmethod + def _wrapped_fn(fn, *args, **kwargs): + """Records thread scoped state modifications in the subprocess/thread and + replays them once the thread/subprocess returns""" + from apache_beam.runners.worker.statesampler_stub import StubStateSampler + stub_state_sampler = StubStateSampler() + + from apache_beam.runners.worker.statesampler import set_current_tracker + set_current_tracker(stub_state_sampler) + + results = fn(*args, **kwargs) + if results is not None: + # Ensure we iterate over the entire output list in the given amount of + # time. + results = list(results) + return (results, stub_state_sampler) + + def submit(self, process_fn, *args, **kwargs): + """ + Submits the process_fn for execution. + + Args: + process_fn (Callable): DoFn#process function to be executed in a + subprocess or thread. + *args: Positional arguments to be passed to the wrapped method. + **kwargs: Keyword arguments to be passed to the wrapped method. + + Returns: + Optional[list]: The results of the submitted_fn execution, or None if + no results. + """ + results, stub_state_sampler = self._pool.submit( + functools.partial(self._wrapped_fn, process_fn), + *args, **kwargs).result(self._timeout) + + from apache_beam.runners.worker.statesampler import get_current_tracker + tracker = get_current_tracker() + + if tracker is not None: + for typed_metric_name, value in ( + stub_state_sampler.get_recorded_calls().items() + ): + tracker.update_metric(typed_metric_name, value) + if results is None: + return + return list(results) + + class _SubprocessDoFn(DoFn): """Process method run in a subprocess, turning hard crashes into exceptions. """ @@ -2533,8 +2596,10 @@ def _call_remote(self, method, *args, **kwargs): self._pool = concurrent.futures.ProcessPoolExecutor(1) self._pool.submit(self._remote_init, self._serialized_fn).result() try: - return self._pool.submit(method, *args, **kwargs).result( - self._timeout if method == self._remote_process else None) + return _DeferredStateUpdatingPool( + self._pool, + self._timeout if method == self._remote_process else None).submit( + method, *args, **kwargs) except (concurrent.futures.process.BrokenProcessPool, TimeoutError, concurrent.futures._base.TimeoutError): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 7db017a59158..fe011e964368 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -2755,17 +2755,18 @@ def test_timeout(self): label='CheckBad') def test_increment_counter(self): - # Counters are not currently supported for - # ParDo#with_exception_handling(use_subprocess=True). - if (self.use_subprocess): - return - class CounterDoFn(beam.DoFn): def __init__(self): - self.records_counter = Metrics.counter(self.__class__, 'recordsCounter') + self.records_counter1 = Metrics.counter( + self.__class__, 'recordsCounter1') + self.records_counter2 = Metrics.counter( + self.__class__, 'recordsCounter2') def process(self, element): - self.records_counter.inc() + self.records_counter1.inc() + self.records_counter2.inc() + self.records_counter2.inc() + yield element with TestPipeline() as p: _, _ = ( @@ -2773,12 +2774,18 @@ def process(self, element): .with_exception_handling( use_subprocess=self.use_subprocess, timeout=1)) results = p.result - metric_results = results.metrics().query( - MetricsFilter().with_name("recordsCounter")) - records_counter = metric_results['counters'][0] - self.assertEqual(records_counter.key.metric.name, 'recordsCounter') - self.assertEqual(records_counter.result, 3) + metric_results1 = results.metrics().query( + MetricsFilter().with_name("recordsCounter1")) + records_counter1 = metric_results1['counters'][0] + metric_results2 = results.metrics().query( + MetricsFilter().with_name("recordsCounter2")) + records_counter2 = metric_results2['counters'][0] + + self.assertEqual(records_counter1.key.metric.name, 'recordsCounter1') + self.assertEqual(records_counter1.result, 3) + self.assertEqual(records_counter2.key.metric.name, 'recordsCounter2') + self.assertEqual(records_counter2.result, 6) def test_lifecycle(self): die = type(self).die