diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 97996bd6cbb2..4e65156f3bc7 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -42,6 +42,7 @@ import re import sys import time +import traceback import warnings from copy import copy from datetime import datetime @@ -557,13 +558,11 @@ def _cached_gcs_file_copy(self, from_path, to_path, sha256): source_file_names=[cached_path], destination_file_names=[to_path]) _LOGGER.info('Copied cached artifact from %s to %s', from_path, to_path) - @retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter) def _uncached_gcs_file_copy(self, from_path, to_path): to_folder, to_name = os.path.split(to_path) total_size = os.path.getsize(from_path) - with open(from_path, 'rb') as f: - self.stage_file(to_folder, to_name, f, total_size=total_size) + self.stage_file_with_retry( + to_folder, to_name, from_path, total_size=total_size) def _stage_resources(self, pipeline, options): google_cloud_options = options.view_as(GoogleCloudOptions) @@ -692,6 +691,41 @@ def stage_file( (gcs_or_local_path, e)) raise + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def stage_file_with_retry( + self, + gcs_or_local_path, + file_name, + stream_or_path, + mime_type='application/octet-stream', + total_size=None): + + if isinstance(stream_or_path, str): + path = stream_or_path + with open(path, 'rb') as stream: + self.stage_file( + gcs_or_local_path, file_name, stream, mime_type, total_size) + elif isinstance(stream_or_path, io.IOBase): + stream = stream_or_path + try: + self.stage_file( + gcs_or_local_path, file_name, stream, mime_type, total_size) + except Exception as exn: + if stream.seekable(): + # reset cursor for possible retrying + stream.seek(0) + raise exn + else: + raise retry.PermanentException( + "Skip retrying because we caught exception:" + + ''.join(traceback.format_exception_only(exn.__class__, exn)) + + ', but the stream is not seekable.') + else: + raise retry.PermanentException( + "Skip retrying because type " + str(type(stream_or_path)) + + "stream_or_path is unsupported.") + @retry.no_retries # Using no_retries marks this as an integration point. def create_job(self, job): """Creates job description. May stage and/or submit for remote execution.""" @@ -703,7 +737,7 @@ def create_job(self, job): job.options.view_as(GoogleCloudOptions).template_location) if job.options.view_as(DebugOptions).lookup_experiment('upload_graph'): - self.stage_file( + self.stage_file_with_retry( job.options.view_as(GoogleCloudOptions).staging_location, "dataflow_graph.json", io.BytesIO(job.json().encode('utf-8'))) @@ -718,7 +752,7 @@ def create_job(self, job): if job_location: gcs_or_local_path = os.path.dirname(job_location) file_name = os.path.basename(job_location) - self.stage_file( + self.stage_file_with_retry( gcs_or_local_path, file_name, io.BytesIO(job.json().encode('utf-8'))) if not template_location: @@ -790,7 +824,7 @@ def create_job_description(self, job): resources = self._stage_resources(job.proto_pipeline, job.options) # Stage proto pipeline. - self.stage_file( + self.stage_file_with_retry( job.google_cloud_options.staging_location, shared_names.STAGED_PIPELINE_FILENAME, io.BytesIO(job.proto_pipeline.SerializeToString())) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 6587e619a500..d055065cb9d9 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -19,11 +19,13 @@ # pytype: skip-file +import io import itertools import json import logging import os import sys +import time import unittest import mock @@ -42,6 +44,7 @@ from apache_beam.transforms import DoFn from apache_beam.transforms import ParDo from apache_beam.transforms.environments import DockerEnvironment +from apache_beam.utils import retry # Protect against environments where apitools library is not available. # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports @@ -1064,7 +1067,11 @@ def test_graph_is_uploaded(self): side_effect=None): client.create_job(job) client.stage_file.assert_called_once_with( - mock.ANY, "dataflow_graph.json", mock.ANY) + mock.ANY, + "dataflow_graph.json", + mock.ANY, + 'application/octet-stream', + None) client.create_job_description.assert_called_once() def test_create_job_returns_existing_job(self): @@ -1174,8 +1181,18 @@ def test_template_file_generation_with_upload_graph(self): client.create_job(job) client.stage_file.assert_has_calls([ - mock.call(mock.ANY, 'dataflow_graph.json', mock.ANY), - mock.call(mock.ANY, 'template', mock.ANY) + mock.call( + mock.ANY, + 'dataflow_graph.json', + mock.ANY, + 'application/octet-stream', + None), + mock.call( + mock.ANY, + 'template', + mock.ANY, + 'application/octet-stream', + None) ]) client.create_job_description.assert_called_once() # template is generated, but job should not be submitted to the @@ -1653,6 +1670,93 @@ def exists_return_value(*args): })) self.assertEqual(pipeline, pipeline_expected) + def test_stage_file_with_retry(self): + def effect(self, *args, **kwargs): + nonlocal count + count += 1 + # Fail the first two calls and succeed afterward + if count <= 2: + raise Exception("This exception is raised for testing purpose.") + + class Unseekable(io.IOBase): + def seekable(self): + return False + + pipeline_options = PipelineOptions([ + '--project', + 'test_project', + '--job_name', + 'test_job_name', + '--temp_location', + 'gs://test-location/temp', + ]) + pipeline_options.view_as(GoogleCloudOptions).no_auth = True + client = apiclient.DataflowApplicationClient(pipeline_options) + + with mock.patch.object(client, 'stage_file') as mock_stage_file: + mock_stage_file.side_effect = effect + + with mock.patch.object(time, 'sleep') as mock_sleep: + with mock.patch("builtins.open", + mock.mock_open(read_data="data")) as mock_file_open: + count = 0 + # calling with a file name + client.stage_file_with_retry( + "/to", "new_name", "/from/old_name", total_size=4) + self.assertEqual(mock_stage_file.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + self.assertEqual(mock_file_open.call_count, 3) + + count = 0 + mock_stage_file.reset_mock() + mock_sleep.reset_mock() + mock_file_open.reset_mock() + + # calling with a seekable stream + client.stage_file_with_retry( + "/to", "new_name", io.BytesIO(b'test'), total_size=4) + self.assertEqual(mock_stage_file.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + # no open() is called if a stream is provided + mock_file_open.assert_not_called() + + count = 0 + mock_sleep.reset_mock() + mock_file_open.reset_mock() + mock_stage_file.reset_mock() + + # calling with an unseekable stream + self.assertRaises( + retry.PermanentException, + client.stage_file_with_retry, + "/to", + "new_name", + Unseekable(), + total_size=4) + # Unseekable streams are staged once. If staging fails, no retries are + # attempted. + self.assertEqual(mock_stage_file.call_count, 1) + mock_sleep.assert_not_called() + mock_file_open.assert_not_called() + + count = 0 + mock_sleep.reset_mock() + mock_file_open.reset_mock() + mock_stage_file.reset_mock() + + # calling with something else + self.assertRaises( + retry.PermanentException, + client.stage_file_with_retry, + "/to", + "new_name", + object(), + total_size=4) + # No staging will be called for wrong arg type + mock_stage_file.assert_not_called() + mock_sleep.assert_not_called() + mock_file_open.assert_not_called() + if __name__ == '__main__': unittest.main()