Skip to content

Commit

Permalink
Add complete interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Claude committed Nov 17, 2024
1 parent e85448d commit 70b8066
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 17 deletions.
38 changes: 36 additions & 2 deletions sdks/python/apache_beam/runners/worker/statesampler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,41 @@
from abc import ABC, abstractmethod


class ScopedStateInterface(ABC):
@abstractmethod
def sampled_seconds(self) -> float:
pass

@abstractmethod
def sampled_msecs_int(self) -> int:
pass

@abstractmethod
def __enter__(self):
pass

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
pass


class StateSamplerInterface(ABC):
@abstractmethod
def update_metric(self, typed_metric_name, value):
raise NotImplementedError
def start(self) -> None:
pass

@abstractmethod
def stop(self) -> None:
pass

@abstractmethod
def reset(self) -> None:
pass

@abstractmethod
def current_state(self) -> ScopedStateInterface:
pass

@abstractmethod
def update_metric(self, typed_metric_name, value) -> None:
pass
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/runners/worker/statesampler_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from apache_beam.runners import common
from apache_beam.utils import counters
from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface
from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface, ScopedStateInterface


class StateSampler(StateSamplerInterface):
Expand Down Expand Up @@ -73,7 +73,7 @@ def reset(self) -> None:
pass


class ScopedState(object):
class ScopedState(ScopedStateInterface):
def __init__(
self,
sampler: StateSampler,
Expand Down
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/runners/worker/statesampler_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface
from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface, ScopedStateInterface


class StubStateSampler(StateSamplerInterface):
Expand All @@ -30,3 +30,15 @@ def update_metric(self, typed_metric_name, value):

def get_recorded_calls(self):
return self._update_metric_calls

def start(self) -> None:
raise NotImplementedError()

def stop(self) -> None:
raise NotImplementedError()

def reset(self) -> None:
raise NotImplementedError()

def current_state(self) -> ScopedStateInterface:
raise NotImplementedError()
12 changes: 5 additions & 7 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2553,9 +2553,7 @@ def submit(self, process_fn, *args, **kwargs):
stub_state_sampler.get_recorded_calls().items()
):
tracker.update_metric(typed_metric_name, value)
if results is None:
return
return list(results)
return results


class _SubprocessDoFn(DoFn):
Expand Down Expand Up @@ -2596,10 +2594,10 @@ def _call_remote(self, method, *args, **kwargs):
self._pool = concurrent.futures.ProcessPoolExecutor(1)
self._pool.submit(self._remote_init, self._serialized_fn).result()
try:
return _DeferredStateUpdatingPool(
self._pool,
self._timeout if method == self._remote_process else None).submit(
method, *args, **kwargs)
if (method == self._remote_process):
return _DeferredStateUpdatingPool(self._pool, self._timeout).submit(
method, *args, **kwargs)
return self._pool.submit(method, *args, **kwargs).result(None)
except (concurrent.futures.process.BrokenProcessPool,
TimeoutError,
concurrent.futures._base.TimeoutError):
Expand Down
17 changes: 12 additions & 5 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Unit tests for the PTransform and descendants."""

# pytype: skip-file

import collections
import operator
import os
Expand Down Expand Up @@ -2769,10 +2768,18 @@ def process(self, element):
yield element

with TestPipeline() as p:
_, _ = (
(p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn())
.with_exception_handling(
use_subprocess=self.use_subprocess, timeout=1))
good, _ = (
(p
| beam.Create([1,2,3]))
| beam.ParDo(CounterDoFn())
.with_exception_handling(
use_subprocess=self.use_subprocess,
timeout=1 if self.use_subprocess else .1
)
)

assert_that(good, equal_to([1, 2, 3]), label='CheckGood')

results = p.result

metric_results1 = results.metrics().query(
Expand Down

0 comments on commit 70b8066

Please sign in to comment.