Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use state sampler stub to defer metrics updates when DoFn#process is executed in subprocess. #32600

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions sdks/python/apache_beam/runners/worker/statesampler_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABC, abstractmethod


class ScopedStateInterface(ABC):
@abstractmethod
def sampled_seconds(self) -> float:
pass

@abstractmethod
def sampled_msecs_int(self) -> int:
pass

@abstractmethod
def __enter__(self):
pass

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
pass


class StateSamplerInterface(ABC):
@abstractmethod
def start(self) -> None:
pass

@abstractmethod
def stop(self) -> None:
pass

@abstractmethod
def reset(self) -> None:
pass

@abstractmethod
def current_state(self) -> ScopedStateInterface:
pass

@abstractmethod
def update_metric(self, typed_metric_name, value) -> None:
pass
5 changes: 3 additions & 2 deletions sdks/python/apache_beam/runners/worker/statesampler_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

from apache_beam.runners import common
from apache_beam.utils import counters
from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface, ScopedStateInterface


class StateSampler(object):
class StateSampler(StateSamplerInterface):
def __init__(self, sampling_period_ms):
self._state_stack = [
ScopedState(self, counters.CounterName('unknown'), None)
Expand Down Expand Up @@ -72,7 +73,7 @@ def reset(self) -> None:
pass


class ScopedState(object):
class ScopedState(ScopedStateInterface):
def __init__(
self,
sampler: StateSampler,
Expand Down
44 changes: 44 additions & 0 deletions sdks/python/apache_beam/runners/worker/statesampler_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface, ScopedStateInterface


class StubStateSampler(StateSamplerInterface):
def __init__(self):
self._update_metric_calls = {}

def update_metric(self, typed_metric_name, value):
if (typed_metric_name not in self._update_metric_calls):
self._update_metric_calls[typed_metric_name] = value
return
self._update_metric_calls[typed_metric_name] += value

def get_recorded_calls(self):
return self._update_metric_calls

def start(self) -> None:
raise NotImplementedError()

def stop(self) -> None:
raise NotImplementedError()

def reset(self) -> None:
raise NotImplementedError()

def current_state(self) -> ScopedStateInterface:
raise NotImplementedError()
67 changes: 65 additions & 2 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import traceback
import types
import typing
import functools
from itertools import dropwhile

from apache_beam import coders
Expand Down Expand Up @@ -2495,6 +2496,66 @@ def as_result(self, error_post_processing=None):
return self._pvalue


class _DeferredStateUpdatingPool:
"""
A class that submits a DoFn#process but defers updating counter metrics until
after the subprocess finishes execution.
"""
def __init__(self, pool, timeout):
"""
Args:
process_pool (ProcessPoolExecutor).
timeout (Optional[float]): The maximum time allowed for execution.
"""
self._pool = pool
self._timeout = timeout

@staticmethod
def _wrapped_fn(fn, *args, **kwargs):
"""Records thread scoped state modifications in the subprocess/thread and
replays them once the thread/subprocess returns"""
from apache_beam.runners.worker.statesampler_stub import StubStateSampler
stub_state_sampler = StubStateSampler()

from apache_beam.runners.worker.statesampler import set_current_tracker
set_current_tracker(stub_state_sampler)

results = fn(*args, **kwargs)
if results is not None:
# Ensure we iterate over the entire output list in the given amount of
# time.
results = list(results)
return (results, stub_state_sampler)

def submit(self, process_fn, *args, **kwargs):
"""
Submits the process_fn for execution.

Args:
process_fn (Callable): DoFn#process function to be executed in a
subprocess or thread.
*args: Positional arguments to be passed to the wrapped method.
**kwargs: Keyword arguments to be passed to the wrapped method.

Returns:
Optional[list]: The results of the submitted_fn execution, or None if
no results.
"""
results, stub_state_sampler = self._pool.submit(
functools.partial(self._wrapped_fn, process_fn),
*args, **kwargs).result(self._timeout)

from apache_beam.runners.worker.statesampler import get_current_tracker
tracker = get_current_tracker()

if tracker is not None:
for typed_metric_name, value in (
stub_state_sampler.get_recorded_calls().items()
):
tracker.update_metric(typed_metric_name, value)
return results


class _SubprocessDoFn(DoFn):
"""Process method run in a subprocess, turning hard crashes into exceptions.
"""
Expand Down Expand Up @@ -2533,8 +2594,10 @@ def _call_remote(self, method, *args, **kwargs):
self._pool = concurrent.futures.ProcessPoolExecutor(1)
self._pool.submit(self._remote_init, self._serialized_fn).result()
try:
return self._pool.submit(method, *args, **kwargs).result(
self._timeout if method == self._remote_process else None)
if (method == self._remote_process):
return _DeferredStateUpdatingPool(self._pool, self._timeout).submit(
method, *args, **kwargs)
return self._pool.submit(method, *args, **kwargs).result(None)
except (concurrent.futures.process.BrokenProcessPool,
TimeoutError,
concurrent.futures._base.TimeoutError):
Expand Down
48 changes: 31 additions & 17 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Unit tests for the PTransform and descendants."""

# pytype: skip-file

import collections
import operator
import os
Expand Down Expand Up @@ -2755,30 +2754,45 @@ def test_timeout(self):
label='CheckBad')

def test_increment_counter(self):
# Counters are not currently supported for
# ParDo#with_exception_handling(use_subprocess=True).
if (self.use_subprocess):
return

class CounterDoFn(beam.DoFn):
def __init__(self):
self.records_counter = Metrics.counter(self.__class__, 'recordsCounter')
self.records_counter1 = Metrics.counter(
self.__class__, 'recordsCounter1')
self.records_counter2 = Metrics.counter(
self.__class__, 'recordsCounter2')

def process(self, element):
self.records_counter.inc()
self.records_counter1.inc()
self.records_counter2.inc()
self.records_counter2.inc()
yield element

with TestPipeline() as p:
_, _ = (
(p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn())
.with_exception_handling(
use_subprocess=self.use_subprocess, timeout=1))
good, _ = (
(p
| beam.Create([1,2,3]))
| beam.ParDo(CounterDoFn())
.with_exception_handling(
use_subprocess=self.use_subprocess,
timeout=1 if self.use_subprocess else .1
)
)

assert_that(good, equal_to([1, 2, 3]), label='CheckGood')

results = p.result
metric_results = results.metrics().query(
MetricsFilter().with_name("recordsCounter"))
records_counter = metric_results['counters'][0]

self.assertEqual(records_counter.key.metric.name, 'recordsCounter')
self.assertEqual(records_counter.result, 3)
metric_results1 = results.metrics().query(
MetricsFilter().with_name("recordsCounter1"))
records_counter1 = metric_results1['counters'][0]
metric_results2 = results.metrics().query(
MetricsFilter().with_name("recordsCounter2"))
records_counter2 = metric_results2['counters'][0]

self.assertEqual(records_counter1.key.metric.name, 'recordsCounter1')
self.assertEqual(records_counter1.result, 3)
self.assertEqual(records_counter2.key.metric.name, 'recordsCounter2')
self.assertEqual(records_counter2.result, 6)

def test_lifecycle(self):
die = type(self).die
Expand Down
Loading