Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback to with_exception_handling #32136

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,7 +1572,9 @@ def with_exception_handling(
use_subprocess=False,
threshold=1,
threshold_windowing=None,
timeout=None):
timeout=None,
on_failure_callback: typing.Optional[typing.Callable[
[Exception, typing.Any], None]] = None):
"""Automatically provides a dead letter output for skipping bad records.
This can allow a pipeline to continue successfully rather than fail or
continuously throw errors on retry when bad elements are encountered.
Expand Down Expand Up @@ -1620,6 +1622,11 @@ def with_exception_handling(
defaults to the windowing of the input.
timeout: If the element has not finished processing in timeout seconds,
raise a TimeoutError. Defaults to None, meaning no time limit.
on_failure_callback: If an element fails or times out,
on_failure_callback will be invoked. It will receive the exception
and the element being processed in as args. Be careful with this
damccorm marked this conversation as resolved.
Show resolved Hide resolved
callback - if you set a timeout, it will not apply to the callback,
and if the callback fails it will not be retried.
"""
args, kwargs = self.raw_side_inputs
return self.label >> _ExceptionHandlingWrapper(
Expand All @@ -1633,7 +1640,8 @@ def with_exception_handling(
use_subprocess,
threshold,
threshold_windowing,
timeout)
timeout,
on_failure_callback)

def default_type_hints(self):
return self.fn.get_type_hints()
Expand Down Expand Up @@ -2232,7 +2240,8 @@ def __init__(
use_subprocess,
threshold,
threshold_windowing,
timeout):
timeout,
on_failure_callback):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
self._fn = fn
Expand All @@ -2246,6 +2255,7 @@ def __init__(
self._threshold = threshold
self._threshold_windowing = threshold_windowing
self._timeout = timeout
self._on_failure_callback = on_failure_callback

def expand(self, pcoll):
if self._use_subprocess:
Expand All @@ -2256,7 +2266,11 @@ def expand(self, pcoll):
wrapped_fn = self._fn
result = pcoll | ParDo(
_ExceptionHandlingWrapperDoFn(
wrapped_fn, self._dead_letter_tag, self._exc_class, self._partial),
wrapped_fn,
self._dead_letter_tag,
self._exc_class,
self._partial,
self._on_failure_callback),
*self._args,
**self._kwargs).with_outputs(
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
Expand Down Expand Up @@ -2295,11 +2309,13 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam):


class _ExceptionHandlingWrapperDoFn(DoFn):
def __init__(self, fn, dead_letter_tag, exc_class, partial):
def __init__(
self, fn, dead_letter_tag, exc_class, partial, on_failure_callback):
self._fn = fn
self._dead_letter_tag = dead_letter_tag
self._exc_class = exc_class
self._partial = partial
self._on_failure_callback = on_failure_callback

def __getattribute__(self, name):
if (name.startswith('__') or name in self.__dict__ or
Expand All @@ -2316,6 +2332,11 @@ def process(self, *args, **kwargs):
result = list(result)
yield from result
except self._exc_class as exn:
if self._on_failure_callback is not None:
try:
self._on_failure_callback(exn, args[0])
except Exception as e:
logging.warning('on_failure_callback failed with error: %s', e)
yield pvalue.TaggedOutput(
self._dead_letter_tag,
(
Expand Down
71 changes: 71 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# pytype: skip-file

import logging
import os
import tempfile
import unittest

import pytest
Expand Down Expand Up @@ -87,6 +89,13 @@ def process(self, element):
yield element


class TestDoFn9(beam.DoFn):
def process(self, element):
if len(element) > 3:
raise ValueError('Not allowed to have long elements')
yield element


class CreateTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
Expand Down Expand Up @@ -170,6 +179,68 @@ def test_flatten_mismatched_windows(self):
_ = (source1, source2, source3) | "flatten" >> beam.Flatten()


class ExceptionHandlingTest(unittest.TestCase):
def test_routes_failures(self):
with beam.Pipeline() as pipeline:
good, bad = (
pipeline | beam.Create(['abc', 'long_word', 'foo', 'bar', 'foobar'])
| beam.ParDo(TestDoFn9()).with_exception_handling()
)
bad_elements = bad | beam.Map(lambda x: x[0])
damccorm marked this conversation as resolved.
Show resolved Hide resolved
assert_that(good, equal_to(['abc', 'foo', 'bar']), 'good')
assert_that(bad_elements, equal_to(['long_word', 'foobar']), 'bad')

def test_handles_callbacks(self):
with tempfile.TemporaryDirectory() as tmp_dirname:
tmp_path = os.path.join(tmp_dirname, 'tmp_filename')
file_contents = 'random content'

def failure_callback(e, el):
if type(e) is not ValueError:
raise Exception(f'Failed to pass in correct exception, received {e}')
if el != 'foobar':
raise Exception(f'Failed to pass in correct element, received {el}')
f = open(tmp_path, "a")
logging.warning(tmp_path)
f.write(file_contents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this didn't work. I'm not totally sure why, but I think it is because of how the scoping of the function works as it is passed through Beam. Somehow, it seems like its referencing a copy of the variable, I'm guessing it gets copied by value somewhere along the way... Maybe related to us spinning up new threads to handle pieces of this logic?

Regardless, I'm inclined to leave it rather than digging in further since it is still effectively testing correctness at this point.

f.close()

with beam.Pipeline() as pipeline:
good, bad = (
pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar', 'foobar'])
| beam.ParDo(TestDoFn9()).with_exception_handling(
on_failure_callback=failure_callback)
)
bad_elements = bad | beam.Map(lambda x: x[0])
assert_that(good, equal_to(['abc', 'bcd', 'foo', 'bar']), 'good')
assert_that(bad_elements, equal_to(['foobar']), 'bad')
with open(tmp_path) as f:
s = f.read()
self.assertEqual(s, file_contents)

def test_handles_no_callback_triggered(self):
with tempfile.TemporaryDirectory() as tmp_dirname:
tmp_path = os.path.join(tmp_dirname, 'tmp_filename')
file_contents = 'random content'

def failure_callback(e, el):
f = open(tmp_path, "a")
logging.warning(tmp_path)
f.write(file_contents)
f.close()

with beam.Pipeline() as pipeline:
good, bad = (
pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar'])
| beam.ParDo(TestDoFn9()).with_exception_handling(
on_failure_callback=failure_callback)
)
bad_elements = bad | beam.Map(lambda x: x[0])
assert_that(good, equal_to(['abc', 'bcd', 'foo', 'bar']), 'good')
assert_that(bad_elements, equal_to([]), 'bad')
self.assertFalse(os.path.isfile(tmp_path))


class FlatMapTest(unittest.TestCase):
def test_default(self):

Expand Down
Loading