Skip to content

Commit

Permalink
Merge pull request #14857 from [BEAM-9487] Add trigger safety check t…
Browse files Browse the repository at this point in the history
…o GroupByKey

[BEAM-9487] Add trigger safety check to GroupByKey
  • Loading branch information
pabloem authored May 27, 2021
2 parents 4f1f1c1 + 847efa5 commit 99aa83d
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_leader_board_it(self):
self.project, teams_query, self.DEFAULT_EXPECTED_CHECKSUM)

extra_opts = {
'allow_unsafe_triggers': True,
'subscription': self.input_sub.name,
'dataset': self.dataset_ref.dataset_id,
'topic': self.input_topic.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import apache_beam as beam
from apache_beam.examples.complete.game import leader_board
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
Expand Down Expand Up @@ -59,7 +60,8 @@ def test_leader_board_teams(self):
('team3', 13)]))

def test_leader_board_users(self):
with TestPipeline() as p:
test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
with TestPipeline(options=test_options) as p:
result = (
self.create_data(p)
| leader_board.CalculateUserScores(allowed_lateness=120))
Expand Down
8 changes: 4 additions & 4 deletions sdks/python/apache_beam/examples/snippets/snippets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,8 +1081,8 @@ def test_model_early_late_triggers(self):
assert_that(counts, equal_to([('a', 4), ('b', 2), ('a', 1)]))

def test_model_setting_trigger(self):
pipeline_options = PipelineOptions()
pipeline_options.view_as(StandardOptions).streaming = True
pipeline_options = PipelineOptions(
flags=['--streaming', '--allow_unsafe_triggers'])

with TestPipeline(options=pipeline_options) as p:
test_stream = (
Expand Down Expand Up @@ -1136,8 +1136,8 @@ def test_model_composite_triggers(self):
assert_that(counts, equal_to([('a', 3), ('b', 2), ('a', 2), ('c', 2)]))

def test_model_other_composite_triggers(self):
pipeline_options = PipelineOptions()
pipeline_options.view_as(StandardOptions).streaming = True
pipeline_options = PipelineOptions(
flags=['--streaming', '--allow_unsafe_triggers'])

with TestPipeline(options=pipeline_options) as p:
test_stream = (
Expand Down
5 changes: 4 additions & 1 deletion sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from apache_beam.io.gcp.internal.clients import bigquery as bigquery_api
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners.dataflow.test_dataflow_runner import TestDataflowRunner
from apache_beam.runners.runner import PipelineState
Expand Down Expand Up @@ -648,8 +649,10 @@ def __call__(self):
with_auto_sharding=with_auto_sharding)

# Need to test this with the DirectRunner to avoid serializing mocks
test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
test_options.view_as(StandardOptions).streaming = is_streaming
with TestPipeline(runner='BundleBasedDirectRunner',
options=StandardOptions(streaming=is_streaming)) as p:
options=test_options) as p:
if is_streaming:
_SIZE = len(_ELEMENTS)
fisrt_batch = [
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/io/gcp/bigquery_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,8 @@ def _run_pubsub_bq_pipeline(self, method, triggering_frequency=None):
args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*matchers),
wait_until_finish_duration=self.WAIT_UNTIL_FINISH_DURATION,
streaming=True)
streaming=True,
allow_unsafe_triggers=True)

def add_schema_info(element):
yield {'number': element}
Expand Down
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/options/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,15 @@ def _add_argparse_args(cls, parser):
'operations such as GropuByKey. This is unsafe, as runners may group '
'keys based on their encoded bytes, but is available for backwards '
'compatibility. See BEAM-11719.')
parser.add_argument(
'--allow_unsafe_triggers',
default=False,
action='store_true',
help='Allow the use of unsafe triggers. Unsafe triggers have the '
'potential to cause data loss due to finishing and/or never having '
'their condition met. Some operations, such as GroupByKey, disallow '
'this. This exists for cases where such loss is acceptable and for '
'backwards compatibility. See BEAM-9487.')

def validate(self, unused_validator):
errors = []
Expand Down
5 changes: 5 additions & 0 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def options(self):
# type: () -> PipelineOptions
return self._options

@property
def allow_unsafe_triggers(self):
# type: () -> bool
return self._options.view_as(TypeOptions).allow_unsafe_triggers

def _current_transform(self):
# type: () -> AppliedPTransform

Expand Down
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/testing/test_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
Expand Down Expand Up @@ -427,6 +428,7 @@ def test_gbk_execution_after_processing_trigger_fired(self):

options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
options.view_as(TypeOptions).allow_unsafe_triggers = True
p = TestPipeline(options=options)
records = (
p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from parameterized import parameterized_class

from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners.direct import direct_runner
from apache_beam.runners.portability import fn_api_runner
Expand Down Expand Up @@ -88,7 +89,10 @@ def test_combine(self):
self._assert_teardown_called()

def test_non_liftable_combine(self):
run_combine(TestPipeline(runner=self.runner()), lift_combiners=False)
test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
run_combine(
TestPipeline(runner=self.runner(), options=test_options),
lift_combiners=False)
self._assert_teardown_called()

def test_combining_value_state(self):
Expand Down
17 changes: 16 additions & 1 deletion sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,15 +2314,30 @@ def infer_output_type(self, input_type):
key_type, typehints.WindowedValue[value_type]]] # type: ignore[misc]

def expand(self, pcoll):
from apache_beam.transforms.trigger import DataLossReason
from apache_beam.transforms.trigger import DefaultTrigger
windowing = pcoll.windowing
trigger = windowing.triggerfn
if not pcoll.is_bounded and isinstance(
windowing.windowfn, GlobalWindows) and isinstance(windowing.triggerfn,
windowing.windowfn, GlobalWindows) and isinstance(trigger,
DefaultTrigger):
raise ValueError(
'GroupByKey cannot be applied to an unbounded ' +
'PCollection with global windowing and a default trigger')

if not pcoll.pipeline.allow_unsafe_triggers:
unsafe_reason = trigger.may_lose_data(windowing)
if unsafe_reason != DataLossReason.NO_POTENTIAL_LOSS:
msg = 'Unsafe trigger: `{}` may lose data. '.format(trigger)
msg += 'Reason: {}. '.format(
str(unsafe_reason).replace('DataLossReason.', ''))
msg += 'This can be overriden with the --allow_unsafe_triggers flag.'
raise ValueError(msg)
else:
_LOGGER.warning(
'Skipping trigger safety check. '
'This could lead to incomplete or missing groups.')

return pvalue.PCollection.from_(pcoll)

def infer_output_type(self, input_type):
Expand Down
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@
from apache_beam.io.iobase import Read
from apache_beam.metrics import Metrics
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.testing.util import is_empty
from apache_beam.transforms import WindowInto
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display import DisplayDataItem
Expand Down Expand Up @@ -481,6 +484,33 @@ def test_group_by_key_unbounded_global_default_trigger(self):
with TestPipeline() as pipeline:
pipeline | TestStream() | beam.GroupByKey()

def test_group_by_key_unsafe_trigger(self):
with self.assertRaisesRegex(ValueError, 'Unsafe trigger'):
with TestPipeline() as pipeline:
_ = (
pipeline
| beam.Create([(None, None)])
| WindowInto(
window.GlobalWindows(),
trigger=trigger.AfterCount(5),
accumulation_mode=trigger.AccumulationMode.ACCUMULATING)
| beam.GroupByKey())

def test_group_by_key_allow_unsafe_triggers(self):
test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
with TestPipeline(options=test_options) as pipeline:
pcoll = (
pipeline
| beam.Create([(1, 1), (1, 2), (1, 3), (1, 4)])
| WindowInto(
window.GlobalWindows(),
trigger=trigger.AfterCount(5),
accumulation_mode=trigger.AccumulationMode.ACCUMULATING)
| beam.GroupByKey())
# We need five, but it only has four - Displays how this option is
# dangerous.
assert_that(pcoll, is_empty())

def test_group_by_key_reiteration(self):
class MyDoFn(beam.DoFn):
def process(self, gbk_result):
Expand Down
Loading

0 comments on commit 99aa83d

Please sign in to comment.