diff --git a/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py b/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py index afbaa182bf67..8f5f91ce2689 100644 --- a/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py +++ b/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py @@ -130,6 +130,7 @@ def test_leader_board_it(self): self.project, teams_query, self.DEFAULT_EXPECTED_CHECKSUM) extra_opts = { + 'allow_unsafe_triggers': True, 'subscription': self.input_sub.name, 'dataset': self.dataset_ref.dataset_id, 'topic': self.input_topic.name, diff --git a/sdks/python/apache_beam/examples/complete/game/leader_board_test.py b/sdks/python/apache_beam/examples/complete/game/leader_board_test.py index 167ce4a1f664..1c1cd6548923 100644 --- a/sdks/python/apache_beam/examples/complete/game/leader_board_test.py +++ b/sdks/python/apache_beam/examples/complete/game/leader_board_test.py @@ -24,6 +24,7 @@ import apache_beam as beam from apache_beam.examples.complete.game import leader_board +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -59,7 +60,8 @@ def test_leader_board_teams(self): ('team3', 13)])) def test_leader_board_users(self): - with TestPipeline() as p: + test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) + with TestPipeline(options=test_options) as p: result = ( self.create_data(p) | leader_board.CalculateUserScores(allowed_lateness=120)) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index 5c7bd46b40b2..8f215a336177 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -1081,8 +1081,8 @@ def test_model_early_late_triggers(self): assert_that(counts, equal_to([('a', 4), ('b', 2), ('a', 1)])) def test_model_setting_trigger(self): - pipeline_options = PipelineOptions() - pipeline_options.view_as(StandardOptions).streaming = True + pipeline_options = PipelineOptions( + flags=['--streaming', '--allow_unsafe_triggers']) with TestPipeline(options=pipeline_options) as p: test_stream = ( @@ -1136,8 +1136,8 @@ def test_model_composite_triggers(self): assert_that(counts, equal_to([('a', 3), ('b', 2), ('a', 2), ('c', 2)])) def test_model_other_composite_triggers(self): - pipeline_options = PipelineOptions() - pipeline_options.view_as(StandardOptions).streaming = True + pipeline_options = PipelineOptions( + flags=['--streaming', '--allow_unsafe_triggers']) with TestPipeline(options=pipeline_options) as p: test_stream = ( diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index ff1b50e55e47..9eb59b51d507 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -41,6 +41,7 @@ from apache_beam.io.gcp.internal.clients import bigquery as bigquery_api from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.runners.dataflow.test_dataflow_runner import TestDataflowRunner from apache_beam.runners.runner import PipelineState @@ -648,8 +649,10 @@ def __call__(self): with_auto_sharding=with_auto_sharding) # Need to test this with the DirectRunner to avoid serializing mocks + test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) + test_options.view_as(StandardOptions).streaming = is_streaming with TestPipeline(runner='BundleBasedDirectRunner', - options=StandardOptions(streaming=is_streaming)) as p: + options=test_options) as p: if is_streaming: _SIZE = len(_ELEMENTS) fisrt_batch = [ diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index c3178da03f48..41bbfe277990 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -1322,7 +1322,8 @@ def _run_pubsub_bq_pipeline(self, method, triggering_frequency=None): args = self.test_pipeline.get_full_options_as_args( on_success_matcher=hc.all_of(*matchers), wait_until_finish_duration=self.WAIT_UNTIL_FINISH_DURATION, - streaming=True) + streaming=True, + allow_unsafe_triggers=True) def add_schema_info(element): yield {'number': element} diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 073014c598ca..335cca8cfff1 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -532,6 +532,15 @@ def _add_argparse_args(cls, parser): 'operations such as GropuByKey. This is unsafe, as runners may group ' 'keys based on their encoded bytes, but is available for backwards ' 'compatibility. See BEAM-11719.') + parser.add_argument( + '--allow_unsafe_triggers', + default=False, + action='store_true', + help='Allow the use of unsafe triggers. Unsafe triggers have the ' + 'potential to cause data loss due to finishing and/or never having ' + 'their condition met. Some operations, such as GroupByKey, disallow ' + 'this. This exists for cases where such loss is acceptable and for ' + 'backwards compatibility. See BEAM-9487.') def validate(self, unused_validator): errors = [] diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 73510806d11c..69362a4449be 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -238,6 +238,11 @@ def options(self): # type: () -> PipelineOptions return self._options + @property + def allow_unsafe_triggers(self): + # type: () -> bool + return self._options.view_as(TypeOptions).allow_unsafe_triggers + def _current_transform(self): # type: () -> AppliedPTransform diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py index ec6309d450ec..94445dd1b477 100644 --- a/sdks/python/apache_beam/testing/test_stream_test.py +++ b/sdks/python/apache_beam/testing/test_stream_test.py @@ -24,6 +24,7 @@ import apache_beam as beam from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord @@ -427,6 +428,7 @@ def test_gbk_execution_after_processing_trigger_fired(self): options = PipelineOptions() options.view_as(StandardOptions).streaming = True + options.view_as(TypeOptions).allow_unsafe_triggers = True p = TestPipeline(options=options) records = ( p diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index c900a480853d..a244f805ee35 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -26,6 +26,7 @@ from parameterized import parameterized_class from apache_beam.options.pipeline_options import DebugOptions +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.runners.direct import direct_runner from apache_beam.runners.portability import fn_api_runner @@ -88,7 +89,10 @@ def test_combine(self): self._assert_teardown_called() def test_non_liftable_combine(self): - run_combine(TestPipeline(runner=self.runner()), lift_combiners=False) + test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) + run_combine( + TestPipeline(runner=self.runner(), options=test_options), + lift_combiners=False) self._assert_teardown_called() def test_combining_value_state(self): diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index cd6fb666334d..acc3c70126c2 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2314,15 +2314,30 @@ def infer_output_type(self, input_type): key_type, typehints.WindowedValue[value_type]]] # type: ignore[misc] def expand(self, pcoll): + from apache_beam.transforms.trigger import DataLossReason from apache_beam.transforms.trigger import DefaultTrigger windowing = pcoll.windowing + trigger = windowing.triggerfn if not pcoll.is_bounded and isinstance( - windowing.windowfn, GlobalWindows) and isinstance(windowing.triggerfn, + windowing.windowfn, GlobalWindows) and isinstance(trigger, DefaultTrigger): raise ValueError( 'GroupByKey cannot be applied to an unbounded ' + 'PCollection with global windowing and a default trigger') + if not pcoll.pipeline.allow_unsafe_triggers: + unsafe_reason = trigger.may_lose_data(windowing) + if unsafe_reason != DataLossReason.NO_POTENTIAL_LOSS: + msg = 'Unsafe trigger: `{}` may lose data. '.format(trigger) + msg += 'Reason: {}. '.format( + str(unsafe_reason).replace('DataLossReason.', '')) + msg += 'This can be overriden with the --allow_unsafe_triggers flag.' + raise ValueError(msg) + else: + _LOGGER.warning( + 'Skipping trigger safety check. ' + 'This could lead to incomplete or missing groups.') + return pvalue.PCollection.from_(pcoll) def infer_output_type(self, input_type): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index bd747931e431..64c684745ae5 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -41,13 +41,16 @@ from apache_beam.io.iobase import Read from apache_beam.metrics import Metrics from apache_beam.metrics.metric import MetricsFilter +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.testing.util import is_empty from apache_beam.transforms import WindowInto +from apache_beam.transforms import trigger from apache_beam.transforms import window from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display import DisplayDataItem @@ -481,6 +484,33 @@ def test_group_by_key_unbounded_global_default_trigger(self): with TestPipeline() as pipeline: pipeline | TestStream() | beam.GroupByKey() + def test_group_by_key_unsafe_trigger(self): + with self.assertRaisesRegex(ValueError, 'Unsafe trigger'): + with TestPipeline() as pipeline: + _ = ( + pipeline + | beam.Create([(None, None)]) + | WindowInto( + window.GlobalWindows(), + trigger=trigger.AfterCount(5), + accumulation_mode=trigger.AccumulationMode.ACCUMULATING) + | beam.GroupByKey()) + + def test_group_by_key_allow_unsafe_triggers(self): + test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) + with TestPipeline(options=test_options) as pipeline: + pcoll = ( + pipeline + | beam.Create([(1, 1), (1, 2), (1, 3), (1, 4)]) + | WindowInto( + window.GlobalWindows(), + trigger=trigger.AfterCount(5), + accumulation_mode=trigger.AccumulationMode.ACCUMULATING) + | beam.GroupByKey()) + # We need five, but it only has four - Displays how this option is + # dangerous. + assert_that(pcoll, is_empty()) + def test_group_by_key_reiteration(self): class MyDoFn(beam.DoFn): def process(self, gbk_result): diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 556789532ca5..6569d3fe2987 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -28,7 +28,11 @@ import numbers from abc import ABCMeta from abc import abstractmethod +from enum import Flag +from enum import auto +from functools import reduce from itertools import zip_longest +from operator import or_ from apache_beam.coders import coder_impl from apache_beam.coders import observable @@ -156,6 +160,13 @@ def with_prefix(self, prefix): prefix + self.tag, self.timestamp_combiner_impl) +class DataLossReason(Flag): + """Enum defining potential reasons that a trigger may cause data loss.""" + NO_POTENTIAL_LOSS = 0 + MAY_FINISH = auto() + CONDITION_NOT_GUARANTEED = auto() + + # pylint: disable=unused-argument # TODO(robertwb): Provisional API, Java likely to change as well. class TriggerFn(metaclass=ABCMeta): @@ -237,6 +248,43 @@ def reset(self, window, context): """Clear any state and timers used by this TriggerFn.""" pass + @abstractmethod + def may_lose_data(self, windowing): + # type: (core.Windowing) -> DataLossReason + + """Returns whether or not this trigger could cause data loss. + + A trigger can cause data loss in the following scenarios: + + * The trigger has a chance to finish. For instance, AfterWatermark() + without a late trigger would cause all late data to be lost. This + scenario is only accounted for if the windowing strategy allows + late data. Otherwise, the trigger is not responsible for the data + loss. + * The trigger condition may not be met. For instance, + Repeatedly(AfterCount(N)) may not fire due to N not being met. This + is only accounted for if the condition itself led to data loss. + Repeatedly(AfterCount(1)) is safe, since it would only not fire if + there is no data to lose, but Repeatedly(AfterCount(2)) can cause + data loss if there is only one record. + + Note that this only returns the potential for loss. It does not mean that + there will be data loss. It also only accounts for loss related to the + trigger, not other potential causes. + + Args: + windowing: The Windowing that this trigger belongs to. It does not need + to be the top-level trigger. + + Returns: + The DataLossReason. If there is no potential loss, + DataLossReason.NO_POTENTIAL_LOSS is returned. Otherwise, all the + potential reasons are returned as a single value. For instance, if + data loss can result from finishing or not having the condition met, + the result will be DataLossReason.MAY_FINISH|CONDITION_NOT_GUARANTEED. + """ + pass + # pylint: enable=unused-argument @@ -290,6 +338,9 @@ def on_fire(self, watermark, window, context): def reset(self, window, context): context.clear_timer(str(window), TimeDomain.WATERMARK) + def may_lose_data(self, unused_windowing): + return DataLossReason.NO_POTENTIAL_LOSS + def __eq__(self, other): return type(self) == type(other) @@ -338,6 +389,9 @@ def on_fire(self, timestamp, window, context): def reset(self, window, context): pass + def may_lose_data(self, unused_windowing): + return DataLossReason.MAY_FINISH + @staticmethod def from_runner_api(proto, context): return AfterProcessingTime( @@ -389,6 +443,9 @@ def should_fire(self, time_domain, watermark, window, context): def on_fire(self, watermark, window, context): return False + def may_lose_data(self, unused_windowing): + return DataLossReason.NO_POTENTIAL_LOSS + @staticmethod def from_runner_api(proto, context): return Always() @@ -433,6 +490,14 @@ def should_fire(self, time_domain, watermark, window, context): def on_fire(self, watermark, window, context): return True + def may_lose_data(self, unused_windowing): + """No potential data loss. + + Though Never doesn't explicitly trigger, it still collects data on + windowing closing, so any data loss is due to windowing closing. + """ + return DataLossReason.NO_POTENTIAL_LOSS + @staticmethod def from_runner_api(proto, context): return _Never() @@ -454,6 +519,7 @@ class AfterWatermark(TriggerFn): LATE_TAG = _CombiningValueStateTag('is_late', any) def __init__(self, early=None, late=None): + # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly self.early = Repeatedly(early) if early else None self.late = Repeatedly(late) if late else None @@ -524,6 +590,20 @@ def reset(self, window, context): if self.late: self.late.reset(window, NestedContext(context, 'late')) + def may_lose_data(self, windowing): + """May cause data loss if the windowing allows lateness and either: + + * The late trigger is not set + * The late trigger may cause data loss. + + The second case is equivalent to Repeatedly(late).may_lose_data(windowing) + """ + if windowing.allowed_lateness == 0: + return DataLossReason.NO_POTENTIAL_LOSS + if self.late is None: + return DataLossReason.MAY_FINISH + return self.late.may_lose_data(windowing) + def __eq__(self, other): return ( type(self) == type(other) and self.early == other.early and @@ -593,6 +673,12 @@ def on_fire(self, watermark, window, context): def reset(self, window, context): context.clear_state(self.COUNT_TAG) + def may_lose_data(self, unused_windowing): + reason = DataLossReason.MAY_FINISH + if self.count > 1: + reason |= DataLossReason.CONDITION_NOT_GUARANTEED + return reason + @staticmethod def from_runner_api(proto, unused_context): return AfterCount(proto.element_count.element_count) @@ -637,6 +723,17 @@ def on_fire(self, watermark, window, context): def reset(self, window, context): self.underlying.reset(window, context) + def may_lose_data(self, windowing): + """Repeatedly may only lose data if the underlying trigger may not have + its condition met. + + For underlying triggers that may finish, Repeatedly overrides that + behavior. + """ + return ( + self.underlying.may_lose_data(windowing) + & DataLossReason.CONDITION_NOT_GUARANTEED) + @staticmethod def from_runner_api(proto, context): return Repeatedly( @@ -742,6 +839,15 @@ class AfterAny(_ParallelTriggerFn): """ combine_op = any + def may_lose_data(self, windowing): + reason = DataLossReason.NO_POTENTIAL_LOSS + for trigger in self.triggers: + t_reason = trigger.may_lose_data(windowing) + if t_reason == DataLossReason.NO_POTENTIAL_LOSS: + return t_reason + reason |= t_reason + return reason + class AfterAll(_ParallelTriggerFn): """Fires when all subtriggers have fired. @@ -750,6 +856,9 @@ class AfterAll(_ParallelTriggerFn): """ combine_op = all + def may_lose_data(self, windowing): + return reduce(or_, (t.may_lose_data(windowing) for t in self.triggers)) + class AfterEach(TriggerFn): @@ -805,6 +914,9 @@ def reset(self, window, context): for ix, trigger in enumerate(self.triggers): trigger.reset(window, self._sub_context(context, ix)) + def may_lose_data(self, windowing): + return reduce(or_, (t.may_lose_data(windowing) for t in self.triggers)) + @staticmethod def _sub_context(context, index): return NestedContext(context, '%d/' % index) diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py index a3dd4385c74c..9e1a5694ab6d 100644 --- a/sdks/python/apache_beam/transforms/trigger_test.py +++ b/sdks/python/apache_beam/transforms/trigger_test.py @@ -39,6 +39,7 @@ from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms import WindowInto from apache_beam.transforms import ptransform from apache_beam.transforms import trigger from apache_beam.transforms.core import Windowing @@ -50,6 +51,7 @@ from apache_beam.transforms.trigger import AfterProcessingTime from apache_beam.transforms.trigger import AfterWatermark from apache_beam.transforms.trigger import Always +from apache_beam.transforms.trigger import DataLossReason from apache_beam.transforms.trigger import DefaultTrigger from apache_beam.transforms.trigger import GeneralTriggerDriver from apache_beam.transforms.trigger import InMemoryUnmergedState @@ -57,6 +59,7 @@ from apache_beam.transforms.trigger import TriggerFn from apache_beam.transforms.trigger import _Never from apache_beam.transforms.window import FixedWindows +from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import IntervalWindow from apache_beam.transforms.window import Sessions from apache_beam.transforms.window import TimestampCombiner @@ -433,6 +436,128 @@ def test_picklable_output(self): pickle.loads(pickle.dumps(unwindowed)).value, list(range(10))) +class MayLoseDataTest(unittest.TestCase): + def _test(self, trigger, lateness, expected): + windowing = WindowInto( + GlobalWindows(), + trigger=trigger, + accumulation_mode=AccumulationMode.ACCUMULATING, + allowed_lateness=lateness).windowing + self.assertEqual(trigger.may_lose_data(windowing), expected) + + def test_default_trigger(self): + self._test(DefaultTrigger(), 0, DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_processing_time(self): + self._test(AfterProcessingTime(), 0, DataLossReason.MAY_FINISH) + + def test_always(self): + self._test(Always(), 0, DataLossReason.NO_POTENTIAL_LOSS) + + def test_never(self): + self._test(_Never(), 0, DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_watermark_no_allowed_lateness(self): + self._test(AfterWatermark(), 0, DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_watermark_late_none(self): + self._test(AfterWatermark(), 60, DataLossReason.MAY_FINISH) + + def test_after_watermark_no_allowed_lateness_safe_late(self): + self._test( + AfterWatermark(late=DefaultTrigger()), + 0, + DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_watermark_safe_late(self): + self._test( + AfterWatermark(late=DefaultTrigger()), + 60, + DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_watermark_no_allowed_lateness_may_finish_late(self): + self._test( + AfterWatermark(late=AfterProcessingTime()), + 0, + DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_watermark_may_finish_late(self): + self._test( + AfterWatermark(late=AfterProcessingTime()), + 60, + DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_watermark_no_allowed_lateness_condition_late(self): + self._test( + AfterWatermark(late=AfterCount(5)), 0, DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_watermark_condition_late(self): + self._test( + AfterWatermark(late=AfterCount(5)), + 60, + DataLossReason.CONDITION_NOT_GUARANTEED) + + def test_after_count_one(self): + self._test(AfterCount(1), 0, DataLossReason.MAY_FINISH) + + def test_after_count_gt_one(self): + self._test( + AfterCount(2), + 0, + DataLossReason.MAY_FINISH | DataLossReason.CONDITION_NOT_GUARANTEED) + + def test_repeatedly_safe_underlying(self): + self._test( + Repeatedly(DefaultTrigger()), 0, DataLossReason.NO_POTENTIAL_LOSS) + + def test_repeatedly_may_finish_underlying(self): + self._test(Repeatedly(AfterCount(1)), 0, DataLossReason.NO_POTENTIAL_LOSS) + + def test_repeatedly_condition_underlying(self): + self._test( + Repeatedly(AfterCount(2)), 0, DataLossReason.CONDITION_NOT_GUARANTEED) + + def test_after_any_some_unsafe(self): + self._test( + AfterAny(AfterCount(1), DefaultTrigger()), + 0, + DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_any_same_reason(self): + self._test( + AfterAny(AfterCount(1), AfterProcessingTime()), + 0, + DataLossReason.MAY_FINISH) + + def test_after_any_different_reasons(self): + self._test( + AfterAny(Repeatedly(AfterCount(2)), AfterProcessingTime()), + 0, + DataLossReason.MAY_FINISH | DataLossReason.CONDITION_NOT_GUARANTEED) + + def test_after_all_some_unsafe(self): + self._test( + AfterAll(AfterCount(1), DefaultTrigger()), 0, DataLossReason.MAY_FINISH) + + def test_after_all_safe(self): + self._test( + AfterAll(Repeatedly(AfterCount(1)), DefaultTrigger()), + 0, + DataLossReason.NO_POTENTIAL_LOSS) + + def test_after_each_some_unsafe(self): + self._test( + AfterEach(AfterCount(1), DefaultTrigger()), + 0, + DataLossReason.MAY_FINISH) + + def test_after_each_all_safe(self): + self._test( + AfterEach(Repeatedly(AfterCount(1)), DefaultTrigger()), + 0, + DataLossReason.NO_POTENTIAL_LOSS) + + class RunnerApiTest(unittest.TestCase): def test_trigger_encoding(self): for trigger_fn in (DefaultTrigger(), @@ -451,7 +576,8 @@ def test_trigger_encoding(self): class TriggerPipelineTest(unittest.TestCase): def test_after_count(self): - with TestPipeline() as p: + test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) + with TestPipeline(options=test_options) as p: def construct_timestamped(k_t): return TimestampedValue((k_t[0], k_t[1]), k_t[1])