diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 97996bd6cbb2..068cef986f3b 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -557,13 +557,11 @@ def _cached_gcs_file_copy(self, from_path, to_path, sha256): source_file_names=[cached_path], destination_file_names=[to_path]) _LOGGER.info('Copied cached artifact from %s to %s', from_path, to_path) - @retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter) def _uncached_gcs_file_copy(self, from_path, to_path): to_folder, to_name = os.path.split(to_path) total_size = os.path.getsize(from_path) - with open(from_path, 'rb') as f: - self.stage_file(to_folder, to_name, f, total_size=total_size) + self.stage_file_with_retry( + to_folder, to_name, from_path, total_size=total_size) def _stage_resources(self, pipeline, options): google_cloud_options = options.view_as(GoogleCloudOptions) @@ -692,6 +690,29 @@ def stage_file( (gcs_or_local_path, e)) raise + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def stage_file_with_retry( + self, + gcs_or_local_path, + file_name, + stream_or_path, + mime_type='application/octet-stream', + total_size=None): + + if isinstance(stream_or_path, str): + path = stream_or_path + with open(path, 'rb') as stream: + self.stage_file( + gcs_or_local_path, file_name, stream, mime_type, total_size) + elif isinstance(stream_or_path, io.BufferedIOBase): + stream = stream_or_path + assert stream.seekable(), "stream must be seekable" + if stream.tell() > 0: + stream.seek(0) + self.stage_file( + gcs_or_local_path, file_name, stream, mime_type, total_size) + @retry.no_retries # Using no_retries marks this as an integration point. def create_job(self, job): """Creates job description. May stage and/or submit for remote execution.""" @@ -703,7 +724,7 @@ def create_job(self, job): job.options.view_as(GoogleCloudOptions).template_location) if job.options.view_as(DebugOptions).lookup_experiment('upload_graph'): - self.stage_file( + self.stage_file_with_retry( job.options.view_as(GoogleCloudOptions).staging_location, "dataflow_graph.json", io.BytesIO(job.json().encode('utf-8'))) @@ -718,7 +739,7 @@ def create_job(self, job): if job_location: gcs_or_local_path = os.path.dirname(job_location) file_name = os.path.basename(job_location) - self.stage_file( + self.stage_file_with_retry( gcs_or_local_path, file_name, io.BytesIO(job.json().encode('utf-8'))) if not template_location: @@ -790,7 +811,7 @@ def create_job_description(self, job): resources = self._stage_resources(job.proto_pipeline, job.options) # Stage proto pipeline. - self.stage_file( + self.stage_file_with_retry( job.google_cloud_options.staging_location, shared_names.STAGED_PIPELINE_FILENAME, io.BytesIO(job.proto_pipeline.SerializeToString())) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 6587e619a500..4b32dc567b06 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -19,11 +19,13 @@ # pytype: skip-file +import io import itertools import json import logging import os import sys +import time import unittest import mock @@ -1064,7 +1066,11 @@ def test_graph_is_uploaded(self): side_effect=None): client.create_job(job) client.stage_file.assert_called_once_with( - mock.ANY, "dataflow_graph.json", mock.ANY) + mock.ANY, + "dataflow_graph.json", + mock.ANY, + 'application/octet-stream', + None) client.create_job_description.assert_called_once() def test_create_job_returns_existing_job(self): @@ -1174,8 +1180,18 @@ def test_template_file_generation_with_upload_graph(self): client.create_job(job) client.stage_file.assert_has_calls([ - mock.call(mock.ANY, 'dataflow_graph.json', mock.ANY), - mock.call(mock.ANY, 'template', mock.ANY) + mock.call( + mock.ANY, + 'dataflow_graph.json', + mock.ANY, + 'application/octet-stream', + None), + mock.call( + mock.ANY, + 'template', + mock.ANY, + 'application/octet-stream', + None) ]) client.create_job_description.assert_called_once() # template is generated, but job should not be submitted to the @@ -1653,6 +1669,50 @@ def exists_return_value(*args): })) self.assertEqual(pipeline, pipeline_expected) + def test_stage_file_with_retry(self): + count = 0 + + def effect(self, *args, **kwargs): + nonlocal count + count += 1 + if count > 1: + return + raise Exception("This exception is raised for testing purpose.") + + pipeline_options = PipelineOptions([ + '--project', + 'test_project', + '--job_name', + 'test_job_name', + '--temp_location', + 'gs://test-location/temp', + ]) + pipeline_options.view_as(GoogleCloudOptions).no_auth = True + client = apiclient.DataflowApplicationClient(pipeline_options) + + with mock.patch.object(time, 'sleep'): + count = 0 + with mock.patch("builtins.open", + mock.mock_open(read_data="data")) as mock_file_open: + with mock.patch.object(client, 'stage_file') as mock_stage_file: + mock_stage_file.side_effect = effect + # call with a file name + client.stage_file_with_retry( + "/to", "new_name", "/from/old_name", total_size=1024) + self.assertEqual(mock_file_open.call_count, 2) + self.assertEqual(mock_stage_file.call_count, 2) + + count = 0 + with mock.patch("builtins.open", + mock.mock_open(read_data="data")) as mock_file_open: + with mock.patch.object(client, 'stage_file') as mock_stage_file: + mock_stage_file.side_effect = effect + # call with a seekable stream + client.stage_file_with_retry( + "/to", "new_name", io.BytesIO(b'test'), total_size=4) + mock_file_open.assert_not_called() + self.assertEqual(mock_stage_file.call_count, 2) + if __name__ == '__main__': unittest.main()