Skip to content

Commit

Permalink
Add state sampler stub
Browse files Browse the repository at this point in the history
  • Loading branch information
Claude committed Nov 17, 2024
1 parent 4d6dcd5 commit e85448d
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 15 deletions.
24 changes: 24 additions & 0 deletions sdks/python/apache_beam/runners/worker/statesampler_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABC, abstractmethod


class StateSamplerInterface(ABC):
@abstractmethod
def update_metric(self, typed_metric_name, value):
raise NotImplementedError
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/runners/worker/statesampler_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

from apache_beam.runners import common
from apache_beam.utils import counters
from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface


class StateSampler(object):
class StateSampler(StateSamplerInterface):
def __init__(self, sampling_period_ms):
self._state_stack = [
ScopedState(self, counters.CounterName('unknown'), None)
Expand Down
32 changes: 32 additions & 0 deletions sdks/python/apache_beam/runners/worker/statesampler_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface


class StubStateSampler(StateSamplerInterface):
def __init__(self):
self._update_metric_calls = {}

def update_metric(self, typed_metric_name, value):
if (typed_metric_name not in self._update_metric_calls):
self._update_metric_calls[typed_metric_name] = value
return
self._update_metric_calls[typed_metric_name] += value

def get_recorded_calls(self):
return self._update_metric_calls
69 changes: 67 additions & 2 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import traceback
import types
import typing
import functools
from itertools import dropwhile

from apache_beam import coders
Expand Down Expand Up @@ -2495,6 +2496,68 @@ def as_result(self, error_post_processing=None):
return self._pvalue


class _DeferredStateUpdatingPool:
"""
A class that submits a DoFn#process but defers updating counter metrics until
after the subprocess finishes execution.
"""
def __init__(self, pool, timeout):
"""
Args:
process_pool (ProcessPoolExecutor).
timeout (Optional[float]): The maximum time allowed for execution.
"""
self._pool = pool
self._timeout = timeout

@staticmethod
def _wrapped_fn(fn, *args, **kwargs):
"""Records thread scoped state modifications in the subprocess/thread and
replays them once the thread/subprocess returns"""
from apache_beam.runners.worker.statesampler_stub import StubStateSampler
stub_state_sampler = StubStateSampler()

from apache_beam.runners.worker.statesampler import set_current_tracker
set_current_tracker(stub_state_sampler)

results = fn(*args, **kwargs)
if results is not None:
# Ensure we iterate over the entire output list in the given amount of
# time.
results = list(results)
return (results, stub_state_sampler)

def submit(self, process_fn, *args, **kwargs):
"""
Submits the process_fn for execution.
Args:
process_fn (Callable): DoFn#process function to be executed in a
subprocess or thread.
*args: Positional arguments to be passed to the wrapped method.
**kwargs: Keyword arguments to be passed to the wrapped method.
Returns:
Optional[list]: The results of the submitted_fn execution, or None if
no results.
"""
results, stub_state_sampler = self._pool.submit(
functools.partial(self._wrapped_fn, process_fn),
*args, **kwargs).result(self._timeout)

from apache_beam.runners.worker.statesampler import get_current_tracker
tracker = get_current_tracker()

if tracker is not None:
for typed_metric_name, value in (
stub_state_sampler.get_recorded_calls().items()
):
tracker.update_metric(typed_metric_name, value)
if results is None:
return
return list(results)


class _SubprocessDoFn(DoFn):
"""Process method run in a subprocess, turning hard crashes into exceptions.
"""
Expand Down Expand Up @@ -2533,8 +2596,10 @@ def _call_remote(self, method, *args, **kwargs):
self._pool = concurrent.futures.ProcessPoolExecutor(1)
self._pool.submit(self._remote_init, self._serialized_fn).result()
try:
return self._pool.submit(method, *args, **kwargs).result(
self._timeout if method == self._remote_process else None)
return _DeferredStateUpdatingPool(
self._pool,
self._timeout if method == self._remote_process else None).submit(
method, *args, **kwargs)
except (concurrent.futures.process.BrokenProcessPool,
TimeoutError,
concurrent.futures._base.TimeoutError):
Expand Down
31 changes: 19 additions & 12 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2755,30 +2755,37 @@ def test_timeout(self):
label='CheckBad')

def test_increment_counter(self):
# Counters are not currently supported for
# ParDo#with_exception_handling(use_subprocess=True).
if (self.use_subprocess):
return

class CounterDoFn(beam.DoFn):
def __init__(self):
self.records_counter = Metrics.counter(self.__class__, 'recordsCounter')
self.records_counter1 = Metrics.counter(
self.__class__, 'recordsCounter1')
self.records_counter2 = Metrics.counter(
self.__class__, 'recordsCounter2')

def process(self, element):
self.records_counter.inc()
self.records_counter1.inc()
self.records_counter2.inc()
self.records_counter2.inc()
yield element

with TestPipeline() as p:
_, _ = (
(p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn())
.with_exception_handling(
use_subprocess=self.use_subprocess, timeout=1))
results = p.result
metric_results = results.metrics().query(
MetricsFilter().with_name("recordsCounter"))
records_counter = metric_results['counters'][0]

self.assertEqual(records_counter.key.metric.name, 'recordsCounter')
self.assertEqual(records_counter.result, 3)
metric_results1 = results.metrics().query(
MetricsFilter().with_name("recordsCounter1"))
records_counter1 = metric_results1['counters'][0]
metric_results2 = results.metrics().query(
MetricsFilter().with_name("recordsCounter2"))
records_counter2 = metric_results2['counters'][0]

self.assertEqual(records_counter1.key.metric.name, 'recordsCounter1')
self.assertEqual(records_counter1.result, 3)
self.assertEqual(records_counter2.key.metric.name, 'recordsCounter2')
self.assertEqual(records_counter2.result, 6)

def test_lifecycle(self):
die = type(self).die
Expand Down

0 comments on commit e85448d

Please sign in to comment.