From 9b8ec45190385c4b3195dcff455a0fb8d0c7b338 Mon Sep 17 00:00:00 2001 From: Brian Hulette Date: Tue, 17 May 2022 16:15:11 -0700 Subject: [PATCH] Revert "[BEAM-14294] Worker changes to support trivial Batched DoFns (#17384)" (#17694) This reverts commit 1c4418ce463468b4d5a5b53dfa6e9c35470b976b. --- .../apache_beam/coders/fast_coders_test.py | 2 +- sdks/python/apache_beam/runners/common.pxd | 15 +- sdks/python/apache_beam/runners/common.py | 505 ++++-------------- .../fn_api_runner/fn_runner_test.py | 202 ------- .../portability/portable_runner_test.py | 5 - .../runners/worker/bundle_processor.py | 32 +- .../apache_beam/runners/worker/opcounters.pxd | 2 - .../apache_beam/runners/worker/opcounters.py | 17 +- .../apache_beam/runners/worker/operations.pxd | 22 +- .../apache_beam/runners/worker/operations.py | 282 ++-------- .../apache_beam/transforms/batch_dofn_test.py | 14 - .../apache_beam/transforms/combiners.py | 4 +- .../apache_beam/transforms/ptransform_test.py | 8 +- sdks/python/apache_beam/typehints/batch.py | 6 - .../apache_beam/utils/windowed_value.pxd | 8 - .../apache_beam/utils/windowed_value.py | 129 +---- .../apache_beam/utils/windowed_value_test.py | 92 ---- 17 files changed, 190 insertions(+), 1155 deletions(-) diff --git a/sdks/python/apache_beam/coders/fast_coders_test.py b/sdks/python/apache_beam/coders/fast_coders_test.py index fa8643c2a383..c7112e0e4842 100644 --- a/sdks/python/apache_beam/coders/fast_coders_test.py +++ b/sdks/python/apache_beam/coders/fast_coders_test.py @@ -29,7 +29,7 @@ class FastCoders(unittest.TestCase): def test_using_fast_impl(self): try: - utils.check_compiled('apache_beam.coders.coder_impl') + utils.check_compiled('apache_beam.coders') except RuntimeError: self.skipTest('Cython is not installed') # pylint: disable=wrong-import-order, wrong-import-position diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd index 9a3d8d250b3d..08de4b9c332f 100644 --- a/sdks/python/apache_beam/runners/common.pxd +++ b/sdks/python/apache_beam/runners/common.pxd @@ -18,7 +18,6 @@ cimport cython from apache_beam.utils.windowed_value cimport WindowedValue -from apache_beam.utils.windowed_value cimport WindowedBatch from apache_beam.transforms.cy_dataflow_distribution_counter cimport DataflowDistributionCounter from libc.stdint cimport int64_t @@ -29,15 +28,12 @@ cdef type TaggedOutput, TimestampedValue cdef class Receiver(object): cpdef receive(self, WindowedValue windowed_value) - cpdef receive_batch(self, WindowedBatch windowed_batch) - cpdef flush(self) cdef class MethodWrapper(object): cdef public object args cdef public object defaults cdef public object method_value - cdef str method_name cdef bint has_userstate_arguments cdef object state_args_to_replace cdef object timer_args_to_replace @@ -54,7 +50,6 @@ cdef class MethodWrapper(object): cdef class DoFnSignature(object): cdef public MethodWrapper process_method - cdef public MethodWrapper process_batch_method cdef public MethodWrapper start_bundle_method cdef public MethodWrapper finish_bundle_method cdef public MethodWrapper setup_lifecycle_method @@ -63,7 +58,6 @@ cdef class DoFnSignature(object): cdef public MethodWrapper initial_restriction_method cdef public MethodWrapper create_tracker_method cdef public MethodWrapper split_method - cdef public object batching_configuration cdef public object do_fn cdef public object timer_methods cdef bint _is_stateful_dofn @@ -87,7 +81,6 @@ cdef class DoFnInvoker(object): cdef class SimpleInvoker(DoFnInvoker): cdef object process_method - cdef object process_batch_method cdef class PerWindowInvoker(DoFnInvoker): @@ -95,14 +88,10 @@ cdef class PerWindowInvoker(DoFnInvoker): cdef DoFnContext context cdef list args_for_process cdef dict kwargs_for_process - cdef list placeholders_for_process - cdef list args_for_process_batch - cdef dict kwargs_for_process_batch - cdef list placeholders_for_process_batch + cdef list placeholders cdef bint has_windowed_inputs cdef bint cache_globally_windowed_args cdef object process_method - cdef object process_batch_method cdef bint is_splittable cdef object threadsafe_restriction_tracker cdef object threadsafe_watermark_estimator @@ -136,8 +125,6 @@ cdef class _OutputProcessor(OutputProcessor): cdef Receiver main_receivers cdef object tagged_receivers cdef DataflowDistributionCounter per_element_output_counter - cdef object output_batch_converter - @cython.locals(windowed_value=WindowedValue, output_element_count=int64_t) cpdef process_outputs(self, WindowedValue element, results, diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 594a7e59cf57..7c1cf49d7625 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -23,11 +23,9 @@ """ # pytype: skip-file - import sys import threading import traceback -from enum import Enum from typing import TYPE_CHECKING from typing import Any from typing import Dict @@ -56,12 +54,9 @@ from apache_beam.transforms.window import TimestampedValue from apache_beam.transforms.window import WindowFn from apache_beam.typehints import typehints -from apache_beam.typehints.batch import BatchConverter from apache_beam.utils.counters import Counter from apache_beam.utils.counters import CounterName from apache_beam.utils.timestamp import Timestamp -from apache_beam.utils.windowed_value import HomogeneousWindowedBatch -from apache_beam.utils.windowed_value import WindowedBatch from apache_beam.utils.windowed_value import WindowedValue if TYPE_CHECKING: @@ -115,13 +110,6 @@ def receive(self, windowed_value): # type: (WindowedValue) -> None raise NotImplementedError - def receive_batch(self, windowed_batch): - # type: (WindowedBatch) -> None - raise NotImplementedError - - def flush(self): - raise NotImplementedError - class MethodWrapper(object): """For internal use only; no backwards-compatibility guarantees. @@ -148,7 +136,6 @@ def __init__(self, obj_to_invoke, method_name): # TODO(BEAM-5878) support kwonlyargs on Python 3. self.method_value = getattr(obj_to_invoke, method_name) - self.method_name = method_name self.has_userstate_arguments = False self.state_args_to_replace = {} # type: Dict[str, core.StateSpec] @@ -227,26 +214,6 @@ def invoke_timer_callback( return self.method_value() -class BatchingPreference(Enum): - DO_NOT_CARE = 1 # This operation can operate on batches or element-at-a-time - # TODO: Should we also store batching parameters here? (time/size preferences) - BATCH_REQUIRED = 2 # This operation can only operate on batches - BATCH_FORBIDDEN = 3 # This operation can only work element-at-a-time - # Other possibilities: BATCH_PREFERRED (with min batch size specified) - - @property - def supports_batches(self) -> bool: - return self in (self.BATCH_REQUIRED, self.DO_NOT_CARE) - - @property - def supports_elements(self) -> bool: - return self in (self.BATCH_FORBIDDEN, self.DO_NOT_CARE) - - @property - def requires_batches(self) -> bool: - return self == self.BATCH_REQUIRED - - class DoFnSignature(object): """Represents the signature of a given ``DoFn`` object. @@ -266,7 +233,6 @@ def __init__(self, do_fn): self.do_fn = do_fn self.process_method = MethodWrapper(do_fn, 'process') - self.process_batch_method = MethodWrapper(do_fn, 'process_batch') self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle') self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle') self.setup_lifecycle_method = MethodWrapper(do_fn, 'setup') @@ -313,55 +279,23 @@ def is_unbounded_per_element(self): def _validate(self): # type: () -> None self._validate_process() - self._validate_process_batch() self._validate_bundle_method(self.start_bundle_method) self._validate_bundle_method(self.finish_bundle_method) self._validate_stateful_dofn() - def _check_duplicate_dofn_params(self, method: MethodWrapper): - param_ids = [ - d.param_id for d in method.defaults if isinstance(d, core._DoFnParam) - ] - if len(param_ids) != len(set(param_ids)): - raise ValueError( - 'DoFn %r has duplicate %s method parameters: %s.' % - (self.do_fn, method.method_name, param_ids)) - def _validate_process(self): # type: () -> None """Validate that none of the DoFnParameters are repeated in the function """ - self._check_duplicate_dofn_params(self.process_method) - - def _validate_process_batch(self): - # type: () -> None - self._check_duplicate_dofn_params(self.process_batch_method) - - for d in self.process_batch_method.defaults: - if not isinstance(d, core._DoFnParam): - continue - - # Helpful errors for params which will be supported in the future - if d == (core.DoFn.ElementParam): - # We currently assume we can just get the typehint from the first - # parameter. ElementParam breaks this assumption - raise NotImplementedError( - f"DoFn {self.do_fn!r} uses unsupported DoFn param ElementParam.") - - if d in (core.DoFn.KeyParam, core.DoFn.StateParam, core.DoFn.TimerParam): - raise NotImplementedError( - f"DoFn {self.do_fn!r} has unsupported per-key DoFn param {d}. " - "Per-key DoFn params are not yet supported for process_batch " - "(BEAM-14409).") - - # Fallback to catch anything not explicitly supported - if not d in (core.DoFn.WindowParam, - core.DoFn.TimestampParam, - core.DoFn.PaneInfoParam): - raise ValueError( - f"DoFn {self.do_fn!r} has unsupported process_batch " - f"method parameter {d}") + param_ids = [ + d.param_id for d in self.process_method.defaults + if isinstance(d, core._DoFnParam) + ] + if len(param_ids) != len(set(param_ids)): + raise ValueError( + 'DoFn %r has duplicate process method parameters: %s.' % + (self.do_fn, param_ids)) def _validate_bundle_method(self, method_wrapper): """Validate that none of the DoFnParameters are used in the function @@ -422,7 +356,7 @@ class DoFnInvoker(object): represented by a given DoFnSignature.""" def __init__(self, - output_processor, # type: _OutputProcessor + output_processor, # type: OutputProcessor signature # type: DoFnSignature ): # type: (...) -> None @@ -442,7 +376,7 @@ def __init__(self, @staticmethod def create_invoker( signature, # type: DoFnSignature - output_processor, # type: OutputProcessor + output_processor, # type: _OutputProcessor context=None, # type: Optional[DoFnContext] side_inputs=None, # type: Optional[List[sideinputs.SideInputMap]] input_args=None, input_kwargs=None, @@ -477,10 +411,10 @@ def create_invoker( allows a callback to be registered. """ side_inputs = side_inputs or [] + default_arg_values = signature.process_method.defaults use_per_window_invoker = process_invocation and ( - side_inputs or input_args or input_kwargs or - signature.process_method.defaults or - signature.process_batch_method.defaults or signature.is_stateful_dofn()) + side_inputs or input_args or input_kwargs or default_arg_values or + signature.is_stateful_dofn()) if not use_per_window_invoker: return SimpleInvoker(output_processor, signature) else: @@ -523,26 +457,6 @@ def invoke_process(self, """ raise NotImplementedError - def invoke_process_batch(self, - windowed_batch, # type: WindowedBatch - additional_args=None, - additional_kwargs=None - ): - # type: (...) -> None - - """Invokes the DoFn.process() function. - - Args: - windowed_batch: a WindowedBatch object that gives a batch of elements for - which process_batch() method should be invoked, along with - the window each element belongs to. - additional_args: additional arguments to be passed to the current - `DoFn.process()` invocation, usually as side inputs. - additional_kwargs: additional keyword arguments to be passed to the - current `DoFn.process()` invocation. - """ - raise NotImplementedError - def invoke_setup(self): # type: () -> None @@ -610,7 +524,6 @@ def __init__(self, # type: (...) -> None super().__init__(output_processor, signature) self.process_method = signature.process_method.method_value - self.process_batch_method = signature.process_batch_method.method_value def invoke_process(self, windowed_value, # type: WindowedValue @@ -624,94 +537,12 @@ def invoke_process(self, windowed_value, self.process_method(windowed_value.value)) return [] - def invoke_process_batch(self, - windowed_batch, # type: WindowedBatch - restriction=None, - watermark_estimator_state=None, - additional_args=None, - additional_kwargs=None - ): - # type: (...) -> None - self.output_processor.process_batch_outputs( - windowed_batch, self.process_batch_method(windowed_batch.values)) - - -def _get_arg_placeholders( - method: MethodWrapper, - input_args: Optional[List[Any]], - input_kwargs: Optional[Dict[str, any]]): - input_args = input_args if input_args else [] - input_kwargs = input_kwargs if input_kwargs else {} - - arg_names = method.args - default_arg_values = method.defaults - - # Create placeholder for element parameter of DoFn.process() method. - # Not to be confused with ArgumentPlaceHolder, which may be passed in - # input_args and is a placeholder for side-inputs. - class ArgPlaceholder(object): - def __init__(self, placeholder): - self.placeholder = placeholder - - if all(core.DoFn.ElementParam != arg for arg in default_arg_values): - # TODO(BEAM-7867): Handle cases in which len(arg_names) == - # len(default_arg_values). - args_to_pick = len(arg_names) - len(default_arg_values) - 1 - # Positional argument values for process(), with placeholders for special - # values such as the element, timestamp, etc. - args_with_placeholders = ([ArgPlaceholder(core.DoFn.ElementParam)] + - input_args[:args_to_pick]) - else: - args_to_pick = len(arg_names) - len(default_arg_values) - args_with_placeholders = input_args[:args_to_pick] - - # Fill the OtherPlaceholders for context, key, window or timestamp - remaining_args_iter = iter(input_args[args_to_pick:]) - for a, d in zip(arg_names[-len(default_arg_values):], default_arg_values): - if core.DoFn.ElementParam == d: - args_with_placeholders.append(ArgPlaceholder(d)) - elif core.DoFn.KeyParam == d: - args_with_placeholders.append(ArgPlaceholder(d)) - elif core.DoFn.WindowParam == d: - args_with_placeholders.append(ArgPlaceholder(d)) - elif core.DoFn.TimestampParam == d: - args_with_placeholders.append(ArgPlaceholder(d)) - elif core.DoFn.PaneInfoParam == d: - args_with_placeholders.append(ArgPlaceholder(d)) - elif core.DoFn.SideInputParam == d: - # If no more args are present then the value must be passed via kwarg - try: - args_with_placeholders.append(next(remaining_args_iter)) - except StopIteration: - if a not in input_kwargs: - raise ValueError("Value for sideinput %s not provided" % a) - elif isinstance(d, core.DoFn.StateParam): - args_with_placeholders.append(ArgPlaceholder(d)) - elif isinstance(d, core.DoFn.TimerParam): - args_with_placeholders.append(ArgPlaceholder(d)) - elif isinstance(d, type) and core.DoFn.BundleFinalizerParam == d: - args_with_placeholders.append(ArgPlaceholder(d)) - else: - # If no more args are present then the value must be passed via kwarg - try: - args_with_placeholders.append(next(remaining_args_iter)) - except StopIteration: - pass - args_with_placeholders.extend(list(remaining_args_iter)) - - # Stash the list of placeholder positions for performance - placeholders = [(i, x.placeholder) - for (i, x) in enumerate(args_with_placeholders) - if isinstance(x, ArgPlaceholder)] - - return placeholders, args_with_placeholders, input_kwargs - class PerWindowInvoker(DoFnInvoker): """An invoker that processes elements considering windowing information.""" def __init__(self, - output_processor, # type: OutputProcessor + output_processor, # type: _OutputProcessor signature, # type: DoFnSignature context, # type: DoFnContext side_inputs, # type: Iterable[sideinputs.SideInputMap] @@ -726,45 +557,95 @@ def __init__(self, self.process_method = signature.process_method.method_value default_arg_values = signature.process_method.defaults self.has_windowed_inputs = ( - not all(si.is_globally_windowed() for si in side_inputs) or any( - core.DoFn.WindowParam == arg - for arg in signature.process_method.defaults) or any( - core.DoFn.WindowParam == arg - for arg in signature.process_batch_method.defaults) or + not all(si.is_globally_windowed() for si in side_inputs) or + any(core.DoFn.WindowParam == arg for arg in default_arg_values) or signature.is_stateful_dofn()) self.user_state_context = user_state_context self.is_splittable = signature.is_splittable_dofn() - self.is_key_param_required = any( - core.DoFn.KeyParam == arg for arg in default_arg_values) self.threadsafe_restriction_tracker = None # type: Optional[ThreadsafeRestrictionTracker] self.threadsafe_watermark_estimator = None # type: Optional[ThreadsafeWatermarkEstimator] self.current_windowed_value = None # type: Optional[WindowedValue] self.bundle_finalizer_param = bundle_finalizer_param + self.is_key_param_required = False if self.is_splittable: self.splitting_lock = threading.Lock() self.current_window_index = None self.stop_window_index = None + # Try to prepare all the arguments that can just be filled in + # without any additional work. in the process function. + # Also cache all the placeholders needed in the process function. + # Flag to cache additional arguments on the first element if all # inputs are within the global window. self.cache_globally_windowed_args = not self.has_windowed_inputs - # Try to prepare all the arguments that can just be filled in - # without any additional work. in the process function. - # Also cache all the placeholders needed in the process function. - ( - self.placeholders_for_process, - self.args_for_process, - self.kwargs_for_process) = _get_arg_placeholders( - signature.process_method, input_args, input_kwargs) + input_args = input_args if input_args else [] + input_kwargs = input_kwargs if input_kwargs else {} + + arg_names = signature.process_method.args + + # Create placeholder for element parameter of DoFn.process() method. + # Not to be confused with ArgumentPlaceHolder, which may be passed in + # input_args and is a placeholder for side-inputs. + class ArgPlaceholder(object): + def __init__(self, placeholder): + self.placeholder = placeholder + + if all(core.DoFn.ElementParam != arg for arg in default_arg_values): + # TODO(BEAM-7867): Handle cases in which len(arg_names) == + # len(default_arg_values). + args_to_pick = len(arg_names) - len(default_arg_values) - 1 + # Positional argument values for process(), with placeholders for special + # values such as the element, timestamp, etc. + args_with_placeholders = ([ArgPlaceholder(core.DoFn.ElementParam)] + + input_args[:args_to_pick]) + else: + args_to_pick = len(arg_names) - len(default_arg_values) + args_with_placeholders = input_args[:args_to_pick] + + # Fill the OtherPlaceholders for context, key, window or timestamp + remaining_args_iter = iter(input_args[args_to_pick:]) + for a, d in zip(arg_names[-len(default_arg_values):], default_arg_values): + if core.DoFn.ElementParam == d: + args_with_placeholders.append(ArgPlaceholder(d)) + elif core.DoFn.KeyParam == d: + self.is_key_param_required = True + args_with_placeholders.append(ArgPlaceholder(d)) + elif core.DoFn.WindowParam == d: + args_with_placeholders.append(ArgPlaceholder(d)) + elif core.DoFn.TimestampParam == d: + args_with_placeholders.append(ArgPlaceholder(d)) + elif core.DoFn.PaneInfoParam == d: + args_with_placeholders.append(ArgPlaceholder(d)) + elif core.DoFn.SideInputParam == d: + # If no more args are present then the value must be passed via kwarg + try: + args_with_placeholders.append(next(remaining_args_iter)) + except StopIteration: + if a not in input_kwargs: + raise ValueError("Value for sideinput %s not provided" % a) + elif isinstance(d, core.DoFn.StateParam): + args_with_placeholders.append(ArgPlaceholder(d)) + elif isinstance(d, core.DoFn.TimerParam): + args_with_placeholders.append(ArgPlaceholder(d)) + elif isinstance(d, type) and core.DoFn.BundleFinalizerParam == d: + args_with_placeholders.append(ArgPlaceholder(d)) + else: + # If no more args are present then the value must be passed via kwarg + try: + args_with_placeholders.append(next(remaining_args_iter)) + except StopIteration: + pass + args_with_placeholders.extend(list(remaining_args_iter)) - self.process_batch_method = signature.process_batch_method.method_value + # Stash the list of placeholder positions for performance + self.placeholders = [(i, x.placeholder) + for (i, x) in enumerate(args_with_placeholders) + if isinstance(x, ArgPlaceholder)] - ( - self.placeholders_for_process_batch, - self.args_for_process_batch, - self.kwargs_for_process_batch) = _get_arg_placeholders( - signature.process_batch_method, input_args, input_kwargs) + self.args_for_process = args_with_placeholders + self.kwargs_for_process = input_kwargs def invoke_process(self, windowed_value, # type: WindowedValue @@ -838,33 +719,6 @@ def invoke_process(self, windowed_value, additional_args, additional_kwargs) return residuals - def invoke_process_batch(self, - windowed_batch, # type: WindowedBatch - additional_args=None, - additional_kwargs=None - ): - # type: (...) -> None - - if not additional_args: - additional_args = [] - if not additional_kwargs: - additional_kwargs = {} - - assert isinstance(windowed_batch, HomogeneousWindowedBatch) - - if self.has_windowed_inputs and len(windowed_batch.windows) != 1: - for w in windowed_batch.windows: - self._invoke_process_batch_per_window( - HomogeneousWindowedBatch.of( - windowed_batch.values, - windowed_batch.timestamp, (w, ), - windowed_batch.pane_info), - additional_args, - additional_kwargs) - else: - self._invoke_process_batch_per_window( - windowed_batch, additional_args, additional_kwargs) - def _should_process_window_for_sdf( self, windowed_value, # type: WindowedValue @@ -908,9 +762,7 @@ def _invoke_process_per_window(self, additional_kwargs, ): # type: (...) -> Optional[SplitResultResidual] - if self.has_windowed_inputs: - assert len(windowed_value.windows) <= 1 window, = windowed_value.windows side_inputs = [si[window] for si in self.side_inputs] side_inputs.extend(additional_args) @@ -946,7 +798,7 @@ def _invoke_process_per_window(self, 'Input value to a stateful DoFn or KeyParam must be a KV tuple; ' 'instead, got \'%s\'.') % (windowed_value.value, )) - for i, p in self.placeholders_for_process: + for i, p in self.placeholders: if core.DoFn.ElementParam == p: args_for_process[i] = windowed_value.value elif core.DoFn.KeyParam == p: @@ -973,15 +825,23 @@ def _invoke_process_per_window(self, elif core.DoFn.BundleFinalizerParam == p: args_for_process[i] = self.bundle_finalizer_param - kwargs_for_process = kwargs_for_process or {} - if additional_kwargs: - kwargs_for_process.update(additional_kwargs) + if kwargs_for_process is None: + kwargs_for_process = additional_kwargs + else: + for key in additional_kwargs: + kwargs_for_process[key] = additional_kwargs[key] - self.output_processor.process_outputs( - windowed_value, - self.process_method(*args_for_process, **kwargs_for_process), - self.threadsafe_watermark_estimator) + if kwargs_for_process: + self.output_processor.process_outputs( + windowed_value, + self.process_method(*args_for_process, **kwargs_for_process), + self.threadsafe_watermark_estimator) + else: + self.output_processor.process_outputs( + windowed_value, + self.process_method(*args_for_process), + self.threadsafe_watermark_estimator) if self.is_splittable: assert self.threadsafe_restriction_tracker is not None @@ -1006,68 +866,6 @@ def _invoke_process_per_window(self, deferred_timestamp=deferred_timestamp) return None - def _invoke_process_batch_per_window( - self, - windowed_batch: WindowedBatch, - additional_args, - additional_kwargs, - ): - # type: (...) -> Optional[SplitResultResidual] - - if self.has_windowed_inputs: - assert isinstance(windowed_batch, HomogeneousWindowedBatch) - assert len(windowed_batch.windows) <= 1 - - window, = windowed_batch.windows - side_inputs = [si[window] for si in self.side_inputs] - side_inputs.extend(additional_args) - (args_for_process_batch, - kwargs_for_process_batch) = util.insert_values_in_args( - self.args_for_process_batch, - self.kwargs_for_process_batch, - side_inputs) - elif self.cache_globally_windowed_args: - # Attempt to cache additional args if all inputs are globally - # windowed inputs when processing the first element. - self.cache_globally_windowed_args = False - - # Fill in sideInputs if they are globally windowed - global_window = GlobalWindow() - self.args_for_process_batch, self.kwargs_for_process_batch = ( - util.insert_values_in_args( - self.args_for_process_batch, self.kwargs_for_process_batch, - [si[global_window] for si in self.side_inputs])) - args_for_process_batch, kwargs_for_process_batch = ( - self.args_for_process_batch, self.kwargs_for_process_batch) - else: - args_for_process_batch, kwargs_for_process_batch = ( - self.args_for_process_batch, self.kwargs_for_process_batch) - - for i, p in self.placeholders_for_process_batch: - if core.DoFn.ElementParam == p: - args_for_process_batch[i] = windowed_batch.values - elif core.DoFn.KeyParam == p: - raise NotImplementedError("BEAM-14409: Per-key process_batch") - elif core.DoFn.WindowParam == p: - args_for_process_batch[i] = window - elif core.DoFn.TimestampParam == p: - args_for_process_batch[i] = windowed_batch.timestamp - elif core.DoFn.PaneInfoParam == p: - assert isinstance(windowed_batch, HomogeneousWindowedBatch) - args_for_process_batch[i] = windowed_batch.pane_info - elif isinstance(p, core.DoFn.StateParam): - raise NotImplementedError("BEAM-14409: Per-key process_batch") - elif isinstance(p, core.DoFn.TimerParam): - raise NotImplementedError("BEAM-14409: Per-key process_batch") - - kwargs_for_process_batch = kwargs_for_process_batch or {} - - self.output_processor.process_batch_outputs( - windowed_batch, - self.process_batch_method( - *args_for_process_batch, **kwargs_for_process_batch), - self.threadsafe_watermark_estimator) - @staticmethod def _try_split(fraction, window_index, # type: Optional[int] @@ -1372,15 +1170,11 @@ def __init__(self, else: per_element_output_counter = None - # TODO(BEAM-14293): output processor assumes DoFns are batch-to-batch or - # element-to-element, @yields_batches and @yields_elements will break this - # assumption. output_processor = _OutputProcessor( windowing.windowfn, main_receivers, tagged_receivers, - per_element_output_counter, - getattr(fn, 'output_batch_converter', None)) + per_element_output_counter) if do_fn_signature.is_stateful_dofn() and not user_state_context: raise Exception( @@ -1406,13 +1200,6 @@ def process(self, windowed_value): self._reraise_augmented(exn) return [] - def process_batch(self, windowed_batch): - # type: (WindowedBatch) -> None - try: - self.do_fn_invoker.invoke_process_batch(windowed_batch) - except BaseException as exn: - self._reraise_augmented(exn) - def process_with_sized_restriction(self, windowed_value): # type: (WindowedValue) -> Iterable[SplitResultResidual] (element, (restriction, estimator_state)), _ = windowed_value.value @@ -1500,11 +1287,6 @@ def process_outputs( # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None raise NotImplementedError - def process_batch_outputs( - self, windowed_input_element, results, watermark_estimator=None): - # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None - raise NotImplementedError - class _OutputProcessor(OutputProcessor): """Processes output produced by DoFn method invocations.""" @@ -1513,9 +1295,7 @@ def __init__(self, window_fn, main_receivers, # type: Receiver tagged_receivers, # type: Mapping[Optional[str], Receiver] - per_element_output_counter, - output_batch_converter, # type: Optional[BatchConverter] - ): + per_element_output_counter): """Initializes ``_OutputProcessor``. Args: @@ -1528,12 +1308,7 @@ def __init__(self, self.window_fn = window_fn self.main_receivers = main_receivers self.tagged_receivers = tagged_receivers - if (per_element_output_counter is not None and - per_element_output_counter.is_cythonized): - self.per_element_output_counter = per_element_output_counter - else: - self.per_element_output_counter = None - self.output_batch_converter = output_batch_converter + self.per_element_output_counter = per_element_output_counter def process_outputs( self, windowed_input_element, results, watermark_estimator=None): @@ -1547,7 +1322,8 @@ def process_outputs( if results is None: # TODO(BEAM-3937): Remove if block after output counter released. # Only enable per_element_output_counter when counter cythonized. - if self.per_element_output_counter is not None: + if (self.per_element_output_counter is not None and + self.per_element_output_counter.is_cythonized): self.per_element_output_counter.add_input(0) return @@ -1585,75 +1361,10 @@ def process_outputs( self.main_receivers.receive(windowed_value) else: self.tagged_receivers[tag].receive(windowed_value) - - # TODO(BEAM-3937): Remove if block after output counter released. - # Only enable per_element_output_counter when counter cythonized - if self.per_element_output_counter is not None: - self.per_element_output_counter.add_input(output_element_count) - - def process_batch_outputs( - self, windowed_input_batch, results, watermark_estimator=None): - # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None - - """Dispatch the result of process computation to the appropriate receivers. - - A value wrapped in a TaggedOutput object will be unwrapped and - then dispatched to the appropriate indexed output. - """ - if results is None: - # TODO(BEAM-3937): Remove if block after output counter released. - # Only enable per_element_output_counter when counter cythonized. - if self.per_element_output_counter is not None: - self.per_element_output_counter.add_input(0) - return - - # TODO(BEAM-10782): Verify that the results object is a valid iterable type - # if performance_runtime_type_check is active, without harming performance - - assert self.output_batch_converter is not None - - output_element_count = 0 - for result in results: - tag = None - if isinstance(result, TaggedOutput): - tag = result.tag - if not isinstance(tag, str): - raise TypeError('In %s, tag %s is not a string' % (self, tag)) - result = result.value - if isinstance(result, (WindowedValue, TimestampedValue)): - raise TypeError( - f"Received {type(result).__name__} from DoFn that was " - "expected to produce a batch.") - if isinstance(result, WindowedBatch): - assert isinstance(result, HomogeneousWindowedBatch) - windowed_batch = result - - if (windowed_input_batch is not None and - len(windowed_input_batch.windows) != 1): - windowed_batch.windows *= len(windowed_input_batch.windows) - # TODO(BEAM-14352): Add TimestampedBatch, an analogue for TimestampedValue - # and handle it here (see TimestampedValue logic in process_outputs). - else: - # TODO: This should error unless the DoFn was defined with - # @DoFn.yields_batches(output_aligned_with_input=True) - # We should consider also validating that the length is the same as - # windowed_input_batch - windowed_batch = windowed_input_batch.with_values(result) - - output_element_count += self.output_batch_converter.get_length( - windowed_input_batch.values) - - if watermark_estimator is not None: - for timestamp in windowed_batch.timestamps: - watermark_estimator.observe_timestamp(timestamp) - if tag is None: - self.main_receivers.receive_batch(windowed_batch) - else: - self.tagged_receivers[tag].receive_batch(windowed_batch) - # TODO(BEAM-3937): Remove if block after output counter released. # Only enable per_element_output_counter when counter cythonized - if self.per_element_output_counter is not None: + if (self.per_element_output_counter is not None and + self.per_element_output_counter.is_cythonized): self.per_element_output_counter.add_input(output_element_count) def start_bundle_outputs(self, results): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 06016706d039..6daf025d2adf 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -32,13 +32,9 @@ import uuid from typing import Any from typing import Dict -from typing import Iterator -from typing import List from typing import Tuple -from typing import no_type_check import hamcrest # pylint: disable=ungrouped-imports -import numpy as np import pytest from hamcrest.core.matcher import Matcher from hamcrest.core.string_description import StringDescription @@ -63,17 +59,14 @@ from apache_beam.runners.sdf_utils import RestrictionTrackerView from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import statesampler -from apache_beam.runners.worker.operations import InefficientExecutionWarning from apache_beam.testing.synthetic_pipeline import SyntheticSDFAsSource from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.tools import utils from apache_beam.transforms import environments from apache_beam.transforms import userstate from apache_beam.transforms import window from apache_beam.utils import timestamp -from apache_beam.utils import windowed_value if statesampler.FAST_SAMPLER: DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS @@ -128,174 +121,6 @@ def test_pardo(self): | beam.Map(lambda e: e + 'x')) assert_that(res, equal_to(['aax', 'bcbcx'])) - def test_batch_pardo(self): - with self.create_pipeline() as p: - res = ( - p - | beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types( - np.int64) - | beam.ParDo(ArrayMultiplyDoFn()) - | beam.Map(lambda x: x * 3)) - - assert_that(res, equal_to([6, 12, 18])) - - def test_batch_pardo_trigger_flush(self): - try: - utils.check_compiled('apache_beam.coders.coder_impl') - except RuntimeError: - self.skipTest( - 'BEAM-14410: FnRunnerTest with non-trivial inputs flakes ' - 'in non-cython environments') - - with self.create_pipeline() as p: - res = ( - p - # Pass more than GeneralPurposeConsumerSet.MAX_BATCH_SIZE elements - # here to make sure we exercise the batch size limit. - | beam.Create(np.array(range(5000), - dtype=np.int64)).with_output_types(np.int64) - | beam.ParDo(ArrayMultiplyDoFn()) - | beam.Map(lambda x: x * 3)) - - assert_that(res, equal_to([i * 2 * 3 for i in range(5000)])) - - def test_batch_rebatch_pardos(self): - # Should raise a warning about the rebatching that mentions: - # - The consuming DoFn - # - The output batch type of the producer - # - The input batch type of the consumer - with self.assertWarnsRegex(InefficientExecutionWarning, - r'ListPlusOneDoFn.*NumpyArray.*List\[int64\]'): - with self.create_pipeline() as p: - res = ( - p - | beam.Create(np.array([1, 2, 3], - dtype=np.int64)).with_output_types(np.int64) - | beam.ParDo(ArrayMultiplyDoFn()) - | beam.ParDo(ListPlusOneDoFn()) - | beam.Map(lambda x: x * 3)) - - assert_that(res, equal_to([9, 15, 21])) - - def test_batch_pardo_fusion_break(self): - class NormalizeDoFn(beam.DoFn): - @no_type_check - def process_batch( - self, - batch: np.ndarray, - mean: np.float64, - ) -> Iterator[np.ndarray]: - assert isinstance(batch, np.ndarray) - yield batch - mean - - # infer_output_type must be defined (when there's no process method), - # otherwise we don't know the input type is the same as output type. - def infer_output_type(self, input_type): - return np.float64 - - with self.create_pipeline() as p: - pc = ( - p - | beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types( - np.int64) - | beam.ParDo(ArrayMultiplyDoFn())) - - res = ( - pc - | beam.ParDo( - NormalizeDoFn(), - mean=beam.pvalue.AsSingleton( - pc | beam.CombineGlobally(beam.combiners.MeanCombineFn())))) - assert_that(res, equal_to([-2, 0, 2])) - - def test_batch_pardo_dofn_params(self): - class ConsumeParamsDoFn(beam.DoFn): - @no_type_check - def process_batch( - self, - batch: np.ndarray, - ts=beam.DoFn.TimestampParam, - pane_info=beam.DoFn.PaneInfoParam, - ) -> Iterator[np.ndarray]: - assert isinstance(batch, np.ndarray) - assert isinstance(ts, timestamp.Timestamp) - assert isinstance(pane_info, windowed_value.PaneInfo) - - yield batch * ts.seconds() - - # infer_output_type must be defined (when there's no process method), - # otherwise we don't know the input type is the same as output type. - def infer_output_type(self, input_type): - return input_type - - with self.create_pipeline() as p: - res = ( - p - | beam.Create(np.array(range(10), dtype=np.int64)).with_output_types( - np.int64) - | beam.Map(lambda t: window.TimestampedValue(t, int(t % 2))). - with_output_types(np.int64) - | beam.ParDo(ConsumeParamsDoFn())) - - assert_that(res, equal_to([0, 1, 0, 3, 0, 5, 0, 7, 0, 9])) - - def test_batch_pardo_window_param(self): - class PerWindowDoFn(beam.DoFn): - @no_type_check - def process_batch( - self, - batch: np.ndarray, - window=beam.DoFn.WindowParam, - ) -> Iterator[np.ndarray]: - yield batch * window.start.seconds() - - # infer_output_type must be defined (when there's no process method), - # otherwise we don't know the input type is the same as output type. - def infer_output_type(self, input_type): - return input_type - - with self.create_pipeline() as p: - res = ( - p - | beam.Create(np.array(range(10), dtype=np.int64)).with_output_types( - np.int64) - | beam.Map(lambda t: window.TimestampedValue(t, int(t))). - with_output_types(np.int64) - | beam.WindowInto(window.FixedWindows(5)) - | beam.ParDo(PerWindowDoFn())) - - assert_that(res, equal_to([0, 0, 0, 0, 0, 25, 30, 35, 40, 45])) - - def test_batch_pardo_overlapping_windows(self): - class PerWindowDoFn(beam.DoFn): - @no_type_check - def process_batch(self, - batch: np.ndarray, - window=beam.DoFn.WindowParam) -> Iterator[np.ndarray]: - yield batch * window.start.seconds() - - # infer_output_type must be defined (when there's no process method), - # otherwise we don't know the input type is the same as output type. - def infer_output_type(self, input_type): - return input_type - - with self.create_pipeline() as p: - res = ( - p - | beam.Create(np.array(range(10), dtype=np.int64)).with_output_types( - np.int64) - | beam.Map(lambda t: window.TimestampedValue(t, int(t))). - with_output_types(np.int64) - | beam.WindowInto(window.SlidingWindows(size=5, period=3)) - | beam.ParDo(PerWindowDoFn())) - - assert_that(res, equal_to([ 0*-3, 1*-3, # [-3, 2) - 0*0, 1*0, 2*0, 3* 0, 4* 0, # [ 0, 5) - 3*3, 4*3, 5*3, 6* 3, 7* 3, # [ 3, 8) - 6*6, 7*6, 8*6, 9* 6, # [ 6, 11) - 9*9 # [ 9, 14) - ])) - @retry(stop=stop_after_attempt(3)) def test_pardo_side_outputs(self): def tee(elem, *tags): @@ -2308,33 +2133,6 @@ def process(self, element, *side_inputs): yield self._name -class ArrayMultiplyDoFn(beam.DoFn): - def process_batch(self, batch: np.ndarray, *unused_args, - **unused_kwargs) -> Iterator[np.ndarray]: - assert isinstance(batch, np.ndarray) - # GeneralPurposeConsumerSet should limit batches to MAX_BATCH_SIZE (4096) - # elements - assert np.size(batch, axis=0) <= 4096 - yield batch * 2 - - # infer_output_type must be defined (when there's no process method), - # otherwise we don't know the input type is the same as output type. - def infer_output_type(self, input_type): - return input_type - - -class ListPlusOneDoFn(beam.DoFn): - def process_batch(self, batch: List[np.int64], *unused_args, - **unused_kwargs) -> Iterator[List[np.int64]]: - assert isinstance(batch, list) - yield [element + 1 for element in batch] - - # infer_output_type must be defined (when there's no process method), - # otherwise we don't know the input type is the same as output type. - def infer_output_type(self, input_type): - return input_type - - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py index e13b25d8eba9..b0404640ac79 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py @@ -291,11 +291,6 @@ def _subprocess_command(cls, job_port, _): str(job_port), ] - def test_batch_rebatch_pardos(self): - raise unittest.SkipTest( - "Portable runners with subprocess can't make " - "assertions about warnings raised on the worker.") - class PortableRunnerTestWithSubprocessesAndMultiWorkers( PortableRunnerTestWithSubprocesses): diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 1f6239a6dc64..3ce8fe58da5d 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -156,8 +156,8 @@ def process(self, windowed_value): def finish(self): # type: () -> None - super().finish() self.output_stream.close() + super().finish() class DataInputOperation(RunnerIOOperation): @@ -183,28 +183,21 @@ def __init__(self, windowed_coder, transform_id=transform_id, data_channel=data_channel) - - self.consumer = next(iter(consumers.values())) + # We must do this manually as we don't have a spec or spec.output_coders. + self.receivers = [ + operations.ConsumerSet.create( + self.counter_factory, + self.name_context.step_name, + 0, + next(iter(consumers.values())), + self.windowed_coder, + self._get_runtime_performance_hints()) + ] self.splitting_lock = threading.Lock() self.index = -1 self.stop = float('inf') self.started = False - def setup(self): - with self.scoped_start_state: - super().setup() - # We must do this manually as we don't have a spec or spec.output_coders. - self.receivers = [ - operations.ConsumerSet.create( - self.counter_factory, - self.name_context.step_name, - 0, - self.consumer, - self.windowed_coder, - self.get_output_batch_converter(), - self._get_runtime_performance_hints()) - ] - def start(self): # type: () -> None super().start() @@ -327,7 +320,6 @@ def is_valid_split_point(index): def finish(self): # type: () -> None - super().finish() with self.splitting_lock: self.index += 1 self.started = False @@ -869,7 +861,7 @@ def __init__(self, 'fnapi-step-%s' % self.process_bundle_descriptor.id, self.counter_factory) self.ops = self.create_execution_tree(self.process_bundle_descriptor) - for op in reversed(self.ops.values()): + for op in self.ops.values(): op.setup() self.splitting_lock = threading.Lock() diff --git a/sdks/python/apache_beam/runners/worker/opcounters.pxd b/sdks/python/apache_beam/runners/worker/opcounters.pxd index ef2d776eabd0..8b7d80375ef7 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters.pxd +++ b/sdks/python/apache_beam/runners/worker/opcounters.pxd @@ -60,10 +60,8 @@ cdef class OperationCounters(object): cdef public libc.stdint.int64_t _sample_counter cdef public libc.stdint.int64_t _next_sample cdef public object output_type_constraints - cdef public object producer_batch_converter cpdef update_from(self, windowed_value) - cpdef update_from_batch(self, windowed_batch) cdef inline do_sample(self, windowed_value) cpdef update_collect(self) cpdef type_check(self, value) diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py index 1dbcbb7a22cc..fad54aaeaf82 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters.py +++ b/sdks/python/apache_beam/runners/worker/opcounters.py @@ -32,13 +32,12 @@ from apache_beam.typehints import TypeCheckError from apache_beam.typehints.decorators import _check_instance_type from apache_beam.utils import counters -from apache_beam.utils import windowed_value from apache_beam.utils.counters import Counter from apache_beam.utils.counters import CounterName if TYPE_CHECKING: + from apache_beam.utils import windowed_value from apache_beam.runners.worker.statesampler import StateSampler - from apache_beam.typehints.batch import BatchConverter # This module is experimental. No backwards-compatibility guarantees. @@ -190,9 +189,7 @@ def __init__( coder, index, suffix='out', - producer_type_hints=None, - producer_batch_converter=None, # type: Optional[BatchConverter] - ): + producer_type_hints=None): self._counter_factory = counter_factory self.element_counter = counter_factory.get_counter( '%s-%s%s-ElementCount' % (step_name, suffix, index), Counter.SUM) @@ -205,7 +202,6 @@ def __init__( self._sample_counter = 0 self._next_sample = 0 self.output_type_constraints = producer_type_hints or {} - self.producer_batch_converter = producer_batch_converter def update_from(self, windowed_value): # type: (windowed_value.WindowedValue) -> None @@ -214,15 +210,6 @@ def update_from(self, windowed_value): if self._should_sample(): self.do_sample(windowed_value) - def update_from_batch(self, windowed_batch): - # type: (windowed_value.WindowedBatch) -> None - assert self.producer_batch_converter is not None - assert isinstance(windowed_batch, windowed_value.HomogeneousWindowedBatch) - - self.element_counter.update( - self.producer_batch_converter.get_length(windowed_batch.values)) - # TODO(BEAM-14408): Update byte size estimate - def _observable_callback(self, inner_coder_impl, accumulator): def _observable_callback_inner(value, is_encoded=False): # TODO(ccy): If this stream is large, sample it as well. diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd index cf1f1b3fb512..800e5870b96b 100644 --- a/sdks/python/apache_beam/runners/worker/operations.pxd +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -21,7 +21,6 @@ from apache_beam.runners.common cimport DoFnRunner from apache_beam.runners.common cimport Receiver from apache_beam.runners.worker cimport opcounters from apache_beam.utils.windowed_value cimport WindowedValue -from apache_beam.utils.windowed_value cimport WindowedBatch #from libcpp.string cimport string cdef WindowedValue _globally_windowed_value @@ -35,28 +34,14 @@ cdef class ConsumerSet(Receiver): cdef public output_index cdef public coder + cpdef receive(self, WindowedValue windowed_value) cpdef update_counters_start(self, WindowedValue windowed_value) cpdef update_counters_finish(self) - cpdef update_counters_batch(self, WindowedBatch windowed_batch) - -cdef class SingletonElementConsumerSet(ConsumerSet): - cdef Operation consumer - cpdef receive(self, WindowedValue windowed_value) - cpdef receive_batch(self, WindowedBatch windowed_batch) - cpdef flush(self) -cdef class GeneralPurposeConsumerSet(ConsumerSet): - cdef list element_consumers - cdef list passthrough_batch_consumers - cdef dict other_batch_consumers - cdef bint has_batch_consumers - cdef list _batched_elements - cdef object producer_batch_converter +cdef class SingletonConsumerSet(ConsumerSet): + cdef Operation consumer - cpdef receive(self, WindowedValue windowed_value) - cpdef receive_batch(self, WindowedBatch windowed_batch) - cpdef flush(self) cdef class Operation(object): cdef readonly name_context @@ -111,7 +96,6 @@ cdef class DoOperation(Operation): cdef public dict timer_inputs cdef dict timer_specs cdef public object input_info - cdef object fn cdef class SdfProcessSizedElements(DoOperation): diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index f46a4a183699..3464c5750c57 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -26,7 +26,6 @@ import collections import logging import threading -import warnings from typing import TYPE_CHECKING from typing import Any from typing import DefaultDict @@ -61,8 +60,6 @@ from apache_beam.transforms.combiners import PhasedCombineFnExecutor from apache_beam.transforms.combiners import curry_combine_fn from apache_beam.transforms.window import GlobalWindows -from apache_beam.typehints.batch import BatchConverter -from apache_beam.utils.windowed_value import WindowedBatch from apache_beam.utils.windowed_value import WindowedValue if TYPE_CHECKING: @@ -107,56 +104,53 @@ def create(counter_factory, output_index, consumers, # type: List[Operation] coder, - producer_type_hints, - producer_batch_converter, # type: Optional[BatchConverter] + producer_type_hints ): # type: (...) -> ConsumerSet if len(consumers) == 1: - consumer = consumers[0] - - consumer_batch_preference = consumer.get_batching_preference() - consumer_batch_converter = consumer.get_input_batch_converter() - if (not consumer_batch_preference.supports_batches and - producer_batch_converter is None and - consumer_batch_converter is None): - return SingletonElementConsumerSet( - counter_factory, - step_name, - output_index, - consumer, - coder, - producer_type_hints) - - return GeneralPurposeConsumerSet( - counter_factory, - step_name, - output_index, - coder, - producer_type_hints, - consumers, - producer_batch_converter) + return SingletonConsumerSet( + counter_factory, + step_name, + output_index, + consumers, + coder, + producer_type_hints) + else: + return ConsumerSet( + counter_factory, + step_name, + output_index, + consumers, + coder, + producer_type_hints) def __init__(self, counter_factory, step_name, # type: str output_index, - consumers, + consumers, # type: List[Operation] coder, - producer_type_hints, - producer_batch_converter + producer_type_hints ): + self.consumers = consumers + self.opcounter = opcounters.OperationCounters( counter_factory, step_name, coder, output_index, - producer_type_hints=producer_type_hints, - producer_batch_converter=producer_batch_converter) + producer_type_hints=producer_type_hints) # Used in repr. self.step_name = step_name self.output_index = output_index self.coder = coder - self.consumers = consumers + + def receive(self, windowed_value): + # type: (WindowedValue) -> None + self.update_counters_start(windowed_value) + for consumer in self.consumers: + cython.cast(Operation, consumer).process(windowed_value) + self.update_counters_finish() def try_split(self, fraction_of_remainder): # type: (...) -> Optional[Any] @@ -187,10 +181,6 @@ def update_counters_finish(self): # type: () -> None self.opcounter.update_collect() - def update_counters_batch(self, windowed_batch): - # type: (WindowedBatch) -> None - self.opcounter.update_from_batch(windowed_batch) - def __repr__(self): return '%s[%s.out%s, coder=%s, len(consumers)=%s]' % ( self.__class__.__name__, @@ -200,25 +190,24 @@ def __repr__(self): len(self.consumers)) -class SingletonElementConsumerSet(ConsumerSet): - """ConsumerSet representing a single consumer that can only process elements - (not batches).""" +class SingletonConsumerSet(ConsumerSet): def __init__(self, counter_factory, step_name, output_index, - consumer, # type: Operation + consumers, # type: List[Operation] coder, producer_type_hints ): - super().__init__( + assert len(consumers) == 1 + super(SingletonConsumerSet, self).__init__( counter_factory, step_name, - output_index, [consumer], + output_index, + consumers, coder, - producer_type_hints, - None) - self.consumer = consumer + producer_type_hints) + self.consumer = consumers[0] def receive(self, windowed_value): # type: (WindowedValue) -> None @@ -226,14 +215,6 @@ def receive(self, windowed_value): self.consumer.process(windowed_value) self.update_counters_finish() - def receive_batch(self, windowed_batch): - raise AssertionError( - "SingletonElementConsumerSet.receive_batch is not implemented") - - def flush(self): - # SingletonElementConsumerSet has no buffer to flush - pass - def try_split(self, fraction_of_remainder): # type: (...) -> Optional[Any] return self.consumer.try_split(fraction_of_remainder) @@ -242,133 +223,6 @@ def current_element_progress(self): return self.consumer.current_element_progress() -class GeneralPurposeConsumerSet(ConsumerSet): - """ConsumerSet implementation that handles all combinations of possible edges. - """ - MAX_BATCH_SIZE = 4096 - - def __init__(self, - counter_factory, - step_name, # type: str - output_index, - coder, - producer_type_hints, - consumers, # type: List[Operation] - producer_batch_converter): - super().__init__( - counter_factory, - step_name, - output_index, - consumers, - coder, - producer_type_hints, - producer_batch_converter) - - self.producer_batch_converter = producer_batch_converter - - # Partition consumers into three groups: - # - consumers that will be passed elements - # - consumers that will be passed batches (where their input batch type - # matches the output of the producer) - # - consumers that will be passed converted batches - self.element_consumers: List[Operation] = [] - self.passthrough_batch_consumers: List[Operation] = [] - other_batch_consumers: DefaultDict[ - BatchConverter, List[Operation]] = collections.defaultdict(lambda: []) - - for consumer in consumers: - if not consumer.get_batching_preference().supports_batches: - self.element_consumers.append(consumer) - elif (consumer.get_input_batch_converter() == - self.producer_batch_converter): - self.passthrough_batch_consumers.append(consumer) - else: - # Batch consumer with a mismatched batch type - if consumer.get_batching_preference().supports_elements: - # Pass it elements if we can - self.element_consumers.append(consumer) - else: - # As a last resort, explode and rebatch - consumer_batch_converter = consumer.get_input_batch_converter() - # This consumer supports batches, it must have a batch converter - assert consumer_batch_converter is not None - other_batch_consumers[consumer_batch_converter].append(consumer) - - self.other_batch_consumers: Dict[BatchConverter, List[Operation]] = dict( - other_batch_consumers) - - self.has_batch_consumers = ( - self.passthrough_batch_consumers or self.other_batch_consumers) - self._batched_elements: List[Any] = [] - - def receive(self, windowed_value): - # type: (WindowedValue) -> None - - self.update_counters_start(windowed_value) - - for consumer in self.element_consumers: - cython.cast(Operation, consumer).process(windowed_value) - - # TODO: Do this branching when contstructing ConsumerSet - if self.has_batch_consumers: - self._batched_elements.append(windowed_value) - if len(self._batched_elements) >= self.MAX_BATCH_SIZE: - self.flush() - - # TODO(BEAM-14408): Properly estimate sizes in the batch-consumer only case, - # this undercounts large iterables - self.update_counters_finish() - - def receive_batch(self, windowed_batch): - if self.element_consumers: - for wv in windowed_batch.as_windowed_values( - self.producer_batch_converter.explode_batch): - for consumer in self.element_consumers: - cython.cast(Operation, consumer).process(wv) - - for consumer in self.passthrough_batch_consumers: - cython.cast(Operation, consumer).process_batch(windowed_batch) - - for (consumer_batch_converter, - consumers) in self.other_batch_consumers.items(): - # Explode and rebatch into the new batch type (ouch!) - # TODO: Register direct conversions for equivalent batch types - - for consumer in consumers: - warnings.warn( - f"Input to operation {consumer} must be rebatched from type " - f"{self.producer_batch_converter.batch_type!r} to " - f"{consumer_batch_converter.batch_type!r}.\n" - "This is very inefficient, consider re-structuring your pipeline " - "or adding a DoFn to directly convert between these types.", - InefficientExecutionWarning) - cython.cast(Operation, consumer).process_batch( - windowed_batch.with_values( - consumer_batch_converter.produce_batch( - self.producer_batch_converter.explode_batch( - windowed_batch.values)))) - - self.update_counters_batch(windowed_batch) - - def flush(self): - if not self.has_batch_consumers or not self._batched_elements: - return - - for batch_converter, consumers in self.other_batch_consumers.items(): - for windowed_batch in WindowedBatch.from_windowed_values( - self._batched_elements, produce_fn=batch_converter.produce_batch): - for consumer in consumers: - cython.cast(Operation, consumer).process_batch(windowed_batch) - - for consumer in self.passthrough_batch_consumers: - for windowed_batch in WindowedBatch.from_windowed_values( - self._batched_elements, - produce_fn=self.producer_batch_converter.produce_batch): - cython.cast(Operation, consumer).process_batch(windowed_batch) - - self._batched_elements = [] - - class Operation(object): """An operation representing the live version of a work item specification. @@ -438,9 +292,7 @@ def setup(self): i, self.consumers[i], coder, - self._get_runtime_performance_hints(), - self.get_output_batch_converter(), - ) for i, + self._get_runtime_performance_hints()) for i, coder in enumerate(self.spec.output_coders) ] self.setup_done = True @@ -453,29 +305,12 @@ def start(self): # For legacy workers. self.setup() - def get_batching_preference(self): - # By default operations don't support batching, require Receiver to unbatch - return common.BatchingPreference.BATCH_FORBIDDEN - - def get_input_batch_converter(self) -> Optional[BatchConverter]: - """Returns a batch type converter if this operation can accept a batch, - otherwise None.""" - return None - - def get_output_batch_converter(self) -> Optional[BatchConverter]: - """Returns a batch type converter if this operation can produce a batch, - otherwise None.""" - return None - def process(self, o): # type: (WindowedValue) -> None """Process element in operation.""" pass - def process_batch(self, batch: WindowedBatch): - pass - def finalize_bundle(self): # type: () -> None pass @@ -494,8 +329,7 @@ def finish(self): # type: () -> None """Finish operation.""" - for receiver in self.receivers: - cython.cast(Receiver, receiver).flush() + pass def teardown(self): # type: () -> None @@ -678,8 +512,7 @@ def __init__( 0, next(iter(consumers.values())), output_coder, - self._get_runtime_performance_hints(), - self.get_output_batch_converter()) + self._get_runtime_performance_hints()) ] def process(self, unused_impulse): @@ -711,8 +544,8 @@ def __init__(self, counter_factory, step_name): self._step_name = step_name def __missing__(self, tag): - self[tag] = receiver = ConsumerSet.create( - self._counter_factory, self._step_name, tag, [], None, None, None) + self[tag] = receiver = ConsumerSet( + self._counter_factory, self._step_name, tag, [], None, None) return receiver def total_output_bytes(self): @@ -754,10 +587,6 @@ def __init__(self, # A mapping of timer tags to the input "PCollections" they come in on. self.input_info = None # type: Optional[OpInputInfo] - # See fn_data in dataflow_runner.py - # TODO: Store all the items from spec? - self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn)) - def _read_side_inputs(self, tags_and_types): # type: (...) -> Iterator[apache_sideinputs.SideInputMap] @@ -873,21 +702,6 @@ def start(self): super(DoOperation, self).start() self.dofn_runner.start() - def get_batching_preference(self): - if self.fn.process_batch_defined: - if self.fn.process_defined: - return common.BatchingPreference.DO_NOT_CARE - else: - return common.BatchingPreference.BATCH_REQUIRED - else: - return common.BatchingPreference.BATCH_FORBIDDEN - - def get_input_batch_converter(self) -> Optional[BatchConverter]: - return getattr(self.fn, 'input_batch_converter', None) - - def get_output_batch_converter(self) -> Optional[BatchConverter]: - return getattr(self.fn, 'output_batch_converter', None) - def process(self, o): # type: (WindowedValue) -> None with self.scoped_process_state: @@ -898,9 +712,6 @@ def process(self, o): self.execution_context.delayed_applications.append( (self, delayed_application)) - def process_batch(self, windowed_batch: WindowedBatch) -> None: - self.dofn_runner.process_batch(windowed_batch) - def finalize_bundle(self): # type: () -> None self.dofn_runner.finalize() @@ -924,7 +735,6 @@ def process_timer(self, tag, timer_data): def finish(self): # type: () -> None - super(DoOperation, self).finish() with self.scoped_finish_state: self.dofn_runner.finish() if self.user_state_context: @@ -1111,7 +921,6 @@ def process(self, o): def finish(self): # type: () -> None _LOGGER.debug('Finishing %s', self) - super(CombineOperation, self).finish() def teardown(self): # type: () -> None @@ -1157,7 +966,6 @@ def process(self, o): def finish(self): # type: () -> None self.flush(0) - super().finish() def flush(self, target): # type: (int) -> None @@ -1462,11 +1270,3 @@ def execute(self): op.start() for op in self._ops: op.finish() - - -class InefficientExecutionWarning(RuntimeWarning): - """warning to indicate an inefficiency in a Beam pipeline.""" - - -# Don't ignore InefficientExecutionWarning, but only log them once -warnings.simplefilter('once', InefficientExecutionWarning) diff --git a/sdks/python/apache_beam/transforms/batch_dofn_test.py b/sdks/python/apache_beam/transforms/batch_dofn_test.py index f1fc7eda0939..5f05e371ec69 100644 --- a/sdks/python/apache_beam/transforms/batch_dofn_test.py +++ b/sdks/python/apache_beam/transforms/batch_dofn_test.py @@ -22,7 +22,6 @@ import unittest from typing import Iterator from typing import List -from typing import no_type_check from parameterized import parameterized_class @@ -128,19 +127,6 @@ def test_no_input_annotation_raises(self): r'BatchDoFnNoInputAnnotation.process_batch'): _ = pc | beam.ParDo(BatchDoFnNoInputAnnotation()) - def test_unsupported_dofn_param_raises(self): - class BatchDoFnBadParam(beam.DoFn): - @no_type_check - def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam): - yield batch * key - - p = beam.Pipeline() - pc = p | beam.Create([1, 2, 3]) - - with self.assertRaisesRegex(NotImplementedError, - r'.*BatchDoFnBadParam.*KeyParam'): - _ = pc | beam.ParDo(BatchDoFnBadParam()) - if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index d4fdfb14c18a..a22b408378e6 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -33,8 +33,6 @@ from typing import TypeVar from typing import Union -import numpy as np - from apache_beam import typehints from apache_beam.transforms import core from apache_beam.transforms import cy_combiners @@ -90,7 +88,7 @@ def expand(self, pcoll): # TODO(laolu): This type signature is overly restrictive. This should be # more general. -@with_input_types(Union[float, int, np.int64, np.float64]) +@with_input_types(Union[float, int]) @with_output_types(float) class MeanCombineFn(core.CombineFn): """CombineFn for computing an arithmetic mean.""" diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index f3ecd00cccc2..191ba8a1f76e 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -2045,7 +2045,7 @@ def test_mean_globally_pipeline_checking_violated(self): expected_msg = \ "Type hint violation for 'CombinePerKey': " \ - "requires Tuple[TypeVariable[K], Union[float, float64, int, int64]] " \ + "requires Tuple[TypeVariable[K], Union[float, int]] " \ "but got Tuple[None, str] for element" self.assertStartswith(e.exception.args[0], expected_msg) @@ -2111,7 +2111,7 @@ def test_mean_per_key_pipeline_checking_violated(self): expected_msg = \ "Type hint violation for 'CombinePerKey(MeanCombineFn)': " \ - "requires Tuple[TypeVariable[K], Union[float, float64, int, int64]] " \ + "requires Tuple[TypeVariable[K], Union[float, int]] " \ "but got Tuple[str, str] for element" self.assertStartswith(e.exception.args[0], expected_msg) @@ -2151,8 +2151,8 @@ def test_mean_per_key_runtime_checking_violated(self): "Runtime type violation detected within " \ "OddMean/CombinePerKey(MeanCombineFn): " \ "Type-hint for argument: 'element' violated: " \ - "Union[float, float64, int, int64] type-constraint violated. " \ - "Expected an instance of one of: ('float', 'float64', 'int', 'int64'), " \ + "Union[float, int] type-constraint violated. " \ + "Expected an instance of one of: ('float', 'int'), " \ "received str instead" self.assertStartswith(e.exception.args[0], expected_msg) diff --git a/sdks/python/apache_beam/typehints/batch.py b/sdks/python/apache_beam/typehints/batch.py index 47294bb1f6e4..3fcf79ec45af 100644 --- a/sdks/python/apache_beam/typehints/batch.py +++ b/sdks/python/apache_beam/typehints/batch.py @@ -226,12 +226,6 @@ def __eq__(self, other) -> bool: def __hash__(self) -> int: return hash(self.__key()) - def __repr__(self): - if self.shape == (N, ): - return f'NumpyArray[{self.dtype!r}]' - else: - return f'NumpyArray[{self.dtype!r}, {self.shape!r}]' - def __getitem__(self, value): if isinstance(value, tuple): if len(value) == 2: diff --git a/sdks/python/apache_beam/utils/windowed_value.pxd b/sdks/python/apache_beam/utils/windowed_value.pxd index 91e36789d69c..5d867c83c384 100644 --- a/sdks/python/apache_beam/utils/windowed_value.pxd +++ b/sdks/python/apache_beam/utils/windowed_value.pxd @@ -43,14 +43,6 @@ cdef class WindowedValue(object): cpdef WindowedValue with_value(self, new_value) -cdef class WindowedBatch(object): - cpdef WindowedBatch with_values(self, object new_values) - -cdef class HomogeneousWindowedBatch(WindowedBatch): - cdef public WindowedValue _wv - - cpdef WindowedBatch with_values(self, object new_values) - @cython.locals(wv=WindowedValue) cpdef WindowedValue create( object value, int64_t timestamp_micros, object windows, object pane_info=*) diff --git a/sdks/python/apache_beam/utils/windowed_value.py b/sdks/python/apache_beam/utils/windowed_value.py index d80becb41c01..08fca45c31c8 100644 --- a/sdks/python/apache_beam/utils/windowed_value.py +++ b/sdks/python/apache_beam/utils/windowed_value.py @@ -30,14 +30,10 @@ # pytype: skip-file -import collections from typing import TYPE_CHECKING from typing import Any -from typing import Callable -from typing import Iterable from typing import List from typing import Optional -from typing import Sequence from typing import Tuple from apache_beam.utils.timestamp import MAX_TIMESTAMP @@ -150,14 +146,10 @@ def __repr__(self): def __eq__(self, other): if self is other: return True - - if isinstance(other, PaneInfo): - return ( - self.is_first == other.is_first and self.is_last == other.is_last and - self.timing == other.timing and self.index == other.index and - self.nonspeculative_index == other.nonspeculative_index) - - return NotImplemented + return ( + self.is_first == other.is_first and self.is_last == other.is_last and + self.timing == other.timing and self.index == other.index and + self.nonspeculative_index == other.nonspeculative_index) def __hash__(self): return hash(( @@ -211,13 +203,13 @@ class WindowedValue(object): the pane that contained this value. If None, will be set to PANE_INFO_UNKNOWN. """ - def __init__( - self, - value, - timestamp, # type: TimestampTypes - windows, # type: Tuple[BoundedWindow, ...] - pane_info=PANE_INFO_UNKNOWN # type: PaneInfo - ): + + def __init__(self, + value, + timestamp, # type: TimestampTypes + windows, # type: Tuple[BoundedWindow, ...] + pane_info=PANE_INFO_UNKNOWN # type: PaneInfo + ): # type: (...) -> None # For performance reasons, only timestamp_micros is stored by default # (as a C int). The Timestamp object is created on demand below. @@ -250,18 +242,16 @@ def __repr__(self): self.pane_info) def __eq__(self, other): - if isinstance(other, WindowedValue): - return ( - type(self) == type(other) and - self.timestamp_micros == other.timestamp_micros and - self.value == other.value and self.windows == other.windows and - self.pane_info == other.pane_info) - return NotImplemented + return ( + type(self) == type(other) and + self.timestamp_micros == other.timestamp_micros and + self.value == other.value and self.windows == other.windows and + self.pane_info == other.pane_info) def __hash__(self): return ((hash(self.value) & 0xFFFFFFFFFFFFFFF) + 3 * (self.timestamp_micros & 0xFFFFFFFFFFFFFF) + 7 * - (hash(tuple(self.windows)) & 0xFFFFFFFFFFFFF) + 11 * + (hash(self.windows) & 0xFFFFFFFFFFFFF) + 11 * (hash(self.pane_info) & 0xFFFFFFFFFFFFF)) def with_value(self, new_value): @@ -280,8 +270,6 @@ def __reduce__(self): # TODO(robertwb): Move this to a static method. - - def create(value, timestamp_micros, windows, pane_info=PANE_INFO_UNKNOWN): wv = WindowedValue.__new__(WindowedValue) wv.value = value @@ -291,89 +279,6 @@ def create(value, timestamp_micros, windows, pane_info=PANE_INFO_UNKNOWN): return wv -class WindowedBatch(object): - """A batch of N windowed values, each having a value, a timestamp and set of - windows.""" - def with_values(self, new_values): - # type: (Any) -> WindowedBatch - - """Creates a new WindowedBatch with the same timestamps and windows as this. - - This is the fasted way to create a new WindowedValue. - """ - raise NotImplementedError - - def as_windowed_values(self, explode_fn: Callable) -> Iterable[WindowedValue]: - raise NotImplementedError - - @staticmethod - def from_windowed_values( - windowed_values: Sequence[WindowedValue], *, - produce_fn: Callable) -> Iterable['WindowedBatch']: - return HomogeneousWindowedBatch.from_windowed_values( - windowed_values, produce_fn=produce_fn) - - -class HomogeneousWindowedBatch(WindowedBatch): - """A WindowedBatch with Homogeneous event-time information, represented - internally as a WindowedValue. - """ - def __init__(self, wv): - self._wv = wv - - @staticmethod - def of(values, timestamp, windows, pane_info): - return HomogeneousWindowedBatch( - WindowedValue(values, timestamp, windows, pane_info)) - - @property - def values(self): - return self._wv.value - - @property - def timestamp(self): - return self._wv.timestamp - - @property - def pane_info(self): - return self._wv.pane_info - - @property - def windows(self): - return self._wv.windows - - @windows.setter - def windows(self, value): - self._wv.windows = value - - def with_values(self, new_values): - # type: (Any) -> WindowedBatch - return HomogeneousWindowedBatch(self._wv.with_value(new_values)) - - def as_windowed_values(self, explode_fn: Callable) -> Iterable[WindowedValue]: - for value in explode_fn(self._wv.value): - yield self._wv.with_value(value) - - def __eq__(self, other): - if isinstance(other, HomogeneousWindowedBatch): - return self._wv == other._wv - return NotImplemented - - def __hash__(self): - return hash(self._wv) - - @staticmethod - def from_windowed_values( - windowed_values: Sequence[WindowedValue], *, - produce_fn: Callable) -> Iterable['WindowedBatch']: - grouped = collections.defaultdict(lambda: []) - for wv in windowed_values: - grouped[wv.with_value(None)].append(wv.value) - - for key, values in grouped.items(): - yield HomogeneousWindowedBatch(key.with_value(produce_fn(values))) - - try: WindowedValue.timestamp_object = None except TypeError: diff --git a/sdks/python/apache_beam/utils/windowed_value_test.py b/sdks/python/apache_beam/utils/windowed_value_test.py index 1e4892aa9bd3..bf4048a9bd06 100644 --- a/sdks/python/apache_beam/utils/windowed_value_test.py +++ b/sdks/python/apache_beam/utils/windowed_value_test.py @@ -20,13 +20,9 @@ # pytype: skip-file import copy -import itertools import pickle import unittest -from parameterized import parameterized -from parameterized import parameterized_class - from apache_beam.utils import windowed_value from apache_beam.utils.timestamp import Timestamp @@ -76,93 +72,5 @@ def test_pickle(self): self.assertTrue(pickle.loads(pickle.dumps(wv)) == wv) -WINDOWED_BATCH_INSTANCES = [ - windowed_value.HomogeneousWindowedBatch.of( - None, 3, (), windowed_value.PANE_INFO_UNKNOWN), - windowed_value.HomogeneousWindowedBatch.of( - None, - 3, (), - windowed_value.PaneInfo( - True, False, windowed_value.PaneInfoTiming.ON_TIME, 0, 0)), -] - - -class WindowedBatchTest(unittest.TestCase): - def test_homogeneous_windowed_batch_with_values(self): - pane_info = windowed_value.PaneInfo( - True, True, windowed_value.PaneInfoTiming.ON_TIME, 0, 0) - wb = windowed_value.HomogeneousWindowedBatch.of(['foo', 'bar'], - 6, (), - pane_info) - self.assertEqual( - wb.with_values(['baz', 'foo']), - windowed_value.HomogeneousWindowedBatch.of(['baz', 'foo'], - 6, (), - pane_info)) - - def test_homogeneous_windowed_batch_as_windowed_values(self): - pane_info = windowed_value.PaneInfo( - True, True, windowed_value.PaneInfoTiming.ON_TIME, 0, 0) - wb = windowed_value.HomogeneousWindowedBatch.of(['foo', 'bar'], - 3, (), - pane_info) - - self.assertEqual( - list(wb.as_windowed_values(iter)), - [ - windowed_value.WindowedValue('foo', 3, (), pane_info), - windowed_value.WindowedValue('bar', 3, (), pane_info) - ]) - - @parameterized.expand(itertools.combinations(WINDOWED_BATCH_INSTANCES, 2)) - def test_inequality(self, left_wb, right_wb): - self.assertNotEqual(left_wb, right_wb) - - def test_equals_different_type(self): - wb = windowed_value.HomogeneousWindowedBatch.of( - None, 3, (), windowed_value.PANE_INFO_UNKNOWN) - self.assertNotEqual(wb, object()) - - def test_homogeneous_from_windowed_values(self): - pane_info = windowed_value.PaneInfo( - True, True, windowed_value.PaneInfoTiming.ON_TIME, 0, 0) - - windowed_values = [ - windowed_value.WindowedValue('foofoo', 3, (), pane_info), - windowed_value.WindowedValue('foobar', 6, (), pane_info), - windowed_value.WindowedValue('foobaz', 9, (), pane_info), - windowed_value.WindowedValue('barfoo', 3, (), pane_info), - windowed_value.WindowedValue('barbar', 6, (), pane_info), - windowed_value.WindowedValue('barbaz', 9, (), pane_info), - windowed_value.WindowedValue('bazfoo', 3, (), pane_info), - windowed_value.WindowedValue('bazbar', 6, (), pane_info), - windowed_value.WindowedValue('bazbaz', 9, (), pane_info), - ] - - self.assertEqual( - list( - windowed_value.WindowedBatch.from_windowed_values( - windowed_values, produce_fn=list)), - [ - windowed_value.HomogeneousWindowedBatch.of( - ['foofoo', 'barfoo', 'bazfoo'], 3, (), pane_info), - windowed_value.HomogeneousWindowedBatch.of( - ['foobar', 'barbar', 'bazbar'], 6, (), pane_info), - windowed_value.HomogeneousWindowedBatch.of( - ['foobaz', 'barbaz', 'bazbaz'], 9, (), pane_info) - ]) - - -@parameterized_class(('wb', ), [(wb, ) for wb in WINDOWED_BATCH_INSTANCES]) -class WindowedBatchUtilitiesTest(unittest.TestCase): - def test_hash(self): - wb_copy = copy.copy(self.wb) - self.assertFalse(self.wb is wb_copy) - self.assertEqual({self.wb: 100}.get(wb_copy), 100) - - def test_pickle(self): - self.assertTrue(pickle.loads(pickle.dumps(self.wb)) == self.wb) - - if __name__ == '__main__': unittest.main()